Poisson 1D
This example solves nearly the same equation as Laplace 1D, but changes two important implementation choices: it uses soft boundary constraints and a finite-difference second derivative.
Problem Setup
We solve
with exact solution
Step 1: Create Interior and Boundary Variables
Unlike the hard-constraint example, this script explicitly asks for both interior and boundary points.
domain = jno.domain(constructor=jno.domain.line(mesh_size=pick(0.01, 0.1)))
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")
Why both are needed:
xis used for the PDE residual.xbis used to define the boundary-condition loss.
Step 2: Define the Reference Solution
As in the previous example, this is only used to track model quality.
Step 3: Build the Network
This version keeps the network output unconstrained.
u_net = jnn.nn.mlp(
in_features=1,
hidden_dims=64,
num_layers=4,
key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, pick(10_000, 10), 1e-5))
u = u_net(x)
The consequence is that boundary conditions must now be enforced through an explicit loss term.
Step 4: Define PDE and Boundary Losses
pde = -u.d2(x, scheme="finite_difference") - jnn.sin(π * x)
bc = u_net(xb)
error = jnn.tracker((u - u_exact).mse, interval=pick(100, 1))
Key difference from Laplace 1D:
u.d2(..., scheme="finite_difference")computes the second derivative numerically.bc = u_net(xb)is minimized toward zero, which softly enforces the boundary values.
Step 5: Solve With Multiple Constraints
Now the optimization balances:
- PDE residual in the interior
- boundary mismatch at the endpoints
- tracked error against the exact solution
Step 6: Evaluate Error and Plot
After solving, the script computes the mean absolute error and saves the solution plot.
This gives a simple scalar quality check in addition to the saved figures.
What To Notice
- Soft constraints are more flexible, especially when hard constraints are awkward to encode.
- Finite differences provide an alternative to fully automatic differentiation.
- This is a good template when you need explicit control over boundary and residual terms.
Script Snippet
"""01 — 1-D Poisson equation (soft boundary constraints + FD Laplacian)
Problem
-------
−u''(x) = sin(πx), x ∈ [0, 1], u(0) = u(1) = 0
Analytical solution
-------------------
u(x) = sin(πx) / π²
Compared to laplace_1d.py this example uses:
* Finite-difference second derivative (u.d2)
* Soft boundary constraints (separate boundary tag)
* A tracker that logs the MSE error every 100 epochs
"""
import jax
import jno
import jno.jnp_ops as jnn
import optax
import numpy as np
import matplotlib.pyplot as plt
from jno import LearningRateSchedule as lrs
π = jnn.pi
dire = jno.setup(__file__)
# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(constructor=jno.domain.line(mesh_size=0.01))
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")
# ── Analytical solution ───────────────────────────────────────────────────────
u_exact = jnn.sin(π * x) / π**2
# ── Network ───────────────────────────────────────────────────────────────────
u_net = jno.nn.mlp(
in_features=1,
hidden_dims=64,
num_layers=4,
key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 10_000, 1e-5))
u = u_net(x)
# ── Constraints ───────────────────────────────────────────────────────────────
pde = -u.d2(x, scheme="finite_difference") - jnn.sin(π * x)
bc = u_net(xb) # soft: u(0) = u(1) = 0
error = jnn.tracker((u - u_exact).mse, interval=100)
# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc.mse, error], domain)
history = crux.solve(10_000)
history.plot(f"{dire}/training_history.png")
# ── Plot ──────────────────────────────────────────────────────────────────────
pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]
mae = np.abs(pred - true).mean()
print(f"Mean absolute error vs exact: {mae:.6e}")
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.set_title("Solution")
ax1.plot(xs, pred, label="PINN")
ax1.plot(xs, true, "--", label="exact")
ax1.set_xlabel("x")
ax1.legend()
ax2.set_title("Pointwise |error|")
ax2.plot(xs, np.abs(pred - true))
ax2.set_xlabel("x")
plt.tight_layout()
plt.savefig(f"{dire}/solution.png", dpi=150)
print(f"Saved to {dire}/")