Integro-Differential Equation
This example solves an integro-differential equation (IDE) — an equation where the unknown appears under both a derivative and an integral at the same time. It shows how .d(x) and .integrate() compose naturally in the same residual.
Equation
Exact solution: \(u^*(x) = \sin(\pi x)\)
Derivation of g
Why this is different
In a standard PINN the residual involves only pointwise quantities — derivatives at \(x\). Here the residual also includes \(\int_0^1 u(t)\,dt\), which is a scalar that couples the solution at every mesh point.
jno handles this transparently: .integrate() returns a scalar placeholder that flows through the same computation graph as .d(x). Both appear in the same MSE loss with no extra bookkeeping.
Hard boundary condition
The Dirichlet condition \(u(0) = 0\) is enforced by the network ansatz
Multiplying by \(x\) forces \(u(0) = 0\) for any weight configuration, so the optimizer never needs to "discover" the boundary condition — it is automatically satisfied throughout training.
Building the residual
x, _ = domain.variable("interior")
u = net(x) * x # hard BC: u(0) = 0
C = u.integrate() # scalar: ∫₀¹ u(t) dt
du = u.d(x) # (N, 1): u'(x) at every collocation point
residual = du + u - g - C # IDE residual
C is a scalar that is the same for every row — JAX broadcasting adds it to the (N, 1) arrays du, u, and g without any manual reshaping.
Shape summary
| Expression | Shape | Note |
|---|---|---|
u |
(N, 1) |
network output |
u.d(x) |
(N, 1) |
pointwise derivative |
u.integrate() |
scalar | integral over all mesh points |
g |
(N, 1) |
forcing term |
residual |
(N, 1) |
broadcast scalar + vectors |
Step-by-step
Step 1 — Domain and variable
Step 2 — Hard-BC ansatz
Step 3 — Compose operators
Finite-difference derivative on the integration mesh
The integration uses the mesh nodes as quadrature points. Since the same mesh is already available, swapping the pointwise derivative to a finite-difference scheme avoids the autodiff tape and runs noticeably faster on dense 1-D meshes:
For 2-D and 3-D meshes you'd also need compute_mesh_connectivity=True on the domain (the 1-D case here precomputes connectivity by default). FD and autodiff agree to within mesh-resolution error; pick FD when the per-step gradient tape dominates training time.
Step 4 — Form and solve
What to notice
.integrate()returns a scalar here — novar=argument, because there is no outer collocation variable to hold fixed. The result is a single number representing \(\int_0^1 u(t)\,dt\)..d(x)and.integrate()compose — both producePlaceholderobjects that participate in the same expression graph.- Gradients flow through both operators — the loss is differentiable with respect to the network weights simultaneously through the derivative term and the integral term.
- Relative L2 error < 10% is achieved with 21 interior points and 30 000 steps.
Contrast with the Fredholm tutorials
| Feature | Fredholm separable | Fredholm non-separable | IDE (this example) |
|---|---|---|---|
| Integral result | scalar | (N, 1) |
scalar |
.integrate(var=...) |
no | yes | no |
| Derivative in residual | no | no | yes |
| Boundary condition | none | none | hard BC ansatz |
Script
"""06 — Integro-Differential Equation (IDE)"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import optax
import jno
π = jno.np.pi
# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")
domain.summary()
# ── Forcing term g(x) = π cos(πx) + sin(πx) − 2/π ──────────────────────────
pi_val = float(jnp.pi)
g = π * jno.np.cos(π * x) + jno.np.sin(π * x) - 2.0 / pi_val
# ── Model ──────────────────────────────────────────────────────────────────────
net = jno.nn.wrap(
foundax.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
activation=jax.nn.tanh,
key=jax.random.PRNGKey(0),
)
)
net.optimizer(
optax.adam(
optax.exponential_decay(
init_value=1e-3,
transition_steps=5_000,
decay_rate=0.5,
end_value=1e-5,
)
)
)
# Hard boundary condition: u(0) = 0 — multiplying by x enforces it for all x
u = net(x) * x
# ── IDE residual ──────────────────────────────────────────────────────────────
# ∫₀¹ u(t) dt is a scalar (not a function of x).
# u.d(x) is the pointwise derivative — both shapes are (N, 1).
C = u.integrate() # scalar: ∫₀¹ u(t) dt
du = u.d(x) # (N, 1): u'(x) at every collocation point
residual = du + u - g - C # IDE residual at every collocation point
# ── Solve ──────────────────────────────────────────────────────────────────────
EPOCHS = 30_000
crux = jno.core([residual.mse]).print_shapes()
_history = crux.solve(EPOCHS)
# ── Evaluate ───────────────────────────────────────────────────────────────────
u_exact = jno.np.sin(π * x)
u_pred, u_ref = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(u_pred - u_ref) / (jnp.linalg.norm(u_ref) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e} (exact solution: u(x) = sin(πx))")
# ── Record result ──────────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f_out:
f_out.write(f"06_integration/integro_differential.py | epochs={EPOCHS} | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 0.10, f"Relative L2 error too large: {rel_l2:.3e}"