U-Net 2D — supervised Poisson 2D
Same problem and dataset as the FNO tutorial, different architecture: a convolutional encoder-decoder with skip connections. Data-driven supervised learning, just like FNO.
Problem Setup
Solve −∇²u = f on [0,1]² with u = 0 on the boundary, learning the operator f → u from precomputed (f, u) pairs.
Step 1: Generate the Dataset
from create_domain import build_domain_from_arrays, generate_poisson_data
GRID, SAMPLES = 16, 20
forcing, solution = generate_poisson_data(SAMPLES, GRID, n_modes=5, alpha=1.5, seed=42)
domain = build_domain_from_arrays(forcing, solution, GRID)
_f = domain.variable("_f")
_u = domain.variable("_u")
Identical to the FNO setup — comparing the two architectures cleanly requires the same data.
Step 2: U-Net Architecture
u = jno.nn.wrap(
foundax.unet2d(
in_channels=1,
out_channels=1,
depth=2, # number of down-sampling stages
wf=4, # width factor: hidden channels = 2^wf
norm="layer",
up_mode="upconv",
padding_mode="reflect", # mirror boundaries — better than zero-padding for Dirichlet
key=jax.random.PRNGKey(0),
)
)
u.optimizer(optax.chain(
optax.clip_by_global_norm(1e-3),
optax.adamw(optax.cosine_decay_schedule(5e-4, 50, alpha=1e-7 / 5e-4), weight_decay=1e-6),
))
Each encoder stage downsamples and doubles the channel count; each decoder stage upsamples and concatenates the corresponding encoder feature map (the "skip connection"). The output has the same spatial resolution as the input, with out_channels=1 for a scalar field.
Step 3: Supervised Loss + Solve
_f[0, ...] drops a leading singleton axis because U-Net expects (B, H, W, C) — minor shape plumbing that's worth knowing about when porting between operator architectures.
What To Notice
- Multiscale by construction. The encoder-decoder topology captures features at multiple resolutions, which helps when the solution operator combines local detail (sharp features in
f) with global structure (Poisson's elliptic smoothing). padding_mode="reflect"is preferred over zero-padding (the default in many CNN libraries) for non-periodic Dirichlet BCs — mirroring keeps the boundary values consistent with the underlying physics. For periodic problems, usepadding_mode="circular".- Compared to FNO: U-Net is convolutional and local; FNO is spectral and global. FNO converges faster on smooth elliptic problems; U-Net generalises better when the solution operator has sharp features or when the input resolution differs from training.
Script Snippet
"""11 — U-Net 2D for Poisson operator learning"""
import foundax
import jax
import optax
from create_domain import build_domain_from_arrays, generate_poisson_data
import jno
KEY = jax.random.PRNGKey(0)
GRID = 16
SAMPLES = 20
EPOCHS = 50
BATCH = 10
# ── Dataset ──────────────────────────────────────────────────────────────────
forcing, solution = generate_poisson_data(SAMPLES, GRID, n_modes=5, alpha=1.5, seed=42)
domain = build_domain_from_arrays(forcing, solution, GRID)
_f = domain.variable("_f")
_u = domain.variable("_u")
# ── Model ────────────────────────────────────────────────────────────────────
u = jno.nn.wrap(
foundax.unet2d(
in_channels=1,
out_channels=1,
depth=2,
wf=4,
norm="layer",
up_mode="upconv",
padding_mode="reflect",
key=KEY,
)
)
u.optimizer(
optax.chain(
optax.clip_by_global_norm(1e-3),
optax.adamw(optax.cosine_decay_schedule(5e-4, EPOCHS, alpha=1e-7 / 5e-4), weight_decay=1e-6),
)
)
# ── Supervised loss + solve ─────────────────────────────────────────────────
# u(_f[0, ...]) drops a leading singleton axis because UNet expects (B, H, W, C).
crux = jno.core([(_u - u(_f[0, ...])).mse])
crux.solve(epochs=EPOCHS, batchsize=BATCH)