Anisotropic Poisson 2D
This example modifies Poisson's equation so diffusion acts with different strength in the horizontal and vertical directions.
Problem Setup
with exact solution u(x,y) = sin(pi x) sin(pi y) and coefficients a = 1, b = 3.
Step 1: Set Physical Coefficients
The script introduces separate constants a and b before building the residual. This is the simplest way to encode directional anisotropy.
Step 2: Create the Unit-Square Domain
Interior points are sampled on a rectangular domain and used to evaluate both the model and the manufactured forcing.
domain = jno.domain.rect(mesh_size=0.1)
x, y, _ = domain.variable("interior")
u_exact = jno.np.sin(pi * x) * jno.np.sin(pi * y)
forcing = (a + b) * pi**2 * u_exact
Step 3: Impose Boundary Conditions Hard
The model output is multiplied by x(1-x)y(1-y), so the field is zero on all four edges without an additional boundary loss.
net = jno.nn.wrap(
foundax.mlp(in_features=2, hidden_dims=64, num_layers=5,
activation=jax.nn.tanh, key=jax.random.PRNGKey(12))
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 80, 0.5, end_value=1e-5)))
u = net(x, y) * x * (1 - x) * y * (1 - y)
Step 4: Assemble an Anisotropic Residual
The residual uses weighted second derivatives in x and y, which is the main distinction from isotropic Poisson.
Step 5: Solve and Visualize
The script tracks error against the exact solution and plots exact, predicted, and absolute-error fields.
crux = jno.core([pde.mse])
history = crux.solve(40_000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
What To Notice
- Anisotropy is often the first step beyond textbook Poisson problems.
- The only major PDE change is the weighted curvature in each coordinate direction.
- This pattern extends naturally to diffusion tensors and heterogeneous media.
Script Snippet
"""02 — 2-D anisotropic Poisson equation"""
from pathlib import Path
import foundax
import jax
import optax
from shapely.geometry import box
import jno
π = jno.np.pi
a, b = 1.0, 3.0
domain = jno.domain(box(0, 0, 1, 1), mesh_size=0.1)
x, y, _ = domain.variable("interior")
u_exact = jno.np.sin(π * x) * jno.np.sin(π * y)
forcing = (a + b) * π**2 * u_exact
net = jno.nn.wrap(foundax.mlp(in_features=2, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(12)))
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 1000, 0.5, end_value=1e-5)))
u = (net(x, y) * x * (1 - x) * y * (1 - y)).scalar.bind(x=x, y=y)
pde = -(a * u.xx + b * u.yy) - forcing
crux = jno.core([pde.mse])
crux.solve(5000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}")
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(f"02_elliptic/anisotropic_poisson_2d.py | epochs=5000 | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"