Fredholm Equation with Non-Separable Kernel
This example solves a Fredholm integral equation of the second kind whose kernel depends on both the evaluation point and the integration dummy simultaneously. It requires the .integrate(var=x) API introduced for non-separable kernels.
Equation
Exact solution: \(u^*(x) = \sin(\pi x)\)
Derivation of f
Why this requires .integrate(var=x)
The kernel \(K(x,t) = x + t\) is non-separable in x: for a fixed collocation point \(x_i\), the integrand \((x_i + t)\,u(t)\) is a different function of \(t\) for every \(x_i\). The result \(\int_0^1 (x+t)\,u(t)\,dt\) is therefore an \((N,1)\) array — a function of \(x\) — not a scalar.
With the separable trick (previous tutorial), you would split \(K = x\cdot 1 + 1\cdot t\) and compute two independent scalar integrals. With .integrate(var=x), you write the kernel directly and jno handles the vectorisation via jax.vmap.
Step 1: Two variables from the same domain call
x, _ = domain.variable("interior") # outer collocation variable
t, _ = domain.variable("interior") # inner integration dummy — no flag needed!
Both x and t point to the same mesh, but they are distinct Python objects. The evaluator uses their object identity to decide which one to keep fixed (the one passed to var=) and which one to sweep (everything else).
Step 2: The network appears at both roles
u_x = net(x) # evaluated at the N collocation points — what we want to learn
u_t = net(t) # same weights, evaluated at the N integration points
The same trained weights power both evaluations. Gradients flow through both u_x and the integral over u_t simultaneously.
Step 3: Form the non-separable integral
var=x declares x as the outer variable. The evaluator:
- Fixes
xat each collocation point viajax.vmap. - Evaluates \((x_i + t)\,u(t)\) over all mesh points for
t. - Returns a weighted sum per outer point — shape
(N, 1).
The shape matches u_x and f, so the residual is formed naturally:
Step 4: Solve
Chaining two integrals (bonus)
Because .integrate(var=x) returns a standard (N, 1) placeholder, you can chain a second .integrate() on top to reduce it to a scalar. For example, the iterated double integral
can be verified in jno without any network:
x, _ = domain.variable("interior")
t, _ = domain.variable("interior")
inner = (x + t).integrate(var=x) # (N, 1): g(x) = x + 0.5
result = inner.integrate() # scalar: ∫₀¹ g(x) dx = 1.0
The inner call sweeps t; the outer scalar call then integrates g(x) over x.
What to notice
- No flag on
domain.variable(). The only API change isvar=xon.integrate(). - Object identity distinguishes roles.
xandtare the same type with the same tag; what makesxthe outer is that you pass it tovar=. - N² network evaluations per step. For each of the N collocation points, the integrand is evaluated at all N integration points. Keep the mesh coarse (or use a small MLP) to control cost.
jax.vmap+ JIT compiles this to an efficient batched kernel. - Relative L2 error < 10 % is achieved here with only 21 interior points and 30 000 steps.
Script
"""06 — Fredholm integral equation with non-separable kernel"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import optax
import jno
π = jno.np.pi
# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior") # outer collocation variable
t, _ = domain.variable("interior") # inner integration dummy
domain.summary()
# ── Forcing term f(x) = sin(πx) − 2x/π − 1/π ────────────────────────────────
pi_val = float(jnp.pi)
f = jno.np.sin(π * x) - 2.0 * x / pi_val - 1.0 / pi_val
# ── Model ──────────────────────────────────────────────────────────────────────
net = jno.nn.wrap(
foundax.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
activation=jax.nn.tanh,
key=jax.random.PRNGKey(0),
)
)
net.optimizer(
optax.adam(
optax.exponential_decay(
init_value=1e-3,
transition_steps=5_000,
decay_rate=0.5,
end_value=1e-5,
)
)
)
u_x = net(x) # network evaluated at collocation points (N, 1)
u_t = net(t) # same network, evaluated at integration points (N, 1)
# ── Non-separable Fredholm residual ───────────────────────────────────────────
# ∫₀¹ (x + t) · u(t) dt — result is (N, 1): depends on x, not a scalar.
# var=x tells the evaluator: keep x fixed, sweep t over the full mesh.
integral_term = ((x + t) * u_t).integrate(var=x)
residual = u_x - f - integral_term
# ── Solve ──────────────────────────────────────────────────────────────────────
EPOCHS = 30_000
crux = jno.core([residual.mse]).print_shapes()
_history = crux.solve(EPOCHS)
# ── Evaluate ───────────────────────────────────────────────────────────────────
u_exact = jno.np.sin(π * x)
u_pred, u_ref = crux.eval([u_x, u_exact])
rel_l2 = float(jnp.linalg.norm(u_pred - u_ref) / (jnp.linalg.norm(u_ref) + 1e-8))
print(f"Relative L2 error: {rel_l2:.4e} (exact solution: u(x) = sin(πx))")
# ── Record result ──────────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f_out:
f_out.write(f"06_integration/fredholm_nonseparable.py | epochs={EPOCHS} | rel_L2={rel_l2:.6e}\n")
assert rel_l2 < 0.10, f"Relative L2 error too large: {rel_l2:.3e}"