Integral Constraints and Flux Monitoring (2-D)
This example solves the 2-D Poisson equation and shows two uses of .integrate() that are not possible with pointwise losses like .mse:
- Tracking a physical observable — the volume mean
∫_Ω u dAis logged throughout training without entering the gradient. - Soft integral constraint — the same quantity can be added to the loss to accelerate convergence when the PDE residual alone is slow to pin the solution's magnitude.
Problem Setup
Exact solution: u(x,y) = sin(πx) sin(πy)
The exact volume mean is ∫₀¹∫₀¹ sin(πx)sin(πy) dxdy = 4/π² ≈ 0.405.
Why integrals matter here
The Dirichlet condition forces u = 0 on the entire boundary. A network that memorises the boundary data without learning the interior would give u ≈ 0 everywhere — satisfying the BC but not the PDE. Tracking ∫_Ω u dA during training immediately reveals this failure mode: if the integral stays near zero, the network has not learned the interior peak.
Step 1: Define volume and boundary variables
x, y, _ = domain.variable("interior") # volume points
x_b, y_b, _ = domain.variable("boundary") # boundary points
Step 2: Hard-enforce the Dirichlet BC
The model output is multiplied by x(1−x)y(1−y), which is zero on all four edges. The network only needs to learn the interior shape.
Step 3: Add the PDE residual
Step 4: Track the volume integral
from jno.numpy import tracker
TARGET = 4.0 / jnp.pi ** 2 # ≈ 0.405
vol_mean = tracker(u.integrate(), interval=200)
tracker(...) wraps any scalar expression so that it is evaluated and logged every interval epochs but does not contribute to the gradient.
Step 5: Optionally add an integral constraint
To enforce the prescribed mean in the loss itself, add:
Without the constraint the PDE residual alone is usually sufficient. With it, the optimizer receives a direct signal about the solution's global magnitude, which can be useful when the interior is poorly sampled or the PDE residual gradient is small.
Step 6: Solve
What to notice
.integrate()returns a scalar placeholder — it can appear anywhere a regular loss term can.- The region (volume vs boundary) is inferred automatically from the variable's tag; no extra argument is needed.
- Integration weights are precomputed once at domain creation. They are embedded as JAX constants and reused across all training steps, so adding an integral term carries negligible runtime cost.
- Because
.integrate()is differentiable,jax.gradandeqx.filter_gradwork through it without modification.
Flux integrals (extension)
If you also want to monitor the outward heat flux through the boundary, request normals and compute F·n by hand:
x_b, y_b, _, nx, ny = domain.variable("boundary", normals=True)
u_b = net(jno.np.concat([x_b, y_b], axis=-1)) * x_b * (1 - x_b) * y_b * (1 - y_b)
# ∫_∂Ω ∂u/∂n ds — should equal ∫_Ω ∇²u dV = −∫_Ω f dV by Green's identity
outward_flux = (u_b.d(x_b) * nx + u_b.d(y_b) * ny).integrate()
flux_tracker = tracker(outward_flux, interval=500)
jno does not dot F with n automatically — you specify the integrand explicitly.
Script
"""06 — Integral constraints and flux monitoring (2-D Poisson)"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import optax
from jno.numpy import tracker
from shapely.geometry import box
import jno
π = jno.np.pi
# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain(box(0, 0, 1, 1), mesh_size=0.05)
x, y, _ = domain.variable("interior")
x_b, y_b, _ = domain.variable("boundary")
domain.summary()
# ── Analytic forcing and boundary data ─────────────────────────────────────────
forcing = 2 * π**2 * jno.np.sin(π * x) * jno.np.sin(π * y)
u_exact_b = jno.np.sin(π * x_b) * jno.np.sin(π * y_b) # = 0 on ∂Ω
# Exact volume integral: ∫₀¹∫₀¹ sin(πx)sin(πy) dxdy = (2/π)² = 4/π²
TARGET_INTEGRAL = 4.0 / float(jnp.pi) ** 2 # ≈ 0.4053
# ── Model ──────────────────────────────────────────────────────────────────────
net = jno.nn.wrap(
foundax.mlp(
in_features=2,
hidden_dims=64,
num_layers=4,
activation=jax.nn.tanh,
key=jax.random.PRNGKey(0),
)
)
net.optimizer(
optax.adam(
optax.exponential_decay(
init_value=1e-3,
transition_steps=3000,
decay_rate=0.5,
end_value=1e-5,
)
)
)
# Hard-enforce u=0 on ∂Ω by multiplying by x(1-x)y(1-y).
# The network then only needs to learn the interior shape.
u = (net(jno.np.concat([x, y], axis=-1)) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y)
# ── Losses ─────────────────────────────────────────────────────────────────────
# Standard PDE residual
pde = -(u.xx + u.yy) - forcing
# Volume-mean tracker — logged every 200 epochs, does not enter the gradient.
# After convergence this should approach TARGET_INTEGRAL ≈ 0.405.
vol_mean = tracker(u.integrate(), interval=200)
# Optional soft integral constraint — uncomment to add it to the loss.
# This can accelerate convergence when the PDE residual alone is slow to
# pin the global magnitude of the solution.
#
# integral_loss = (u.integrate() - TARGET_INTEGRAL).square()
# losses = [pde.mse, integral_loss, vol_mean]
losses = [pde.mse, vol_mean]
# ── Solve ──────────────────────────────────────────────────────────────────────
EPOCHS = 30_000
crux = jno.core(losses).print_shapes()
_history = crux.solve(EPOCHS)
# ── Evaluate ───────────────────────────────────────────────────────────────────
u_pred, u_ref = crux.eval([u, jno.np.sin(π * x) * jno.np.sin(π * y)])
rel_l2 = float(jnp.linalg.norm(u_pred - u_ref) / (jnp.linalg.norm(u_ref) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
print(f"Target integral: {TARGET_INTEGRAL:.6f}")
# ── Record result ──────────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(
f"06_integration/flux_conservation_2d.py | epochs={EPOCHS} "
f"| rel_L2={rel_l2:.6e} | target_integral={TARGET_INTEGRAL:.6f}\n"
)
assert rel_l2 < 0.15, f"Relative L2 error too large: {rel_l2:.3e}"