Skip to content

Integro-Differential Equation

This example solves an integro-differential equation (IDE) — an equation where the unknown appears under both a derivative and an integral at the same time. It shows how .d(x) and .integrate() compose naturally in the same residual.

Equation

\[u'(x) + u(x) = g(x) + \int_0^1 u(t)\, dt, \qquad x \in [0,1], \quad u(0) = 0\]

Exact solution: \(u^*(x) = \sin(\pi x)\)

Derivation of g

\[u'(x) = \pi\cos(\pi x), \qquad u(x) = \sin(\pi x), \qquad \int_0^1 \sin(\pi t)\,dt = \frac{2}{\pi}\]
\[g(x) = \pi\cos(\pi x) + \sin(\pi x) - \frac{2}{\pi}\]

Why this is different

In a standard PINN the residual involves only pointwise quantities — derivatives at \(x\). Here the residual also includes \(\int_0^1 u(t)\,dt\), which is a scalar that couples the solution at every mesh point.

jno handles this transparently: .integrate() returns a scalar placeholder that flows through the same computation graph as .d(x). Both appear in the same MSE loss with no extra bookkeeping.

Hard boundary condition

The Dirichlet condition \(u(0) = 0\) is enforced by the network ansatz

\[u(x) = \text{net}(x) \cdot x\]

Multiplying by \(x\) forces \(u(0) = 0\) for any weight configuration, so the optimizer never needs to "discover" the boundary condition — it is automatically satisfied throughout training.

Building the residual

x, _ = domain.variable("interior")

u  = net(x) * x   # hard BC: u(0) = 0

C  = u.integrate()           # scalar: ∫₀¹ u(t) dt
du = u.d(x)                  # (N, 1): u'(x) at every collocation point

residual = du + u - g - C    # IDE residual

C is a scalar that is the same for every row — JAX broadcasting adds it to the (N, 1) arrays du, u, and g without any manual reshaping.

Shape summary

Expression Shape Note
u (N, 1) network output
u.d(x) (N, 1) pointwise derivative
u.integrate() scalar integral over all mesh points
g (N, 1) forcing term
residual (N, 1) broadcast scalar + vectors

Step-by-step

Step 1 — Domain and variable

domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")

Step 2 — Hard-BC ansatz

u = net(x) * x   # u(0) = 0 for any net weights

Step 3 — Compose operators

C  = u.integrate()   # scalar: feeds into every row of the residual
du = u.d(x)          # pointwise: (N, 1)

Finite-difference derivative on the integration mesh

The integration uses the mesh nodes as quadrature points. Since the same mesh is already available, swapping the pointwise derivative to a finite-difference scheme avoids the autodiff tape and runs noticeably faster on dense 1-D meshes:

du = u.d(x, scheme="finite_difference")

For 2-D and 3-D meshes you'd also need compute_mesh_connectivity=True on the domain (the 1-D case here precomputes connectivity by default). FD and autodiff agree to within mesh-resolution error; pick FD when the per-step gradient tape dominates training time.

Step 4 — Form and solve

residual = du + u - g - C
crux = jno.core([residual.mse])
crux.solve(30_000)

What to notice

  • .integrate() returns a scalar here — no var= argument, because there is no outer collocation variable to hold fixed. The result is a single number representing \(\int_0^1 u(t)\,dt\).
  • .d(x) and .integrate() compose — both produce Placeholder objects that participate in the same expression graph.
  • Gradients flow through both operators — the loss is differentiable with respect to the network weights simultaneously through the derivative term and the integral term.
  • Relative L2 error < 10% is achieved with 21 interior points and 30 000 steps.

Contrast with the Fredholm tutorials

Feature Fredholm separable Fredholm non-separable IDE (this example)
Integral result scalar (N, 1) scalar
.integrate(var=...) no yes no
Derivative in residual no no yes
Boundary condition none none hard BC ansatz

Script

"""06 — Integro-Differential Equation (IDE)"""

from pathlib import Path

import foundax
import jax
import jax.numpy as jnp
import optax

import jno

π = jno.np.pi

# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.05)

x, _ = domain.variable("interior")

domain.summary()

# ── Forcing term  g(x) = π cos(πx) + sin(πx) − 2/π ──────────────────────────
pi_val = float(jnp.pi)
g = π * jno.np.cos(π * x) + jno.np.sin(π * x) - 2.0 / pi_val

# ── Model ──────────────────────────────────────────────────────────────────────
net = jno.nn.wrap(
    foundax.mlp(
        in_features=1,
        hidden_dims=32,
        num_layers=3,
        activation=jax.nn.tanh,
        key=jax.random.PRNGKey(0),
    )
)
net.optimizer(
    optax.adam(
        optax.exponential_decay(
            init_value=1e-3,
            transition_steps=5_000,
            decay_rate=0.5,
            end_value=1e-5,
        )
    )
)

# Hard boundary condition: u(0) = 0 — multiplying by x enforces it for all x
u = net(x) * x

# ── IDE residual ──────────────────────────────────────────────────────────────
# ∫₀¹ u(t) dt is a scalar (not a function of x).
# u.d(x) is the pointwise derivative — both shapes are (N, 1).
C = u.integrate()  # scalar: ∫₀¹ u(t) dt
du = u.d(x)  # (N, 1): u'(x) at every collocation point
residual = du + u - g - C  # IDE residual at every collocation point

# ── Solve ──────────────────────────────────────────────────────────────────────
EPOCHS = 30_000
crux = jno.core([residual.mse]).print_shapes()
_history = crux.solve(EPOCHS)

# ── Evaluate ───────────────────────────────────────────────────────────────────
u_exact = jno.np.sin(π * x)
u_pred, u_ref = crux.eval([u, u_exact])

rel_l2 = float(jnp.linalg.norm(u_pred - u_ref) / (jnp.linalg.norm(u_ref) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e}   (exact solution: u(x) = sin(πx))")

# ── Record result ──────────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f_out:
    f_out.write(f"06_integration/integro_differential.py | epochs={EPOCHS} | rel_L2={rel_l2:.6e}\n")

assert rel_l2 < 0.10, f"Relative L2 error too large: {rel_l2:.3e}"