Skip to content

Heat 1D

Transient 1D heat equation — the introductory time-dependent example. Uses soft enforcement for both the initial condition and the spatial boundary conditions, which keeps the IC/BC roles explicit and works on any geometry (not just the unit interval).

Problem Setup

u_t = α u_xx on (x, t) ∈ [0, 1] × [0, 0.5], with u(0, t) = u(1, t) = 0 (Dirichlet) and u(x, 0) = sin(π x) (initial). Exact solution: u(x, t) = e^{-α π² t} sin(π x).

Step 1: Build a Space-Time Domain

α = 0.1
T_end = 0.5

domain = jno.domain.line(mesh_size=0.01, time=(0, T_end, 10))
x, t   = domain.variable("interior")   # full interior of space-time
x0, t0 = domain.variable("initial")    # t = 0 slice
xb, tb = domain.variable("boundary")   # x = 0 and x = 1 at all t

Three tags sampled — one for the PDE residual, one for the IC, one for the BC.

Step 2: Bare-Network Ansatz

net = jno.nn.wrap(
    foundax.deeponet(
        n_sensors=1, coord_dim=1, n_outputs=1,
        n_layers=3, basis_functions=64, hidden_dim=32,
        key=jax.random.PRNGKey(0),
    )
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 10000, 0.9, end_value=1e-5)))

u = net(t, x)   # no multiplicative ansatz — IC + BC enforced via loss terms below

Step 3: Three Constraints — PDE + IC + BC

# Interior PDE residual:  u_t − α u_xx = 0
pde = u.d(t) - α * u.d2(x)

# Initial condition:  net(0, x) = sin(πx)
ic = net(t0, x0) - jno.np.sin(π * x0)

# Spatial boundary:  net(t, 0) = net(t, 1) = 0
bc = net(tb, xb)

crux = jno.core([pde.mse, ic.mse, bc.mse])
history = crux.solve(10000)

What To Notice

  • Each physical condition is its own constraint term — the IC, the BC, and the PDE are three separate scalars that the optimiser balances. There is no ansatz hiding any of them.
  • For unit-interval Dirichlet problems a hard ansatz u = sin(πx) + t · net(t,x) · x(1−x) would work and remove two of the three losses (see the original Laplace 1D for the hard-ansatz pattern). The soft pattern shown here generalises to arbitrary geometries and to PDEs where no clean ansatz exists.
  • DeepONet is used here in PINN mode (single instance, no parameter sweep). The branch/trunk split makes it expressive enough to capture both space and time dependence with a small network.

Script Snippet

"""03 — 1-D heat equation (parabolic, time-dependent)"""

import foundax
import jax
import optax

import jno

π = jno.np.pi
α = 0.1
T_end = 0.5

domain = jno.domain.line(mesh_size=0.05, time=(0, T_end, 4))
x, t = domain.variable("interior")
x0, t0 = domain.variable("initial")
xb, tb = domain.variable("boundary")

u_exact = jno.np.exp(-α * π**2 * t) * jno.np.sin(π * x)

net = jno.nn.wrap(
    foundax.deeponet(
        n_sensors=1,
        coord_dim=1,
        n_outputs=1,
        n_layers=3,
        basis_functions=48,
        hidden_dim=32,
        key=jax.random.PRNGKey(0),
    )
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 2000, 0.5, end_value=1e-5)))

u = net(t, x).scalar.bind(x=x, t=t)

pde = u.t - α * u.xx  # PDE residual
ic = net(t0, x0) - jno.np.sin(π * x0)  # initial condition (t=0 slice)
bc = net(tb, xb)  # spatial boundary (u=0)

crux = jno.core([pde.mse, ic.mse, bc.mse])
crux.solve(5000)

_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
assert rel_l2 < 2e-1, f"relative L2 error too large: {rel_l2:.3e}"