Laplace 1D
This is the smallest complete jNO example. It solves a 1D Poisson or Laplace-type equation with homogeneous Dirichlet boundary conditions and compares the learned solution against the analytical one.
Problem Setup
We solve
with exact solution
Step 1: Create the Domain
The script initializes jNO, creates a 1D line domain, and extracts interior points.
π = jnn.pi
dire = jno.setup(__file__)
domain = jno.domain(constructor=jno.domain.line(mesh_size=pick(0.01, 0.1)))
x, _ = domain.variable("interior")
What this does:
jno.setup(__file__)creates the run directory for outputs.jno.domain.line(...)builds the 1D geometry.domain.variable("interior")gives the collocation points used for the PDE residual.
Step 2: Define the Analytical Reference
The exact solution is only used for monitoring error, not for training supervision.
This is useful because you can track whether the PINN is converging toward the known solution.
Step 3: Build the Network with Hard Boundary Conditions
The model is a small MLP. Boundary conditions are enforced by multiplying the network output by x(1-x).
u_net = jnn.nn.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, pick(5_000, 10), 1e-5))
u = u_net(x) * x * (1 - x)
Why this matters:
u_net(x)is unconstrained.- Multiplying by
x(1-x)forcesu(0)=u(1)=0exactly. - This removes the need for a separate boundary loss term.
Step 4: Build the PDE Residual and Error Tracker
The residual uses automatic differentiation twice to compute the second derivative.
pde = -jnn.grad(jnn.grad(u, x), x) - jnn.sin(π * x)
error = jnn.tracker((u - u_exact).mse, interval=pick(100, 1))
This gives you:
pde.mse: the physics loss to minimizeerror: a tracked metric that reports the solution error during training
Step 5: Solve the Problem
This is the standard jNO flow:
- Bundle constraints and tracked metrics into
jno.core(...) - Call
solve(...) - Use the returned history for diagnostics
Step 6: Evaluate and Plot
After training, the script sorts points, evaluates the learned field, and saves both the training history and solution plot.
pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]
You end up with:
training_history.pngsolution.png
What To Notice
- This example uses hard constraints, which keeps the loss simple.
- The exact solution is not part of the PDE loss, only the tracker.
- For many introductory PDEs, this is the cleanest pattern to start from.
Script Snippet
"""01 — 1-D Laplace / Poisson equation (simplest possible PINN)
Problem
-------
−u''(x) = sin(πx), x ∈ [0, 1], u(0) = u(1) = 0
Analytical solution
-------------------
u(x) = sin(πx) / π²
Techniques shown
----------------
* Homogeneous Dirichlet BCs via hard constraint: u = net(x) · x (1−x)
* Automatic-differentiation gradient (jnn.grad)
* Tracker to log the L²-error against the exact solution
* Single-phase Adam with exponential LR decay
"""
import jax
import jno
import jno.jnp_ops as jnn
import optax
from jno import LearningRateSchedule as lrs
import matplotlib.pyplot as plt
import numpy as np
π = jnn.pi
dire = jno.setup(__file__)
# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(constructor=jno.domain.line(mesh_size=0.01))
x, _ = domain.variable("interior")
# ── Analytical solution ───────────────────────────────────────────────────────
u_exact = jnn.sin(π * x) / π**2
# ── Network with hard-enforced BCs ────────────────────────────────────────────
u_net = jno.nn.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 5_000, 1e-5))
u = u_net(x) * x * (1 - x) # hard BC: u(0) = u(1) = 0
# ── Constraints ───────────────────────────────────────────────────────────────
pde = -jnn.grad(jnn.grad(u, x), x) - jnn.sin(π * x) # should be 0
error = jnn.tracker((u - u_exact).mse, interval=100)
# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, error], domain)
history = crux.solve(5_000)
history.plot(f"{dire}/training_history.png")
# ── Plot ──────────────────────────────────────────────────────────────────────
pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.set_title("Solution")
ax1.plot(xs, pred, label="PINN")
ax1.plot(xs, true, "--", label="exact")
ax1.set_xlabel("x")
ax1.legend()
ax2.set_title("Pointwise |error|")
ax2.plot(xs, np.abs(pred - true))
ax2.set_xlabel("x")
plt.tight_layout()
plt.savefig(f"{dire}/solution.png", dpi=150)
print(f"Saved to {dire}/")