2-D Poisson with Stochastic Forcing
This example solves a 2-D Poisson equation whose source term is random — a model for PDEs driven by turbulent forces, uncertain body loads, or random media. Unlike the Fokker-Planck tutorial (where noise appeared on boundary observations), here jno.noise is placed directly inside the PDE residual. The network still recovers the correct deterministic solution; the tutorial explains exactly why.
Problem
- \(f(x,y) = 2\pi^2 \sin(\pi x)\sin(\pi y)\) — deterministic part of the forcing.
- \(\xi(x,y) \sim \mathcal{N}(0, 1)\) pointwise — random fluctuation, freshly drawn each epoch.
- \(\sigma = 0.5\) — noise amplitude.
The exact mean-field solution (the expectation over all noise realisations) is:
The MSE loss at training step \(k\) is
Taking the expectation over the i.i.d. noise \(\xi^{(k)} \sim \mathcal{N}(0,1)\) and expanding the square:
The cross term vanishes because \(\mathbb{E}[\xi_i] = 0\). The noise adds a constant \(\sigma^2\) that does not depend on \(\theta\), so:
The minimiser is identical to the deterministic case. Stochastic PINN training is a Monte Carlo estimator of the expected loss, and the estimator is unbiased.
Hard Dirichlet BCs are enforced via the ansatz:
The factor \(x(1-x)y(1-y)\) is zero on all four edges, so \(u_\theta = 0\) on \(\partial\Omega\) for any network output \(\hat{u}_\theta\). No boundary loss term is needed.
Code Walkthrough
Step 1 — Domain and forcing
domain = jno.domain.rect(mesh_size=0.05)
x, y, _ = domain.variable("interior")
f = 2 * π**2 * jno.np.sin(π * x) * jno.np.sin(π * y)
u_exact = jno.np.sin(π * x) * jno.np.sin(π * y)
mesh_size=0.05 gives roughly 500 interior collocation points on \([0,1]^2\).
Step 2 — Network with hard BCs
net = jno.nn.wrap(
foundax.mlp(in_features=2, hidden_dims=64, num_layers=5,
activation=jax.nn.tanh, key=jax.random.PRNGKey(0))
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.5, end_value=1e-5)))
u = net(x, y) * x * (1 - x) * y * (1 - y)
The hard-BC ansatz means the only constraint needed is the PDE residual.
Step 3 — Stochastic PDE residual
jno.noise.gaussian(std=σ) is a symbolic Placeholder — it participates in the expression tree like any other traced node. At each training step the solver splits its PRNG key and draws a fresh \((N, 1)\) sample, so every gradient update sees a different noisy residual.
Noise amplitude and convergence
With \(\sigma = 0.5\) and a Laplacian residual that is typically \(O(1)\), the signal-to-noise ratio is moderate. Training takes slightly more epochs than the deterministic problem, but the final accuracy is the same. Larger \(\sigma\) adds more variance to the stochastic gradient without changing the bias.
Step 4 — Single-loss solve
Because hard BCs remove the boundary term, there is only one loss. The optimizer drives the stochastic PDE residual to zero by learning the mean-field solution \(u^* = \sin(\pi x)\sin(\pi y)\).
Comparison: Deterministic vs Stochastic
| Deterministic | Stochastic (this tutorial) | |
|---|---|---|
| Loss | \(\lVert{-\Delta u - f}\rVert^2\) | \(\lVert{-\Delta u - f - \sigma\xi}\rVert^2\) |
| Gradient | exact | noisy estimate (unbiased) |
| Minimiser \(u^*\) | \(\sin(\pi x)\sin(\pi y)\) | same |
| Extra cost | — | one fold_in per step |
| Noise source | — | solver PRNG, seeded via jno.setup |
What to Notice
- One loss, no boundary term. Hard BCs eliminate the need for a separate boundary loss, so the entire problem reduces to a single stochastic PDE constraint.
- Noise on the physics, not the data. This is the complement of the Fokker-Planck tutorial: there, noise modelled uncertain measurements; here, it models an uncertain forcing within the governing equation itself.
ndimfor vector noise. If your field has multiple components (e.g., a 2-D velocity vector), usejno.noise.gaussian(std=σ, ndim=2)to draw a correlated \((N, 2)\) sample in one call rather than two separate scalar noise nodes.- Reproducibility via seed. Fix
jno.setup(seed=42)(or setseed = 42in.jno.toml) to get identical noise sequences across runs.
Full Script
"""07 — 2-D Poisson equation with stochastic forcing"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import optax
from shapely.geometry import box
import jno
π = jno.np.pi
σ = 0.5 # noise amplitude on the forcing
# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain(box(0, 0, 1, 1), mesh_size=0.05)
x, y, _ = domain.variable("interior")
# ── Deterministic forcing and exact solution ───────────────────────────────────
f = 2 * π**2 * jno.np.sin(π * x) * jno.np.sin(π * y)
u_exact = jno.np.sin(π * x) * jno.np.sin(π * y)
# ── Network (hard BCs via ansatz) ──────────────────────────────────────────────
net = jno.nn.wrap(
foundax.mlp(
in_features=2,
hidden_dims=64,
num_layers=5,
activation=jax.nn.tanh,
key=jax.random.PRNGKey(0),
)
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.5, end_value=1e-5)))
u = (net(x, y) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y) # u = 0 on ∂Ω by construction
# ── Stochastic PDE residual ────────────────────────────────────────────────────
# The noise term is resampled every training step. Its expectation is zero,
# so E[loss] is minimised by the deterministic solution u*(x,y) = sin(πx)sin(πy).
noise = jno.noise.gaussian(std=σ)
pde = -(u.xx + u.yy) - f - noise
# ── Solve ──────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse])
history = crux.solve(40_000)
# ── Evaluate ───────────────────────────────────────────────────────────────────
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(_u - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"Noise amplitude σ = {σ}")
print(f"Relative L2 error: {rel_l2:.4e}")
# ── Record ─────────────────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f_out:
f_out.write(f"07_stochastic/stochastic_forcing_2d.py | epochs=40000 | sigma={σ} | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"