Reaction-Diffusion 2D
This example augments diffusion with a linear reaction term in a transient 2D setting.
Problem Setup
The PDE has the form u_t - nu Delta u + lambda u = f, with a manufactured exact solution used for validation.
Step 1: Build the Space-Time Problem
The script samples interior points in a 2D domain over time and tracks an exact reference solution.
Step 2: Use a Hard Boundary Ansatz
The model is wrapped with a boundary envelope so the solution remains zero on the outer edges.
Step 3: Add Both Initial and PDE Residuals
The time-dependent PDE residual and the initial-condition loss are optimized together.
Step 4: Use a Standard Training Schedule
This script is a good reference for a clean, standard jNO transient training setup with a manufactured source term.
What To Notice
- The reaction term changes the balance of the dynamics without changing the basic workflow.
- Manufactured solutions are especially useful for validating transient codes.
- This is a useful bridge from heat equations to nonlinear parabolic systems.
Script Snippet
"""03 - 2-D reaction-diffusion equation
Problem
-------
u_t - nu Delta u + lambda u = f(x, y, t), (x, y) in [0, 1]^2, t in [0, 1]
u = 0 on the boundary
u(x, y, 0) = sin(pi x) sin(pi y)
Analytical solution
-------------------
u(x, y, t) = exp(-t) sin(pi x) sin(pi y)
"""
import copy
import os
import jax
import jno
import jno.jnp_ops as jnn
import matplotlib.pyplot as plt
import matplotlib.tri as tri
import numpy as np
import optax
from jno import LearningRateSchedule as lrs
def pick(default, test):
pi = jnn.pi
dire = jno.setup(__file__)
nu = 0.1
lam = 0.5
T_end = 1.0
N_t = pick(8, 4)
domain = jno.domain(
constructor=jno.domain.rect(mesh_size=pick(0.05, 0.2)),
time=(0, T_end, N_t),
compute_mesh_connectivity=False,
)
x, y, t = domain.variable("interior")
x0, y0, _ = domain.variable("initial")
t_like = 0 * x + t
zero0 = 0 * x0
u_exact = jnn.exp(-t) * jnn.sin(pi * x) * jnn.sin(pi * y)
source = (-1 + 2 * nu * pi**2 + lam) * u_exact
net = jnn.nn.mlp(in_features=3, hidden_dims=48, num_layers=4, key=jax.random.PRNGKey(21))
net.optimizer(optax.adam(1), lr=lrs.warmup_cosine(pick(12_000, 10), pick(500, 1), 1e-3, 1e-5))
u = net(x, y, t_like) * x * (1 - x) * y * (1 - y)
u0 = net(x0, y0, zero0) * x0 * (1 - x0) * y0 * (1 - y0)
pde = jnn.grad(u, t) - nu * jnn.laplacian(u, [x, y]) + lam * u - source
ini = u0 - jnn.sin(pi * x0) * jnn.sin(pi * y0)
error = jnn.tracker((u - u_exact).mse, interval=pick(200, 1))
crux = jno.core([pde.mse, ini.mse, error], domain)
history = crux.solve(pick(12_000, 10))
# Always run full mode
print("Smoke test completed for reaction_diffusion_2d.py")
else:
history.plot(f"{dire}/training_history.png")
pts = np.array(domain.context["interior"][0, 0])
xs, ys = pts[:, 0], pts[:, 1]
triang = tri.Triangulation(xs, ys)
time_values = np.array(domain.context["__time__"]).reshape(-1)
snap_idx = [0, len(time_values) // 2, len(time_values) - 1]
fig, axes = plt.subplots(3, len(snap_idx), figsize=(4 * len(snap_idx), 9))
for col, ti in enumerate(snap_idx):
sub_domain = copy.deepcopy(domain)
sub_domain.context["__time__"] = np.asarray(domain.context["__time__"])[ti : ti + 1]
sub_domain.context["interior"] = np.asarray(domain.context["interior"])[:, ti : ti + 1, :, :]
pred = np.array(crux.eval(u, domain=sub_domain))[0, :, 0]
true = np.array(crux.eval(u_exact, domain=sub_domain))[0, :, 0]
err = np.abs(pred - true)
t_val = float(time_values[ti])
vmin, vmax = true.min(), true.max()
for row, (data, cmap, title) in enumerate(
[
(true, "viridis", f"Exact t={t_val:.2f}"),
(pred, "viridis", f"PINN t={t_val:.2f}"),
(err, "hot", f"|err| t={t_val:.2f}"),
]
):
kwargs = dict(shading="gouraud", cmap=cmap)
if row < 2:
kwargs.update(vmin=vmin, vmax=vmax)
tc = axes[row, col].tripcolor(triang, data, **kwargs)
fig.colorbar(tc, ax=axes[row, col], shrink=0.8)
axes[row, col].set_aspect("equal")
axes[row, col].set_title(title, fontsize=8)
plt.tight_layout()
plt.savefig(f"{dire}/comparison.png", dpi=150, bbox_inches="tight")
print(f"Saved to {dire}/")