FieldView — wave-equation PINN audit
A two-stage tutorial for the 2-D wave equation. Stage 1 trains a DeepONet
PINN via crux.solve() (AD-based u.tt, u.xx, u.yy through the network
graph). Stage 2 audits the trained network's own prediction: it evaluates
the trained model on a structured grid and re-checks the wave equation with
second-order finite differences on that grid output — an independent,
discretisation-level test of whether the model satisfies the PDE, rather than a
comparison against a hand-built analytic field.
Problem setup
Mode-(1,1) standing wave: u(t, x, y) = cos(ωt) sin(πx) sin(πy), with
ω = c·π√2, so u_tt = −ω² u = c²(u_xx + u_yy).
Stage 1 — PINN training
A DeepONet maps (t, x, y) to scalar u. .scalar.bind() registers the
coordinate variables so u.tt, u.xx, u.yy trace AD derivatives through
the network.
dom_pinn = jno.domain.rect(mesh_size=0.05, time=(0, T_MAX, 6))
x_p, y_p, t_p = dom_pinn.variable("interior")
x0_p, y0_p, t0_p = dom_pinn.variable("initial")
xb_p, yb_p, tb_p = dom_pinn.variable("boundary")
net = jno.nn.wrap(foundax.deeponet(n_sensors=1, coord_dim=2, ...))
u_p = net(t_p, xy_p).scalar.bind(x=x_p, y=y_p, t=t_p)
pde_wave = u_p.tt - C**2 * (u_p.xx + u_p.yy) # wave PDE
ic_disp = net(t0_p, xy0_p) - sin(π*x0)*sin(π*y0) # u(0) = sin·sin
ic_vel = u0_p.t # u_t(0) = 0
bc_wall = net(tb_p, xy_b) # Dirichlet u = 0
crux_pinn = jno.core([pde_wave.mse, ic_disp.mse, ic_vel.mse, bc_wall.mse], domain=dom_pinn)
crux_pinn.solve(EPOCHS_TRAIN)
Both initial conditions are required for the second-order wave equation:
displacement and velocity at t = 0.
Stage 2 — FieldView audit of the trained prediction
Sampled on a grid, the coordinates x, y, t are axes, not network inputs, so
net(t, xy).field.bind(...) differentiates the model's grid output with an FD
stencil instead of AD:
T, H = 20, 32
dom_fd = jno.domain.equi_distant_rect(nx=H - 1, ny=H - 1, time=(0.0, T_MAX, T))
xg, yg, tg = dom_fd.variable("interior")
u = net(tg, jno.np.concat([xg, yg])).field.bind(x=xg, y=yg, t=tg)
wave_res = u.tt - C**2 * (u.xx + u.yy) # nested temporal + spatial FD
Everything is evaluated through the trained crux on the FD grid, with
min_consecutive=T so the nested u.tt stencil sees every frame:
pred, exact = crux_pinn.eval([net(tg, xyg), analytic], domain=dom_fd, min_consecutive=T)
res_mse = crux_pinn.eval(wave_res.expr.mse, domain=dom_fd, min_consecutive=T)
crux.eval(..., domain=dom_fd) reuses the trained weights but swaps in the grid
domain, so the FD residual is computed on exactly the network the PINN produced.
Same PDE, two derivative engines
- Stage 1
.scalar.bind()→ AD through the network graph (training). - Stage 2
.field.bind()→ FD on the grid output (audit).
The expression u.tt - c²(u.xx + u.yy) is written identically both times — one
differentiates the graph, the other the grid values.
The audit reports the prediction's relative L2 against the closed-form solution
and the FD wave residual of that same prediction — how well the trained
network satisfies u_tt = c²Δu when checked discretely on the grid.
PINN accuracy is run-dependent
The wave PINN is ill-conditioned; combined with XLA autotune and the
persistent compile cache, training can land in different basins between runs.
The script's asserts are loose guards (rel_l2 < 1) that confirm the model
beats a trivial predictor — treat the printed numbers as indicative and
confirm final accuracy/timing on GPU.
What to notice
- Audit the model, not a stand-in. Stage 2 differentiates the trained network's own output — a discretisation-level check that the PINN satisfies the PDE, independent of the AD objective it was trained on.
net(t, xy).field.bind(...)is live. FieldView wraps the model call directly; there is no need to materialise and store an array first.u.ttis nested FD. The second temporal derivative applies the central-difference stencil twice over the buffered window;min_consecutive=Tguarantees the frames are present.crux.eval(domain=...)runs the trained weights on a different (grid) domain — the standard way to evaluate a jNO model on new inputs.
Script
"""11 — FieldView audit of a trained wave-equation PINN
A DeepONet is trained as a PINN on the 2-D wave equation
u_tt = c²(u_xx + u_yy) on [0,1]², t ∈ [0, T_max], u = 0 on ∂[0,1]²,
with the closed-form standing wave u(t, x, y) = cos(ωt) sin(πx) sin(πy),
ω = c·π√2.
Stage 1 — PINN training via ``crux.solve`` (AD-based ``u.tt``, ``u.xx``, ``u.yy``
through the network graph). This is ordinary point-wise PINN training.
Stage 2 — FieldView audit of the **trained network's own prediction**: evaluate
the trained DeepONet on a structured grid and re-check the wave equation with
second-order *finite differences* on that grid output. Sampled on a grid the
coordinates are axes, not network inputs, so ``net(t, xy).field.bind(...)`` uses
an FD stencil — an independent, discretisation-level check of whether the model
actually satisfies the PDE (not a comparison against a hand-built analytic field).
"""
import foundax
import jax
import numpy as np
import optax
import jno
KEY = jax.random.PRNGKey(0)
π = jno.np.pi
# ═══════════════════════════════════════════════════════════════════════════════
# Stage 1 — PINN training with crux.solve()
# 2D wave equation via AD-based u.tt and u.xx + u.yy (.scalar.bind pattern).
# Analytic: u(t,x,y) = cos(ωt) sin(πx) sin(πy), ω = c·π√2
# ═══════════════════════════════════════════════════════════════════════════════
_C = 1.0
_OMEGA = _C * np.pi * np.sqrt(2)
_T_MAX_TRAIN = 0.5 * (2.0 * np.pi / _OMEGA) # half-period ≈ 0.707 s
EPOCHS_TRAIN = 3_000
dom_pinn = jno.domain.rect(mesh_size=0.05, time=(0, _T_MAX_TRAIN, 6))
x_p, y_p, t_p = dom_pinn.variable("interior")
x0_p, y0_p, t0_p = dom_pinn.variable("initial")
xb_p, yb_p, tb_p = dom_pinn.variable("boundary")
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1,
coord_dim=2,
n_outputs=1,
n_layers=4,
basis_functions=64,
hidden_dim=48,
activation=jax.nn.tanh,
key=KEY,
)
)
net.optimizer(optax.adam(optax.warmup_cosine_decay_schedule(1e-6, 1e-3, 100, EPOCHS_TRAIN, 1e-6)))
xy_p = jno.np.concat([x_p, y_p])
xy0_p = jno.np.concat([x0_p, y0_p])
u_p = net(t_p, xy_p).scalar.bind(x=x_p, y=y_p, t=t_p)
u0_p = net(t0_p, xy0_p).scalar.bind(x=x0_p, y=y0_p, t=t0_p)
pde_wave = u_p.tt - _C**2 * (u_p.xx + u_p.yy)
ic_disp = net(t0_p, xy0_p) - jno.np.sin(π * x0_p) * jno.np.sin(π * y0_p)
ic_vel = u0_p.t # u_t(t=0) = 0 for this solution
bc_wall = net(tb_p, jno.np.concat([xb_p, yb_p]))
crux_pinn = jno.core([pde_wave.mse, ic_disp.mse, ic_vel.mse, bc_wall.mse], domain=dom_pinn)
crux_pinn.solve(EPOCHS_TRAIN)
_u_p, _u_exact_p = crux_pinn.eval(
[
u_p,
jno.np.cos(_OMEGA * t_p) * jno.np.sin(π * x_p) * jno.np.sin(π * y_p),
]
)
rel_l2 = float(jax.numpy.linalg.norm(_u_p - _u_exact_p) / (jax.numpy.linalg.norm(_u_exact_p) + 1e-8))
print(f"Stage 1 Relative L2 error: {rel_l2:.4e}")
# ═══════════════════════════════════════════════════════════════════════════════
# Stage 2 — FieldView audit of the TRAINED network's prediction
# Evaluate the trained DeepONet on a structured grid and re-check the wave
# equation with finite differences on ITS output — not on an analytic field.
# ═══════════════════════════════════════════════════════════════════════════════
T, H = 20, 32
dom_fd = jno.domain.equi_distant_rect(nx=H - 1, ny=H - 1, time=(0.0, _T_MAX_TRAIN, T))
xg, yg, tg = dom_fd.variable("interior")
xyg = jno.np.concat([xg, yg])
# .field makes x, y, t grid axes: derivatives of the net's grid output use FD, not
# AD. This is a discretisation-level audit of the same trained weights.
u = net(tg, xyg).field.bind(x=xg, y=yg, t=tg)
wave_res = u.tt - _C**2 * (u.xx + u.yy) # second-order temporal + spatial FD
# Evaluate through the TRAINED crux on the FD grid; min_consecutive=T holds every
# frame so the nested u.tt temporal stencil has its neighbours.
analytic = jno.np.cos(_OMEGA * tg) * jno.np.sin(π * xg) * jno.np.sin(π * yg)
pred, exact = crux_pinn.eval([net(tg, xyg), analytic], domain=dom_fd, min_consecutive=T)
pred, exact = np.asarray(pred), np.asarray(exact)
rel_l2_fd = float(np.linalg.norm(pred - exact) / (np.linalg.norm(exact) + 1e-8))
res_mse = float(np.mean(np.asarray(crux_pinn.eval(wave_res.expr.mse, domain=dom_fd, min_consecutive=T))))
print("Stage 2 FieldView audit of the trained network's prediction")
print(f" Relative L2 vs analytic (grid) : {rel_l2_fd:.3e}")
print(f" FD wave residual u_tt − c²Δu (MSE): {res_mse:.3e}")
# ── Tolerance checks ──────────────────────────────────────────────────────────
# Wave-PINN accuracy varies run-to-run (ill-conditioned + XLA autotune/compile
# cache), so the L2 guards are loose — they confirm the trained net beats a trivial
# predictor. The FD residual audits how well that prediction satisfies the wave
# equation on the grid. Final accuracy / 180 s timing should be confirmed on GPU.
assert np.isfinite(rel_l2) and rel_l2 < 1.0, f"Stage 1 PINN diverged (rel_l2={rel_l2:.3e})"
assert np.isfinite(rel_l2_fd) and rel_l2_fd < 1.0, f"prediction no better than zero (rel_l2={rel_l2_fd:.3e})"
assert np.isfinite(res_mse) and res_mse < 1.0, f"FD wave residual not sane (mse={res_mse:.3e})"