Skip to content

Fokker-Planck 2D (Ornstein-Uhlenbeck)

This example solves the stationary Fokker-Planck equation for a 2-D Ornstein-Uhlenbeck process — a PDE that describes how a probability density evolves under drift and diffusion. It introduces jno.noise to add stochastic measurement uncertainty to the boundary observations, so the network sees freshly-sampled noisy data every training step.


The Stochastic Process

The underlying stochastic differential equations (Itô form) are:

\[dX_t = -X_t \, dt + dW_t^{(1)}, \qquad dY_t = -Y_t \, dt + dW_t^{(2)}\]
  • The drift \(-X_t\), \(-Y_t\) pulls the process back toward the origin (restoring force with rate 1).
  • The diffusion coefficient \(\sigma = 1\) sets the strength of the Brownian noise.

This is the simplest mean-reverting process in 2-D and is analytically tractable — an ideal test case for a PINN.

The forward Kolmogorov (Fokker-Planck) equation for the stationary density \(p(x, y)\) is:

\[\frac{\partial (x \, p)}{\partial x} + \frac{\partial (y \, p)}{\partial y} + \frac{\sigma^2}{2} \left(\frac{\partial^2 p}{\partial x^2} + \frac{\partial^2 p}{\partial y^2}\right) = 0, \qquad (x, y) \in \Omega = [-3, 3]^2\]

The first two terms are the drift divergence \(\nabla \cdot (\mathbf{b} \, p)\) with \(\mathbf{b}(x,y) = (-x, -y)\). The last term is the diffusion Laplacian \(\tfrac{\sigma^2}{2} \Delta p\).

Boundary condition: \(p \approx 0\) on \(\partial\Omega\) (the Gaussian decays to \(e^{-9}/\pi \approx 4 \times 10^{-5}\) at the domain edges).

For unit restoring rate and \(\sigma = 1\) the stationary density is a bivariate Gaussian:

\[p^\infty(x, y) = \frac{1}{\pi} \exp\!\left(-(x^2 + y^2)\right)\]

Verification. Substituting \(p = C e^{-(x^2+y^2)}\):

Term Value
\(\partial(xp)/\partial x\) \(p - 2x^2 p\)
\(\partial(yp)/\partial y\) \(p - 2y^2 p\)
\(\tfrac{1}{2}\Delta p\) \((−2 + 2x^2 − 2 + 2y^2) p / 2 = (-2 + x^2 + y^2 - 1)p\)...

Working through all terms cancels exactly to zero. ✓


Code Walkthrough

Step 1 — Domain centred at the origin and analytical boundary

The domain is centred at \((0, 0)\) so the Gaussian peak sits at the domain centre. The analytical boundary value \(p_\text{exact}\) is reused later as the (near-zero) target for the noisy boundary observations.

# Disc of radius 3 centred at the origin — captures the Gaussian's effective support.
domain = jno.domain(Point(0, 0).buffer(3.0), mesh_size=0.25)
x, y, _ = domain.variable("interior")
xb, yb, _ = domain.variable("boundary")

p_exact = jno.np.exp(-(x**2 + y**2)) / π
p_exact_bc = jno.np.exp(-(xb**2 + yb**2)) / π

Step 2 — Fokker-Planck residual

jno.np.vector(...) builds a typed VectorView from scalar components without manual concat, and .div(x, y) reads exactly like the math \(\nabla \cdot (\mathbf{b} p)\). The Laplacian \(\Delta p\) is computed by jno.np.laplacian.

net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=64, num_layers=5, key=jax.random.PRNGKey(0)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 2000, 0.5, end_value=1e-5)))

p = net(x, y).scalar.bind(x=x, y=y)
prob_flux = jno.np.vector(x * p, y * p)  # OU drift flux as a VectorView
drift = prob_flux.div(x, y)  # ∇·(b·p)
diff = 0.5 * (p.xx + p.yy)  # ½ ∆p
fp = drift + diff  # residual = 0

