Skip to content

Diffusion-Reaction 1D

This example solves a steady 1D PDE that combines diffusion and a linear reaction term.

Problem Setup

-u''(x) + sigma u(x) = f(x),   x in [0,1],   u(0) = u(1) = 0

with manufactured solution u(x) = sin(pi x).

Step 1: Set the Reaction Strength

The parameter sigma determines how strongly the reaction term competes with diffusion.

Step 2: Build a Line Domain and Exact Solution

The script samples the domain, builds a manufactured forcing term, and keeps the exact solution for error tracking.

Step 3: Hard-Enforce Boundary Conditions

The field is defined with an x(1-x) factor so endpoint values are zero by construction.

Step 4: Balance Diffusion and Reaction in the Residual

The residual is assembled as -u_xx + sigma u - forcing, making this a clean example of multiple physical effects in one PDE.

Step 5: Measure Error and Plot

After solving, the script prints mean absolute error and saves solution and error plots.

What To Notice

  • Increasing sigma makes the problem stiffer.
  • The overall training loop remains similar to the basic 1D examples.
  • This is a useful template before adding time dependence.

Script Snippet

"""02 — 1-D diffusion-reaction equation  (steady)

Problem
-------
    −u''(x) + σ u(x) = f(x),   x ∈ [0, 1],   u(0) = u(1) = 0

Manufactured solution
---------------------
    u(x) = sin(πx)
    f(x) = (π² + σ) sin(πx)

This tests whether the network can balance diffusion (−u'') and reaction (σu).
Large σ makes the reaction term dominant; try σ ∈ {1, 10, 100}.
"""

import jax
import jno
import jno.jnp_ops as jnn
import optax
import numpy as np
import matplotlib.pyplot as plt
from jno import LearningRateSchedule as lrs

π = jnn.pi
dire = jno.setup(__file__)

# ── Physical parameter ────────────────────────────────────────────────────────
σ = 10.0  # reaction coefficient — increase to make the problem stiffer

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain(constructor=jno.domain.line(mesh_size=0.005))
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")

# ── Manufactured solution and forcing ─────────────────────────────────────────
u_exact = jnn.sin(π * x)
forcing = (π**2 + σ) * jnn.sin(π * x)  # f = −u'exact'' + σ u_exact

# ── Network (hard BCs via x(1−x) factor) ─────────────────────────────────────
u_net = jno.nn.mlp(
    in_features=1,
    hidden_dims=64,
    num_layers=4,
    key=jax.random.PRNGKey(0),
).optimizer(optax.adam(1), lr=lrs.exponential(1e-3, 0.5, 10_000, 1e-5))

u = u_net(x) * x * (1 - x)

# ── PDE residual:  −u'' + σu − f = 0 ──────────────────────────────────────────
u_xx = jnn.grad(jnn.grad(u, x), x)
pde = -u_xx + σ * u - forcing
error = jnn.tracker((u - u_exact).mse, interval=200)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, error], domain)
history = crux.solve(10_000)
history.plot(f"{dire}/training_history.png")

# ── Evaluate ──────────────────────────────────────────────────────────────────
pts = np.array(crux.domain_data.context["interior"][0, 0, :, 0])
idx = np.argsort(pts)
xs = pts[idx]
pred = np.array(crux.eval(u)).reshape(xs.shape[0], 1)[:, 0][idx]
true = np.array(crux.eval(u_exact)).reshape(xs.shape[0], 1)[:, 0][idx]

mae = np.abs(pred - true).mean()
print(f"σ = {σ}  |  Mean absolute error: {mae:.6e}")

# ── Plot ──────────────────────────────────────────────────────────────────────
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.set_title(f"Diffusion-reaction  σ={σ}")
ax1.plot(xs, pred, label="PINN")
ax1.plot(xs, true, "--", label="exact")
ax1.set_xlabel("x")
ax1.legend()

ax2.set_title("Pointwise |error|")
ax2.plot(xs, np.abs(pred - true))
ax2.set_xlabel("x")
plt.tight_layout()
plt.savefig(f"{dire}/solution.png", dpi=150)
print(f"Saved to {dire}/")