Fokker-Planck 2D (Ornstein-Uhlenbeck)
This example solves the stationary Fokker-Planck equation for a 2-D Ornstein-Uhlenbeck process — a PDE that describes how a probability density evolves under drift and diffusion. It introduces jno.noise to add stochastic measurement uncertainty to the boundary observations, so the network sees freshly-sampled noisy data every training step.
The Stochastic Process
The underlying stochastic differential equations (Itô form) are:
- The drift \(-X_t\), \(-Y_t\) pulls the process back toward the origin (restoring force with rate 1).
- The diffusion coefficient \(\sigma = 1\) sets the strength of the Brownian noise.
This is the simplest mean-reverting process in 2-D and is analytically tractable — an ideal test case for a PINN.
The forward Kolmogorov (Fokker-Planck) equation for the stationary density \(p(x, y)\) is:
The first two terms are the drift divergence \(\nabla \cdot (\mathbf{b} \, p)\) with \(\mathbf{b}(x,y) = (-x, -y)\). The last term is the diffusion Laplacian \(\tfrac{\sigma^2}{2} \Delta p\).
Boundary condition: \(p \approx 0\) on \(\partial\Omega\) (the Gaussian decays to \(e^{-9}/\pi \approx 4 \times 10^{-5}\) at the domain edges).
For unit restoring rate and \(\sigma = 1\) the stationary density is a bivariate Gaussian:
Verification. Substituting \(p = C e^{-(x^2+y^2)}\):
| Term | Value |
|---|---|
| \(\partial(xp)/\partial x\) | \(p - 2x^2 p\) |
| \(\partial(yp)/\partial y\) | \(p - 2y^2 p\) |
| \(\tfrac{1}{2}\Delta p\) | \((−2 + 2x^2 − 2 + 2y^2) p / 2 = (-2 + x^2 + y^2 - 1)p\)... |
Working through all terms cancels exactly to zero. ✓
Code Walkthrough
Step 1 — Domain centred at the origin and analytical boundary
The domain is centred at \((0, 0)\) so the Gaussian peak sits at the domain centre. The analytical boundary value \(p_\text{exact}\) is reused later as the (near-zero) target for the noisy boundary observations.
# Disc of radius 3 centred at the origin — captures the Gaussian's effective support.
domain = jno.domain(Point(0, 0).buffer(3.0), mesh_size=0.25)
x, y, _ = domain.variable("interior")
xb, yb, _ = domain.variable("boundary")
p_exact = jno.np.exp(-(x**2 + y**2)) / π
p_exact_bc = jno.np.exp(-(xb**2 + yb**2)) / π
Step 2 — Fokker-Planck residual
jno.np.vector(...) builds a typed VectorView from scalar components without manual concat, and .div(x, y) reads exactly like the math \(\nabla \cdot (\mathbf{b} p)\). The Laplacian \(\Delta p\) is computed by jno.np.laplacian.
net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=64, num_layers=5, key=jax.random.PRNGKey(0)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 2000, 0.5, end_value=1e-5)))
p = net(x, y).scalar.bind(x=x, y=y)
prob_flux = jno.np.vector(x * p, y * p) # OU drift flux as a VectorView
drift = prob_flux.div(x, y) # ∇·(b·p)
diff = 0.5 * (p.xx + p.yy) # ½ ∆p
fp = drift + diff # residual = 0
Step 3 — Normalization and noisy boundary
.integrate() reduces \(p(x, y)\) to the scalar \(\iint_\Omega p \, dx \, dy\) using mesh-based quadrature weights. Subtracting 1 creates a loss that drives total probability mass to 1.
jno.noise.gaussian(std=1e-4) is a lazy Placeholder — each training step the solver splits its PRNG key and samples a fresh \((N_b, 1)\) array, simulating noisy physical measurements at the boundary.
norm = p.integrate() - 1.0 # ∬ p dx dy = 1
p_bc = net(xb, yb) - (p_exact_bc + jno.noise.gaussian(std=1e-4))
Reproducibility
The noise sequence is fully determined by the global seed. Set it with jno.setup(seed=42) or in .jno.toml to reproduce the exact same training run.
Step 4 — Solve
Three losses compete: PDE residual, normalization, and noisy boundary data. The solver balances them using its built-in loss weighting.
residuals = jno.trackers.residual_stats(interval=1000)
crux = jno.core([fp.mse, norm.mse, p_bc.mse])
crux.solve(15_000, callbacks=[residuals])
What to Notice
- Noise on observations, not the physics. The Fokker-Planck residual is deterministic; only the boundary data carries noise. This mirrors the real-world setting where the governing equation is known but measurements are uncertain.
- The network still converges. Because \(\operatorname{std} = 10^{-4}\) is of the same order as the true boundary values (\(\sim 4 \times 10^{-5}\)), the signal-to-noise ratio at the boundary is low. The PDE and normalization constraints compensate and anchor the interior solution.
.integrate()is differentiable. Gradients flow through the normalization term so the optimizer simultaneously adjusts the global scale of \(p\) and the shape of the Fokker-Planck residual.- No manual key management. The user never calls
jax.random.splitor threads keys through the loss —jno.noisehandles all of that inside the solver.
Full Script
"""07 — 2-D Fokker–Planck on a disc (shapely + RAD resampling + residual tracker)"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import optax
from shapely.geometry import Point
import jno
π = jno.np.pi
# Disc of radius 3 centred at the origin — captures the Gaussian's effective support.
domain = jno.domain(Point(0, 0).buffer(3.0), mesh_size=0.25)
x, y, _ = domain.variable("interior")
xb, yb, _ = domain.variable("boundary")
p_exact = jno.np.exp(-(x**2 + y**2)) / π
p_exact_bc = jno.np.exp(-(xb**2 + yb**2)) / π
net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=64, num_layers=5, key=jax.random.PRNGKey(0)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 2000, 0.5, end_value=1e-5)))
p = net(x, y).scalar.bind(x=x, y=y)
prob_flux = jno.np.vector(x * p, y * p) # OU drift flux as a VectorView
drift = prob_flux.div(x, y) # ∇·(b·p)
diff = 0.5 * (p.xx + p.yy) # ½ ∆p
fp = drift + diff # residual = 0
norm = p.integrate() - 1.0 # ∬ p dx dy = 1
p_bc = net(xb, yb) - (p_exact_bc + jno.noise.gaussian(std=1e-4))
residuals = jno.trackers.residual_stats(interval=1000)
crux = jno.core([fp.mse, norm.mse, p_bc.mse])
crux.solve(15_000, callbacks=[residuals])
_p, _p_exact = crux.eval([p, p_exact])
rel_l2 = float(jnp.linalg.norm(_p - _p_exact) / (jnp.linalg.norm(_p_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
if residuals.value is not None:
print(f"Per-constraint max residuals: {residuals.value['maxes']}")
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(f"07_stochastic/fokker_planck_2d.py | epochs=15000 | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 3e-1, f"relative L2 error too large: {rel_l2:.3e}"