Step 3 — Normalization and noisy boundary

.integrate() reduces \(p(x, y)\) to the scalar \(\iint_\Omega p \, dx \, dy\) using mesh-based quadrature weights. Subtracting 1 creates a loss that drives total probability mass to 1.

jno.noise.gaussian(std=1e-4) is a lazy Placeholder — each training step the solver splits its PRNG key and samples a fresh \((N_b, 1)\) array, simulating noisy physical measurements at the boundary.

norm = p.integrate() - 1.0  # ∬ p dx dy = 1
p_bc = net(xb, yb) - (p_exact_bc + jno.noise.gaussian(std=1e-4))

Reproducibility

The noise sequence is fully determined by the global seed. Set it with jno.setup(seed=42) or in .jno.toml to reproduce the exact same training run.

Step 4 — Solve

Three losses compete: PDE residual, normalization, and noisy boundary data. The solver balances them using its built-in loss weighting.

residuals = jno.trackers.residual_stats(interval=1000)
crux = jno.core([fp.mse, norm.mse, p_bc.mse])
crux.solve(15_000, callbacks=[residuals])

What to Notice

  • Noise on observations, not the physics. The Fokker-Planck residual is deterministic; only the boundary data carries noise. This mirrors the real-world setting where the governing equation is known but measurements are uncertain.
  • The network still converges. Because \(\operatorname{std} = 10^{-4}\) is of the same order as the true boundary values (\(\sim 4 \times 10^{-5}\)), the signal-to-noise ratio at the boundary is low. The PDE and normalization constraints compensate and anchor the interior solution.
  • .integrate() is differentiable. Gradients flow through the normalization term so the optimizer simultaneously adjusts the global scale of \(p\) and the shape of the Fokker-Planck residual.
  • No manual key management. The user never calls jax.random.split or threads keys through the loss — jno.noise handles all of that inside the solver.

Full Script

"""07 — 2-D Fokker–Planck on a disc  (shapely + RAD resampling + residual tracker)"""

from pathlib import Path

import foundax
import jax
import jax.numpy as jnp
import optax
from shapely.geometry import Point

import jno

π = jno.np.pi

# Disc of radius 3 centred at the origin — captures the Gaussian's effective support.
domain = jno.domain(Point(0, 0).buffer(3.0), mesh_size=0.25)
x, y, _ = domain.variable("interior")
xb, yb, _ = domain.variable("boundary")

p_exact = jno.np.exp(-(x**2 + y**2)) / π
p_exact_bc = jno.np.exp(-(xb**2 + yb**2)) / π

net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=64, num_layers=5, key=jax.random.PRNGKey(0)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 2000, 0.5, end_value=1e-5)))

p = net(x, y).scalar.bind(x=x, y=y)
prob_flux = jno.np.vector(x * p, y * p)  # OU drift flux as a VectorView
drift = prob_flux.div(x, y)  # ∇·(b·p)
diff = 0.5 * (p.xx + p.yy)  # ½ ∆p
fp = drift + diff  # residual = 0

norm = p.integrate() - 1.0  # ∬ p dx dy = 1
p_bc = net(xb, yb) - (p_exact_bc + jno.noise.gaussian(std=1e-4))

residuals = jno.trackers.residual_stats(interval=1000)
crux = jno.core([fp.mse, norm.mse, p_bc.mse])
crux.solve(15_000, callbacks=[residuals])

_p, _p_exact = crux.eval([p, p_exact])
rel_l2 = float(jnp.linalg.norm(_p - _p_exact) / (jnp.linalg.norm(_p_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
if residuals.value is not None:
    print(f"Per-constraint max residuals: {residuals.value['maxes']}")

results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
    f.write(f"07_stochastic/fokker_planck_2d.py | epochs=15000 | rel_L2={rel_l2:.6e}\n")

assert rel_l2 < 3e-1, f"relative L2 error too large: {rel_l2:.3e}"