Heat 1D
This example solves the transient 1D heat equation and introduces time as an explicit input to the model.
Problem Setup
The script solves a diffusion equation of the form u_t = alpha u_xx on a space-time domain with zero Dirichlet boundaries and a sinusoidal initial condition.
Step 1: Build a Space-Time Domain
The domain includes both space and time, with separate sampling for interior and initial-condition points.
Step 2: Use a DeepONet-Style Model
The example uses a DeepONet architecture in PINN mode so the model can learn a time-dependent field over the full domain.
Step 3: Hard-Enforce Spatial Boundary Conditions
A spatial envelope x(1-x) keeps the field zero at the two endpoints for every time.
Step 4: Add the Initial Condition as a Separate Constraint
The PDE residual governs the interior, while a second loss enforces the known initial profile at t = 0.
What To Notice
- Time-dependent PDEs need both interior physics and initial data.
- The jNO workflow stays similar even though the field now depends on multiple coordinates.
- This is the cleanest parabolic starting point in the tutorial set.
Script Snippet
"""03 — 1-D heat equation (parabolic, time-dependent)
Problem
-------
∂u/∂t = α ∂²u/∂x², x ∈ [0,1], t ∈ [0, 0.5]
u(0, t) = u(1, t) = 0 (homogeneous Dirichlet)
u(x, 0) = sin(πx) (initial condition)
Analytical solution
-------------------
u(x, t) = exp(−απ²t) sin(πx)
Network ansatz
--------------
u ≈ net(t, x) · x (1−x) — hard-enforces BCs for all t
The initial condition is implemented as a soft constraint evaluated on a
separate "initial" tag.
"""
import copy
import jax
import jno
import jno.jnp_ops as jnn
import optax
import numpy as np
import matplotlib.pyplot as plt
from jno import LearningRateSchedule as lrs
π = jnn.pi
dire = jno.setup(__file__)
# ── Physical parameter ────────────────────────────────────────────────────────
α = 0.1 # thermal diffusivity
T_end = 0.5 # final time
# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(
constructor=jno.domain.line(mesh_size=0.02),
time=(0, T_end, 10),
)
x, t = domain.variable("interior")
x0, _ = domain.variable("initial")
# ── Analytical solution ───────────────────────────────────────────────────────
u_exact = jnn.exp(-α * π**2 * t) * jnn.sin(π * x)
# ── Network ───────────────────────────────────────────────────────────────────
net = jnn.nn.deeponet(
n_sensors=1,
sensor_channels=1,
coord_dim=1,
basis_functions=32,
hidden_dim=32,
n_layers=3,
key=jax.random.PRNGKey(0),
)
net.optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.7, 10_000, 1e-5))
u = net(t, x) * x * (1 - x) # hard Dirichlet BCs
u0 = net(0.0, x0) * x0 * (1 - x0)
# ── Constraints ───────────────────────────────────────────────────────────────
# PDE: u_t − α u_xx = 0
pde = jnn.grad(u, t) - α * jnn.grad(jnn.grad(u, x), x)
# IC: u(x, 0) = sin(πx)
ini = u0 - jnn.sin(π * x0)
error = jnn.tracker((u - u_exact).mse, interval=200)
# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, ini.mse, error], domain)
history = crux.solve(10_000)
history.plot(f"{dire}/training_history.png")
# ── Plot: solution at selected time slices ────────────────────────────────────
pts = np.array(domain.context["interior"][0, 0, :, 0]) # spatial coords
idx = np.argsort(pts)
xs = pts[idx]
time_values = np.array(domain.context["__time__"]).reshape(-1)
n_t = time_values.shape[0]
def eval_snapshots(expr):
values = []
for ti in range(n_t):
sub_domain = copy.deepcopy(domain)
sub_domain.context["__time__"] = np.asarray(domain.context["__time__"])[ti : ti + 1]
sub_domain.context["interior"] = np.asarray(domain.context["interior"])[:, ti : ti + 1, :, :]
values.append(np.array(crux.eval(expr, domain=sub_domain))[0, :, 0])
return np.stack(values, axis=0)
pred_all = eval_snapshots(u)
true_all = eval_snapshots(u_exact)
snap_times = [0, n_t // 4, n_t // 2, n_t - 1]
fig, axes = plt.subplots(1, len(snap_times), figsize=(14, 4), sharey=True)
for ax, ti in zip(axes, snap_times):
t_val = float(time_values[ti])
p = pred_all[ti, :][idx]
r = true_all[ti, :][idx]
ax.plot(xs, r, "--", label="exact")
ax.plot(xs, p, label="PINN")
ax.set_title(f"t = {t_val:.3f}")
ax.set_xlabel("x")
if ax is axes[0]:
ax.set_ylabel("u")
ax.legend()
plt.suptitle(f"1-D heat equation α={α}", fontsize=13)
plt.tight_layout()
plt.savefig(f"{dire}/solution.png", dpi=150)
print(f"Saved to {dire}/")