Skip to content

Allen-Cahn 2D

This example solves a manufactured 2D Allen-Cahn problem and introduces a nonlinear cubic reaction term.

Problem Setup

The PDE has the Allen-Cahn structure u_t = epsilon^2 Delta u + u - u^3 + f, with a known exact solution used to build the forcing term.

Step 1: Build a Manufactured Nonlinear Problem

The exact solution is substituted into the PDE to derive a forcing term that makes validation straightforward.

Step 2: Set Up the Space-Time Network

The model learns a field over space and time while respecting the chosen boundary handling.

Step 3: Encode the Nonlinear Residual

The key change relative to the heat equation is the nonlinear reaction term u - u^3.

Step 4: Impose the Initial Condition

The script uses the same PDE infrastructure but anchors the solution at the initial time with an additional loss.

What To Notice

  • Nonlinear reaction terms are easy to express once the field is available symbolically.
  • Manufactured solutions are especially valuable for nonlinear PDEs.
  • This example is a good template for phase-field style problems.

Script Snippet

"""03 — 2-D Allen–Cahn equation  (manufactured-solution verification)

Problem (Allen–Cahn with source)
---------------------------------
    ∂u/∂t = ε² ∇²u + u − u³ + f(x,y,t),   (x,y) ∈ [0,1]²,  t ∈ [0,1]

Manufactured solution
---------------------
    u(x,y,t) = e^{−t} sin(πx) sin(πy)

This automatically satisfies homogeneous Dirichlet BCs on ∂[0,1]².
The source term is computed by substitution:

    f = u_t − ε² ∇²u − u + u³
      = e^{−t} sin(πx) sin(πy) (2ε²π² − 2)
        + e^{−3t} sin³(πx) sin³(πy)

Parameters: ε = 0.1  (interface width)
"""

import copy
import jax
import jno
import jno.jnp_ops as jnn
import optax
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
from jno import LearningRateSchedule as lrs

π = jnn.pi
sin = jnn.sin
exp = jnn.exp
dire = jno.setup(__file__)

eps = 0.1
T_end = 1.0

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(
    constructor=jno.domain.rect(mesh_size=0.05),
    time=(0, T_end, 10),
)
x, y, t = domain.variable("interior")

# ── Manufactured solution + source ───────────────────────────────────────────
S = sin(π * x) * sin(π * y)
u_exact = exp(-t) * S

coeff = 2 * eps**2 * π**2 - 2
source = exp(-t) * S * coeff + exp(-3 * t) * S**3

# ── Network ───────────────────────────────────────────────────────────────────
net = jno.nn.deeponet(
    n_sensors=1,
    sensor_channels=1,
    coord_dim=2,
    basis_functions=40,
    hidden_dim=40,
    n_layers=3,
    key=jax.random.PRNGKey(42),
)
net.optimizer(optax.adam(1), lr=lrs.warmup_cosine(5_000, 300, 1e-3, 1e-5))

u = net(t, jnn.concat([x, y])) * x * (1 - x) * y * (1 - y)

# ── PDE residual ──────────────────────────────────────────────────────────────
pde = jnn.grad(u, t) - eps**2 * jnn.laplacian(u, [x, y]) - u + u**3 - source

# ── Initial condition  (t=0 via 0*t trick) ──────────────────────────────────
u_at_0 = net(0 * t, jnn.concat([x, y])) * x * (1 - x) * y * (1 - y)
ini = u_at_0 - sin(π * x) * sin(π * y)

error = jnn.tracker((u - u_exact).mse, interval=200)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, ini.mse, error], domain)

print(f"Allen–Cahn 2-D  (ε={eps})")

history = crux.solve(5_000)
history.plot(f"{dire}/training_history.png")

# ── Compare vs exact  ─────────────────────────────────────────────────────────
pts = np.array(domain.context["interior"][0, 0])
xs, ys = pts[:, 0], pts[:, 1]
triang = tri.Triangulation(xs, ys)
time_values = np.array(domain.context["__time__"]).reshape(-1)


def eval_snapshots(expr):
    values = []
    for ti in range(len(time_values)):
        sub_domain = copy.deepcopy(domain)
        sub_domain.context["__time__"] = np.asarray(domain.context["__time__"])[ti : ti + 1]
        sub_domain.context["interior"] = np.asarray(domain.context["interior"])[:, ti : ti + 1, :, :]
        values.append(np.array(crux.eval(expr, domain=sub_domain))[0, :, 0])
    return np.stack(values, axis=0)


pred_all = eval_snapshots(u)
true_all = eval_snapshots(u_exact)
n_times = len(time_values)

print(f"\n{'t':>6}  {'rel L²':>12}  {'max |err|':>12}")
print("─" * 36)
for ti in range(n_times):
    t_val = float(time_values[ti])
    p = pred_all[ti, :]
    r = true_all[ti, :]
    l2_rel = np.sqrt(np.mean((p - r) ** 2)) / (np.sqrt(np.mean(r**2)) + 1e-12)
    print(f"{t_val:6.3f}  {l2_rel:12.4e}  {np.abs(p - r).max():12.4e}")

# ── Plot ──────────────────────────────────────────────────────────────────────
snap_idx = [0, n_times // 2, n_times - 1]
fig, axes = plt.subplots(3, len(snap_idx), figsize=(4 * len(snap_idx), 9))

for col, ti in enumerate(snap_idx):
    t_val = float(time_values[ti])
    p = pred_all[ti, :]
    r = true_all[ti, :]
    e = np.abs(p - r)
    vmin, vmax = r.min(), r.max()

    for row, (data, cmap, ttl) in enumerate(
        [
            (r, "viridis", f"Exact  t={t_val:.2f}"),
            (p, "viridis", f"PINN  t={t_val:.2f}"),
            (e, "hot", f"|err|  t={t_val:.2f}"),
        ]
    ):
        kw = dict(shading="gouraud", cmap=cmap)
        if row < 2:
            kw.update(vmin=vmin, vmax=vmax)
        tc = axes[row, col].tripcolor(triang, data, **kw)
        fig.colorbar(tc, ax=axes[row, col], shrink=0.8)
        axes[row, col].set_title(ttl, fontsize=8)
        axes[row, col].set_aspect("equal")
        axes[row, col].set_xticks([])
        axes[row, col].set_yticks([])

for ax, lbl in zip(axes[:, 0], ["Exact", "PINN", "|Error|"]):
    ax.set_ylabel(lbl, fontsize=10, fontweight="bold")

plt.suptitle(f"Allen–Cahn 2-D  ε={eps}", fontsize=13, y=1.01)
plt.tight_layout()
plt.savefig(f"{dire}/comparison.png", dpi=150, bbox_inches="tight")
print(f"\nSaved to {dire}/")