Reaction-Diffusion 2D
This example augments diffusion with a linear reaction term in a transient 2D setting.
Problem Setup
The PDE has the form u_t - nu Delta u + lambda u = f, with a manufactured exact solution used for validation.
Step 1: Build the Space-Time Problem
The script samples interior points in a 2D domain over time and tracks an exact reference solution.
nu = 0.1
lam = 0.5
T_end = 1.0
N_t = 4
domain = jno.domain(
constructor=jno.domain.rect(mesh_size=0.05),
time=(0, T_end, N_t),
compute_mesh_connectivity=False,
)
x, y, t = domain.variable("interior")
x0, y0, t0 = domain.variable("initial")
u_exact = jno.np.exp(-t) * jno.np.sin(pi * x) * jno.np.sin(pi * y)
source = (-1 + 2 * nu * pi**2 + lam) * u_exact
Step 2: Use a Hard Boundary Ansatz
The model is wrapped with a boundary envelope so the solution remains zero on the outer edges.
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1, coord_dim=2, n_outputs=1,
n_layers=5, basis_functions=96, hidden_dim=64,
key=jax.random.PRNGKey(21),
)
)
net.optimizer(optax.adam(optax.warmup_cosine_decay_schedule(...)))
xy = jno.np.concat([x, y])
xy0 = jno.np.concat([x0, y0])
u = net(t, xy) * x * (1 - x) * y * (1 - y)
u0 = net(t0, xy0) * x0 * (1 - x0) * y0 * (1 - y0)
Step 3: Add Both Initial and PDE Residuals
The time-dependent PDE residual and the initial-condition loss are optimized together.
pde = u.d(t) - nu * jno.np.laplacian(u, [x, y]) + lam * u - source
ini = u0 - jno.np.sin(pi * x0) * jno.np.sin(pi * y0)
crux = jno.core([pde.mse, ini.mse])
Step 4: Use a Standard Training Schedule
This script is a good reference for a clean, standard jNO transient training setup with a manufactured source term.
history = crux.solve(40000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
What To Notice
- The reaction term changes the balance of the dynamics without changing the basic workflow.
- Manufactured solutions are especially useful for validating transient codes.
- This is a useful bridge from heat equations to nonlinear parabolic systems.
Script Snippet
"""03 — 2-D reaction-diffusion equation"""
from pathlib import Path
import foundax
import jax
import optax
import jno
π = jno.np.pi
ν = 0.1
λ = 0.5
T_end = 1.0
domain = jno.domain.rect(mesh_size=0.05, time=(0, T_end, 4))
x, y, t = domain.variable("interior")
x0, y0, t0 = domain.variable("initial")
u_exact = jno.np.exp(-t) * jno.np.sin(π * x) * jno.np.sin(π * y)
source = (-1 + 2 * ν * π**2 + λ) * u_exact
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1,
coord_dim=2,
n_outputs=1,
n_layers=3,
basis_functions=48,
hidden_dim=32,
key=jax.random.PRNGKey(21),
)
)
net.optimizer(optax.adam(optax.warmup_cosine_decay_schedule(0.0, 1e-3, 50, 5000, 1e-5)))
xy = jno.np.concat([x, y])
xy0 = jno.np.concat([x0, y0])
u = (net(t, xy) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y, t=t)
u0 = net(t0, xy0) * x0 * (1 - x0) * y0 * (1 - y0)
pde = u.t - ν * (u.xx + u.yy) + λ * u - source
ini = u0 - jno.np.sin(π * x0) * jno.np.sin(π * y0)
crux = jno.core([pde.mse, ini.mse])
crux.solve(5_000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(f"03_parabolic/reaction_diffusion_2d.py | epochs=5000 | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"