Skip to content

Coupled Parabolic 2D

This example takes the coupled-field idea into a transient setting.

Problem Setup

The script solves two time-dependent PDEs with cross-coupling terms and manufactured transient reference solutions.

Step 1: Build Two Time-Dependent Fields

Both unknowns depend on space and time, so the sampled domain and constraint set are larger than in the stationary case.

T_end = 1.0

domain = jno.domain(
    constructor=jno.domain.rect(mesh_size=0.05),
    time=(0, T_end, 10),
)
x, y, t   = domain.variable("interior")
x0, y0, t0 = domain.variable("initial")

u_exact = jno.np.exp(-t) * jno.np.sin(pi * x) * jno.np.sin(pi * y)
v_exact = jno.np.exp(-t) * jno.np.sin(2 * pi * x) * jno.np.sin(pi * y)
f = (2 * pi**2 - 1) * u_exact + v_exact
g = (5 * pi**2 - 1) * v_exact + u_exact

u_net = jno.nn.wrap(foundax.deeponet(n_sensors=1, coord_dim=2, n_outputs=1,
                                      n_layers=5, basis_functions=96, hidden_dim=64,
                                      key=jax.random.PRNGKey(24)))
v_net = jno.nn.wrap(foundax.deeponet(n_sensors=1, coord_dim=2, n_outputs=1,
                                      n_layers=5, basis_functions=96, hidden_dim=64,
                                      key=jax.random.PRNGKey(25)))
for net in [u_net, v_net]:
    net.optimizer(optax.adam(optax.warmup_cosine_decay_schedule(...)))

xy  = jno.np.concat([x, y])
xy0 = jno.np.concat([x0, y0])
u   = u_net(t,  xy)  * x  * (1 - x)  * y  * (1 - y)
v   = v_net(t,  xy)  * x  * (1 - x)  * y  * (1 - y)
u0  = u_net(t0, xy0) * x0 * (1 - x0) * y0 * (1 - y0)
v0  = v_net(t0, xy0) * x0 * (1 - x0) * y0 * (1 - y0)

Step 2: Add Initial Conditions for Both Fields

Each unknown needs its own initial condition in addition to the coupled PDE residuals.

ini_u = u0 - jno.np.sin(pi * x0) * jno.np.sin(pi * y0)
ini_v = v0 - jno.np.sin(2 * pi * x0) * jno.np.sin(pi * y0)

Step 3: Train the System Jointly

All losses are optimized together so the two models remain consistent with each other and with the data.

pde_u = u.d(t) - jno.np.laplacian(u, [x, y]) + v - f
pde_v = v.d(t) - jno.np.laplacian(v, [x, y]) + u - g

crux    = jno.core([pde_u.mse, pde_v.mse, ini_u.mse, ini_v.mse])
history = crux.solve(40_000)

_u, _u_exact, _v, _v_exact = crux.eval([u, u_exact, v, v_exact])

What To Notice

  • Coupling and time dependence can be combined cleanly in one workflow.
  • The same hard-boundary ideas can be reused for both unknowns.
  • This is a good reference for multi-field transient PDEs.

Script Snippet

"""05 — Coupled parabolic system in 2-D"""

from pathlib import Path

import foundax
import jax
import optax

import jno

π = jno.np.pi
T_end = 1.0

domain = jno.domain.rect(mesh_size=0.05, time=(0, T_end, 4))
x, y, t = domain.variable("interior")
x0, y0, t0 = domain.variable("initial")

u_exact = jno.np.exp(-t) * jno.np.sin(π * x) * jno.np.sin(π * y)
v_exact = jno.np.exp(-t) * jno.np.sin(2 * π * x) * jno.np.sin(π * y)
f = (2 * π**2 - 1) * u_exact + v_exact
g = (5 * π**2 - 1) * v_exact + u_exact


def _net(key: int):
    n = jno.nn.wrap(
        foundax.deeponet(
            n_sensors=1,
            coord_dim=2,
            n_outputs=1,
            n_layers=3,
            basis_functions=48,
            hidden_dim=32,
            key=jax.random.PRNGKey(key),
        )
    )
    n.optimizer(optax.adam(optax.warmup_cosine_decay_schedule(0.0, 1e-3, 50, 5000, 1e-5)))
    return n


u_net, v_net = _net(24), _net(25)

xy = jno.np.concat([x, y])
xy0 = jno.np.concat([x0, y0])
ansatz = lambda raw, xa, ya: raw * xa * (1 - xa) * ya * (1 - ya)  # noqa: E731
u = ansatz(u_net(t, xy), x, y).scalar.bind(x=x, y=y, t=t)
v = ansatz(v_net(t, xy), x, y).scalar.bind(x=x, y=y, t=t)
u0 = ansatz(u_net(t0, xy0), x0, y0)
v0 = ansatz(v_net(t0, xy0), x0, y0)

pde_u = u.t - (u.xx + u.yy) + v - f
pde_v = v.t - (v.xx + v.yy) + u - g
ini_u = u0 - jno.np.sin(π * x0) * jno.np.sin(π * y0)
ini_v = v0 - jno.np.sin(2 * π * x0) * jno.np.sin(π * y0)

crux = jno.core([pde_u.mse, pde_v.mse, ini_u.mse, ini_v.mse])
crux.solve(5_000)

_u, _u_exact, _v, _v_exact = crux.eval([u, u_exact, v, v_exact])
rel_l2_u = float(jax.numpy.linalg.norm(_u - _u_exact) / (jax.numpy.linalg.norm(_u_exact) + 1e-8))
rel_l2_v = float(jax.numpy.linalg.norm(_v - _v_exact) / (jax.numpy.linalg.norm(_v_exact) + 1e-8))
print(f"u rel_L2 = {rel_l2_u:.4e}    v rel_L2 = {rel_l2_v:.4e}")

results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as out:
    out.write(
        f"05_coupled_and_inverse/coupled_parabolic_2d.py | epochs=5000"
        f" | rel_L2_u={rel_l2_u:.6e} | rel_L2_v={rel_l2_v:.6e}\n"
    )

assert rel_l2_u < 2e-1, f"u relative L2 error too large: {rel_l2_u:.3e}"
assert rel_l2_v < 2e-1, f"v relative L2 error too large: {rel_l2_v:.3e}"