DeepONet — parametric Poisson 2D
Train a DeepONet to solve a 1-parameter family of Poisson problems via PDE-residual learning. The network sees no ground-truth solutions — only the physics — and learns the operator k → u(·) for the entire range k ∈ [0.5, 1.5].
Problem Setup
500 random k values are sampled at the start of training; the solver replicates the spatial mesh across all 500 samples and computes the residual for every (k, x, y) triple in one forward pass.
Step 1: Parametric Domain
Multiplying a domain by an integer B replicates it across B independent samples. This is the operator-learning pattern:
N_SAMPLES = 500
dom = N_SAMPLES * jno.domain.rect(mesh_size=0.05, x_range=(0, 2), y_range=(0, 1))
x, y, _ = dom.variable("interior")
k_values = jax.random.uniform(jax.random.PRNGKey(0), shape=(N_SAMPLES, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", k_values)
k_values has one scalar k per sample; attaching it as a tensor variable on the domain makes it accessible inside the symbolic expression.
Step 2: DeepONet Network
The branch input is the scalar k (a "function evaluated at one sensor"); the trunk input is the query coordinate (x, y). The output is the dot product of the two encoded vectors:
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1, # branch input dimensionality
coord_dim=2, # trunk input dimensionality
basis_functions=32,
hidden_dim=128,
activation=jax.numpy.tanh,
key=jax.random.PRNGKey(0),
)
)
net.optimizer(optax.adam(optax.cosine_decay_schedule(1e-3, 20_000, alpha=1e-5 / 1e-3)))
Step 3: Hard BCs + PDE Residual
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.d2(x) + u.d2(y)) + 1.0
The multiplicative ansatz x(2-x)y(1-y) vanishes on all four edges and enforces the homogeneous Dirichlet BC for every sample, so the boundary doesn't need a loss term.
Step 4: Solve
batchsize=32 means each gradient step uses 32 of the 500 parametric samples — a stochastic minibatch in k-space.
What To Notice
- One network, one training run, 500 PDE solutions. After convergence,
crux.eval(u)returns the solution field for every sampledkwithout any retraining. - Branch/trunk factorisation is the operator-learning interpretation of "separation of variables in parameter space". It's cheap to scale (the trunk is the same for all samples), which makes DeepONet much faster than training one PINN per
k. - Pure PDE-residual training. No solution data is supplied — the network learns from physics alone. Compare with the FNO and U-Net tutorials, which use a precomputed
(f, u)dataset.
Script Snippet
"""11 — DeepONet 2D for parametric Poisson"""
import foundax
import jax
import optax
from shapely.geometry import box
import jno
KEY = jax.random.PRNGKey(0)
N_SAMPLES = 50
EPOCHS = 2_000
# ── Parametric domain — replicate one mesh across N_SAMPLES random k values ──
dom = N_SAMPLES * jno.domain(box(0, 0, 2, 1), mesh_size=0.05)
x, y, _ = dom.variable("interior")
k_values = jax.random.uniform(KEY, shape=(N_SAMPLES, 1, 1), minval=0.5, maxval=1.5)
k = dom.variable("k", k_values)
# ── Network ──────────────────────────────────────────────────────────────────
net = jno.nn.wrap(
foundax.deeponet(
n_sensors=1, # branch input is the scalar k
coord_dim=2, # trunk input is (x, y)
basis_functions=32,
hidden_dim=128,
activation=jax.numpy.tanh,
key=KEY,
)
)
net.optimizer(optax.adam(optax.cosine_decay_schedule(1e-3, EPOCHS, alpha=1e-5 / 1e-3)))
# ── Hard BC ansatz + PDE residual ────────────────────────────────────────────
u = net(k, jno.np.concat([x, y], axis=-1)) * x * (2 - x) * y * (1 - y)
pde = k * (u.d2(x) + u.d2(y)) + 1.0
# ── Solve ────────────────────────────────────────────────────────────────────
crux = jno.core(constraints=[pde.mse])
crux.solve(epochs=EPOCHS, batchsize=32)