Allen-Cahn 2D
This example solves a manufactured 2D Allen-Cahn problem and introduces a nonlinear cubic reaction term.
Problem Setup
The PDE has the Allen-Cahn structure u_t = epsilon^2 Delta u + u - u^3 + f, with a known exact solution used to build the forcing term.
Step 1: Build a Manufactured Nonlinear Problem
The exact solution is substituted into the PDE to derive a forcing term that makes validation straightforward.
eps = 0.1
T_end = 1.0
domain = jno.domain(
constructor=jno.domain.rect(mesh_size=0.05),
time=(0, T_end, 4),
)
x, y, t = domain.variable("interior")
S = sin(π * x) * sin(π * y)
u_exact = exp(-t) * S
coeff = 2 * eps**2 * π**2 - 2
source = exp(-t) * S * coeff + exp(-3 * t) * S**3
Step 2: Set Up the Space-Time Network
The model learns a field over space and time while respecting the chosen boundary handling.
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1, coord_dim=2, n_outputs=1,
n_layers=3, basis_functions=64, hidden_dim=40,
key=jax.random.PRNGKey(42),
)
)
net.optimizer(optax.adam(optax.warmup_cosine_decay_schedule(init_value=0.0, peak_value=1e-3, warmup_steps=1, decay_steps=500, end_value=1e-5)))
xy = jno.np.concat([x, y])
u = net(t, xy) * x * (1 - x) * y * (1 - y)
Step 3: Encode the Nonlinear Residual
The key change relative to the heat equation is the nonlinear reaction term u - u^3.
Step 4: Impose the Initial Condition
The script uses the same PDE infrastructure but anchors the solution at the initial time with an additional loss.
u_at_0 = net(0 * t, xy) * x * (1 - x) * y * (1 - y)
ini = u_at_0 - sin(π * x) * sin(π * y)
crux = jno.core([pde.mse, ini.mse])
history = crux.solve(5000)
What To Notice
- Nonlinear reaction terms are easy to express once the field is available symbolically.
- Manufactured solutions are especially valuable for nonlinear PDEs.
- This example is a good template for phase-field style problems.
Script Snippet
"""03 — 2-D Allen–Cahn equation (manufactured-solution verification)"""
import foundax
import jax
import optax
import jno
π = jno.np.pi
ε = 0.1
T_end = 1.0
# Time-dependent rectangle (jno.domain.rect produces the time-extended sampler;
# PolygonDomain doesn't yet support the ``time=`` axis).
domain = jno.domain.rect(mesh_size=0.05, time=(0, T_end, 4))
x, y, t = domain.variable("interior")
S = jno.np.sin(π * x) * jno.np.sin(π * y)
u_exact = jno.np.exp(-t) * S
source = jno.np.exp(-t) * S * (2 * ε**2 * π**2 - 2) + jno.np.exp(-3 * t) * S**3
# Network with hard Dirichlet BCs in space; t is fed via the trunk input.
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1,
coord_dim=2,
n_outputs=1,
n_layers=3,
basis_functions=64,
hidden_dim=40,
key=jax.random.PRNGKey(42),
)
)
net.optimizer(
optax.adam(
optax.warmup_cosine_decay_schedule(init_value=0.0, peak_value=1e-3, warmup_steps=1, decay_steps=500, end_value=1e-5)
)
)
xy = jno.np.concat([x, y])
# Bind names so partials read like the math: u.t, u.xx, u.yy, u.xy, ...
u = (net(t, xy) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y, t=t)
pde = u.t - ε**2 * (u.xx + u.yy) - u + u**3 - source
# Initial condition (t=0 via 0*t trick)
u_at_0 = net(0 * t, xy) * x * (1 - x) * y * (1 - y)
ini = u_at_0 - S
grad_norms = jno.trackers.gradient_norms(interval=500)
crux = jno.core([pde.mse, ini.mse])
print(f"Allen–Cahn 2-D (ε = {ε})")
crux.solve(5000, callbacks=[grad_norms])
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
if grad_norms.value is not None:
print(f"Final ∇L norms (pde, ini): {grad_norms.value['norms']}")
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"