Laplace 1D
The smallest complete jNO example. Solves the 1-D Laplace equation with non-homogeneous Dirichlet boundary conditions and a hard-enforced linear-interpolant ansatz.
Problem Setup
Exact solution: u(x) = x (the straight line between the two boundary values).
Step 1: Create the Domain
Step 2: Build the Network with a Non-Homogeneous Hard Ansatz
The boundary values are non-zero (u(0)=0, u(1)=1), so the multiplicative x(1-x) trick alone won't enforce them. Instead, use a linear interpolant + correction ansatz:
u_net = jno.nn.wrap(
foundax.mlp(in_features=1, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(0))
).optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.5, end_value=1e-5)))
u = x + x * (1 - x) * u_net(x)
Why this works:
xalone satisfiesu(0)=0andu(1)=1exactly — it's the linear interpolant of the BCs.x*(1-x)vanishes at both endpoints, sou_net(x)can never break the BCs no matter what it learns.- The network only has to learn the deviation from the linear interpolant. For pure Laplace
u''=0that deviation is exactly zero, so this is a clean test of whether the optimiser can driveu_netto the constant-zero function.
Step 3: Build the PDE Residual
.d2(x) is the second-derivative shortcut on any Placeholder. Equivalent to u.d(x).d(x) but more compact.
Step 4: Solve
What To Notice
- The non-homogeneous BCs require an ansatz that adds to a satisfying solution rather than just multiplying. The general recipe is
u = boundary_lift(x) + zero_at_boundary(x) * net(x). - For more interesting Poisson-style problems with a non-zero forcing, see Poisson 1D — same domain, but adds a soft-BC pattern with
scheme="finite_difference". - This is a trivial PDE in the sense that the exact solution is in the trial-space class; jNO is being asked to confirm convergence, not discover anything new.
Script Snippet
"""01 — 1-D Laplace equation (simplest possible PINN)"""
import foundax
import jax
import optax
import jno
domain = jno.domain.line(mesh_size=0.1)
x, _ = domain.variable("interior")
u_exact = x
net = jno.nn.wrap(foundax.mlp(in_features=1, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(0)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.5, end_value=1e-5)))
u = (x + x * (1 - x) * net(x)).scalar.bind(x=x)
pde = u.xx # Laplace: u'' = 0
crux = jno.core([pde.mse])
crux.solve(5000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"