Biharmonic 1D
Raises the PDE order. Solves a fourth-order biharmonic equation with clamped boundary conditions (both u and u' vanish at the endpoints) using a hard-enforced ansatz.
Problem Setup
Exact solution: u(x) = sin(π x) / π⁴.
Step 1: Create the Domain
Step 2: Hard-Enforce the Clamped Boundary Conditions
net = jno.nn.wrap(
foundax.mlp(in_features=1, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(11))
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.6, end_value=1e-5)))
u = net(x) * x**2 * (1 - x) ** 2
The x²(1−x)² factor and its first derivative both vanish at x=0 and x=1, so all four clamped boundary conditions are enforced exactly by construction. The network only has to learn the interior shape that, when multiplied by x²(1−x)², gives sin(π x) / π⁴.
Step 3: Build the Fourth-Order Residual
.d2(x).d2(x) is the fourth derivative via two stacked second-derivative shortcuts — equivalent to u.d(x).d(x).d(x).d(x) but more compact, and far cleaner than four nested jno.np.grad(...) calls.
Step 4: Solve
What To Notice
- The clamped-BC ansatz
net(x) · x²(1−x)²is genuinely doing work here — the exact solutionsin(πx)/π⁴is not equal to the ansatz, so the network has to learn a non-trivial correction. - Higher-order PDEs stay readable when you use the
.d2()shortcut. Four nestedjno.np.grad(...)calls would obscure the structure. - Fourth-order accuracy is harder than second-order — expect to need a smaller mesh and longer training than
Laplace 1DorPoisson 1D.
Script Snippet
"""01 - 1-D biharmonic equation (beam-like fourth-order problem)"""
import foundax
import jax
import optax
import jno
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")
u_exact = x**2 * (1 - x) ** 2
net = jno.nn.wrap(
foundax.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
key=jax.random.PRNGKey(11),
)
)
net.optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.6, end_value=1e-5)))
u = net(x) * x**2 * (1 - x) ** 2
u_xxxx = u.d2(x).d2(x)
pde = u_xxxx - 24.0
crux = jno.core([pde.mse])
crux.solve(5000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
assert rel_l2 < 1e-1, f"relative L2 error too large: {rel_l2:.3e}"