Fredholm Integral Equation of the Second Kind
This example solves a Fredholm integral equation of the second kind — an equation where the unknown function appears both outside and inside an integral. The exact solution is known, so the trained network's accuracy can be measured directly.
Equation
with forcing term chosen so that the exact solution is
Derivation of f
Substituting \(u^* = \sin(\pi x)\) into the equation:
The integral evaluates to (integration by parts):
so \(x \cdot \int_0^1 t \cdot \sin(\pi t) \, dt = x / \pi\), giving:
Why this is interesting
Integral equations cannot be solved by the standard PINN approach of differentiating pointwise residuals — there is no ODE or PDE to differentiate. The .integrate() operator fills this gap. The key insight here is that the degenerate (separable) kernel \(K(x,t) = x \cdot t\) lets the double-argument integral collapse into a simple product:
\(C\) is a scalar independent of \(x\). jno computes it with one .integrate() call over the full mesh.
Step 1: Set up the domain
The 1D mesh on \([0,1]\) doubles as the integration mesh. No separate quadrature rule is needed.
Step 2: Define the forcing term
Step 3: Build the model
net = jno.nn.wrap(
foundax.mlp(in_features=1, hidden_dims=64, num_layers=4,
activation=jax.nn.tanh, key=jax.random.PRNGKey(0))
)
u = net(x)
Step 4: Formulate the residual
# C = ∫₀¹ t · u(t) dt — scalar, the same for every x
C = (x * u).integrate()
# Pointwise residual R(xᵢ) = u(xᵢ) − f(xᵢ) − xᵢ · C
residual = u - f - x * C
(x * u).integrate() evaluates the integrand t · u(t) at all mesh nodes (here x is the dummy variable \(t\)), then reduces to a scalar using nodal volume weights. Multiplying by the outer x broadcasts that scalar back to the collocation grid.
Both uses of the network — as the integrand inside C and as the outer u — go through the same shared parameters. Gradients flow through both paths simultaneously.
Step 5: Solve
What to notice
- No boundary conditions are needed. The integral equation is posed on the interior only; boundary values emerge naturally from the trained solution.
.integrate()is differentiable.jax.gradandeqx.filter_gradpropagate through it, so the scalar \(C\) receives gradient updates alongside the pointwise residual terms.- JIT-friendly. Mesh weights are precomputed at domain creation and embedded as JAX constants. The integral adds no overhead beyond a single dot-product per step.
- Relative L2 error < 5 % is achievable with a modest 4-layer MLP and 50 000 Adam steps.
Script
"""06 — Fredholm integral equation of the second kind"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import optax
import jno
π = jno.np.pi
# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.01)
x, _ = domain.variable("interior")
domain.summary()
# ── Forcing term f(x) = sin(πx) − x/π ───────────────────────────────────────
pi_val = float(jnp.pi)
f = jno.np.sin(π * x) - x / pi_val
# ── Model ──────────────────────────────────────────────────────────────────────
net = jno.nn.wrap(
foundax.mlp(
in_features=1,
hidden_dims=64,
num_layers=4,
activation=jax.nn.tanh,
key=jax.random.PRNGKey(0),
)
)
net.optimizer(
optax.adam(
optax.exponential_decay(
init_value=1e-3,
transition_steps=5_000,
decay_rate=0.5,
end_value=1e-5,
)
)
)
u = net(x)
# ── Fredholm residual ──────────────────────────────────────────────────────────
# C = ∫₀¹ t · u(t) dt — scalar, independent of x.
# .integrate() evaluates the integrand over all mesh nodes and sums with
# nodal volume weights. Here x is the integration variable (dummy variable t).
C = (x * u).integrate()
# Pointwise residual R(xᵢ) = u(xᵢ) − f(xᵢ) − xᵢ · C
residual = u - f - x * C
# ── Solve ──────────────────────────────────────────────────────────────────────
EPOCHS = 50_000
crux = jno.core([residual.mse]).print_shapes()
_history = crux.solve(EPOCHS)
# ── Evaluate ───────────────────────────────────────────────────────────────────
u_exact = jno.np.sin(π * x)
u_pred, u_ref = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(u_pred - u_ref) / (jnp.linalg.norm(u_ref) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e} (exact solution: u(x) = sin(πx))")
# ── Record result ──────────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f_out:
f_out.write(f"06_integration/fredholm_integral_equation.py | epochs={EPOCHS} | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 0.05, f"Relative L2 error too large: {rel_l2:.3e}"