Variable-Coefficient Poisson 2D
Same square geometry as the constant-coefficient example, but with a spatially varying conductivity field.
Problem Setup
with kappa(x,y) = 1 + x + y and exact solution sin(pi x) sin(pi y).
Step 1 — Domain and coefficient field
kappa is built directly from the sampled coordinates, so the PDE coefficients vary pointwise across the domain.
domain = jno.domain(box(0, 0, 1, 1), mesh_size=0.05)
x, y, _ = domain.variable("interior")
κ = 1 + x + y
u_exact = jno.np.sin(π * x) * jno.np.sin(π * y)
Step 2 — Build the flux, then take its divergence
Rather than writing Delta u, the script forms the flux vector kappa · grad u directly and takes its divergence. The x(1 - x) y(1 - y) factor on net(x, y) enforces the homogeneous Dirichlet BC hard.
net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(13)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 1000, 0.5, end_value=1e-5)))
# Multiplicative hard BC — the x(1-x)y(1-y) factor's derivatives flow through.
u = (net(x, y) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y)
# −∂_x(κ ∂_x u) − ∂_y(κ ∂_y u) = forcing. Named-partial syntax reads
# component-by-component as if from the math, including the product rule.
pde = -((κ * u.x).d(x) + (κ * u.y).d(y)) - (
2 * π**2 * κ * u_exact - π * jno.np.cos(π * x) * jno.np.sin(π * y) - π * jno.np.sin(π * x) * jno.np.cos(π * y)
)
Step 3 — Train
The residual → core → solve workflow is unchanged from the constant-coefficient case; only the forcing term and the spatially varying kappa differ.
grad_norms = jno.trackers.gradient_norms(interval=500)
crux = jno.core([pde.mse])
crux.solve(5_000, callbacks=[grad_norms])
What to notice
u.grad(x, y)returns aVectorView; multiplying by the scalarkappapreserves the view type (Placeholder × VectorView → VectorView), so the chainkappa * u.grad(x, y) → .div(x, y)reads exactly like the math∇·(κ∇u).- Hard BC enforcement via the multiplicative ansatz keeps the loss focused on the PDE residual alone — no boundary-loss weighting to tune.
- Variable coefficients are a common bridge from toy PDEs to physically meaningful media.
Full script
"""02 — 2-D variable-coefficient Poisson (shapely domain + named partials + tracker)"""
from pathlib import Path
import foundax
import jax
import optax
from shapely.geometry import box
import jno
π = jno.np.pi
domain = jno.domain(box(0, 0, 1, 1), mesh_size=0.05)
x, y, _ = domain.variable("interior")
κ = 1 + x + y
u_exact = jno.np.sin(π * x) * jno.np.sin(π * y)
net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(13)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 1000, 0.5, end_value=1e-5)))
# Multiplicative hard BC — the x(1-x)y(1-y) factor's derivatives flow through.
u = (net(x, y) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y)
# −∂_x(κ ∂_x u) − ∂_y(κ ∂_y u) = forcing. Named-partial syntax reads
# component-by-component as if from the math, including the product rule.
pde = -((κ * u.x).d(x) + (κ * u.y).d(y)) - (
2 * π**2 * κ * u_exact - π * jno.np.cos(π * x) * jno.np.sin(π * y) - π * jno.np.sin(π * x) * jno.np.cos(π * y)
)
grad_norms = jno.trackers.gradient_norms(interval=500)
crux = jno.core([pde.mse])
crux.solve(5_000, callbacks=[grad_norms])
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
if grad_norms.value is not None:
print(f"Final ∇L norm: {grad_norms.value['norms']}")
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(f"02_elliptic/variable_coefficient_poisson_2d.py | epochs=5000 | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"