Skip to content

Heat 2D

This example extends the heat equation to a square domain and shows how to inspect the learned solution at multiple time slices.

Problem Setup

The PDE is u_t = alpha Delta u on the unit square with homogeneous Dirichlet boundaries and a sinusoidal initial state.

Step 1: Build the 2D Space-Time Geometry

The script samples interior space-time points on a rectangular domain and uses a separate initial-time slice for the starting condition.

Step 2: Use a DeepONet With a Hard Spatial Envelope

The model output is multiplied by x(1-x)y(1-y) so the boundary is satisfied on all four edges.

Step 3: Combine PDE and Initial Losses

The transient residual enforces the heat equation, while a dedicated initial-condition residual anchors the solution at t = 0.

Step 4: Plot Time Snapshots

One of the nice features of this script is explicit evaluation on selected time slices so you can inspect how the field evolves.

What To Notice

  • This is the natural 2D extension of Heat 1D.
  • Snapshot evaluation is a good debugging tool for time-dependent solves.
  • The same ideas generalize to more complex transient PDEs.

Script Snippet

"""03 — 2-D heat equation  (parabolic, time-dependent)

Problem
-------
    ∂u/∂t = α ∇²u,   (x,y) ∈ [0,1]²,  t ∈ [0, 0.5]
    u = 0 on ∂Ω  (homogeneous Dirichlet BCs)
    u(x,y,0) = sin(πx) sin(πy)

Analytical solution
-------------------
    u(x,y,t) = exp(−2απ²t) sin(πx) sin(πy)

The x(1−x)y(1−y) factor in the ansatz hard-enforces the Dirichlet BCs on the
unit-square boundary for all times.  The initial condition is a soft constraint
evaluated on the "initial" domain tag.
"""

import copy
import jax
import jno
import jno.jnp_ops as jnn
import optax
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.tri as tri
from jno import LearningRateSchedule as lrs

π = jnn.pi
dire = jno.setup(__file__)

α = 0.1  # thermal diffusivity
T_end = 0.5  # final time
N_t = 8  # number of time slices

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(
    constructor=jno.domain.rect(mesh_size=(0.05)),
    time=(0, T_end, N_t),
    compute_mesh_connectivity=False,
)
x, y, t = domain.variable("interior")
x0, y0, _ = domain.variable("initial")
domain.summary()
# ── Analytical solution ───────────────────────────────────────────────────────
u_exact = jnn.exp(-2 * α * π**2 * t) * jnn.sin(π * x) * jnn.sin(π * y)

# ── Network ───────────────────────────────────────────────────────────────────
net = jnn.nn.deeponet(
    n_sensors=1,
    sensor_channels=1,
    coord_dim=2,
    basis_functions=40,
    hidden_dim=40,
    n_layers=3,
    key=jax.random.PRNGKey(0),
)
net.optimizer(optax.adam(1), lr=lrs.warmup_cosine(10_000, 1_000, 1e-3, 1e-5))
net.summary()
xy = jnn.concat([x, y])
xy0 = jnn.concat([x0, y0])

u = net(t, xy) * x * (1 - x) * y * (1 - y)
u0 = net(0.0, xy0) * x0 * (1 - x0) * y0 * (1 - y0)

# ── Constraints ───────────────────────────────────────────────────────────────
pde = jnn.grad(u, t) - α * jnn.laplacian(u, [x, y])
ini = u0 - jnn.sin(π * x0) * jnn.sin(π * y0)
error = jnn.tracker((u - u_exact).mse, interval=200)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, ini.mse, error], domain).print_shapes()
history = crux.solve(10_000)
history.plot(f"{dire}/training_history.png")

# ── Plot ──────────────────────────────────────────────────────────────────────
spacetime_pts = np.array(domain.context["interior"][0])
time_values = np.array(domain.context["__time__"]).reshape(-1)

pts = spacetime_pts[0]
xs, ys = pts[:, 0], pts[:, 1]
triang = tri.Triangulation(xs, ys)

n_t = time_values.shape[0]

snap_idx = [0, n_t // 2, n_t - 1]
fig, axes = plt.subplots(3, len(snap_idx), figsize=(4 * len(snap_idx), 9))

for col, ti in enumerate(snap_idx):
    t_val = float(time_values[ti])

    sub_domain = copy.deepcopy(domain)
    sub_domain.context["__time__"] = np.asarray(domain.context["__time__"])[ti : ti + 1]
    sub_domain.context["interior"] = np.asarray(domain.context["interior"])[:, ti : ti + 1, :, :]

    p = np.array(crux.eval(u, domain=sub_domain))[0, :, 0]
    r = np.array(crux.eval(u_exact, domain=sub_domain))[0, :, 0]
    e = np.abs(p - r)
    vmin, vmax = r.min(), r.max()

    for row, (data, cmap, title) in enumerate(
        [
            (r, "viridis", f"Exact  t={t_val:.2f}"),
            (p, "viridis", f"PINN  t={t_val:.2f}"),
            (e, "hot", f"|err|  t={t_val:.2f}"),
        ]
    ):
        kw = dict(shading="gouraud", cmap=cmap)
        if row < 2:
            kw.update(vmin=vmin, vmax=vmax)
        tc = axes[row, col].tripcolor(triang, data, **kw)
        fig.colorbar(tc, ax=axes[row, col], shrink=0.8)
        axes[row, col].set_title(title, fontsize=8)
        axes[row, col].set_aspect("equal")

for ax, lbl in zip(axes[:, 0], ["Exact", "PINN", "|Error|"]):
    ax.set_ylabel(lbl, fontsize=10, fontweight="bold")

plt.suptitle(f"2-D heat equation  α={α}", fontsize=13)
plt.tight_layout()
plt.savefig(f"{dire}/comparison.png", dpi=150, bbox_inches="tight")
print(f"Saved to {dire}/")