Surrogate Inversion
This tutorial shows how to flip a trained PINN into an inverse solver by turning its input into a trainable parameter.
The Idea
After training a forward PINN you have a differentiable surrogate \(u_\theta : x \mapsto u(x)\). The usual direction is forward: given \(x\), predict \(u\).
For inversion, we want to go the other way: given an observation \(u_\text{obs}\), find \(x^*\) such that
Because the surrogate is differentiable end-to-end, we can simply declare the query input as a jno.np.parameter and minimise the residual with gradient descent — no new machinery needed.
Problem
1D Poisson on \([0,1]\):
Exact solution: \(u(x) = \sin(\pi x)\).
Inverse task: given the measurement \(u_\text{obs} = \sin(\pi \cdot 0.3) \approx 0.809\), recover \(x^* = 0.3\).
Phase 1 — Forward Solve
Train a PINN in the usual way:
domain = jno.domain.line(mesh_size=0.01)
x, _ = domain.variable("interior")
u_net = jno.nn.wrap(foundax.mlp(in_features=1, output_dim=1, hidden_dims=32, num_layers=3, key=k1))
u_net.optimizer(optax.adam(1e-3))
u = u_net(x) * x * (1 - x) # hard zero Dirichlet BCs
pde = u.dd(x) + π**2 * jno.np.sin(π * x) # PDE residual
crux_fwd = jno.core([pde.mse])
crux_fwd.solve(3_000)
The network now acts as a cheap, differentiable surrogate of the Poisson solution.
Phase 2 — Inversion
Step 1: Freeze the Surrogate
The weights are now fixed. We will only optimise the query input.
Step 2: Declare the Input as a Parameter
x_query = jno.np.parameter((1,), name="x_query")
x_query.initialize(jax.nn.initializers.constant(0.1)) # initial guess
x_query.optimizer(optax.adam(5e-3))
x_query is a scalar trainable variable — exactly like a PDE coefficient in inverse_parameter.py, but here it is fed as the network input.
Step 3: Build the Inverse Loss
# Apply the same BC factor that was used during training
u_at_query = u_net(x_query) * x_query * (1 - x_query)
inv_loss = (u_at_query - u_obs) ** 2
Step 4: Solve on a Single-Point Domain
The loss has no spatial dependence on a mesh — it is purely a function of x_query. Use jno.domain.from_array to create a minimal one-point domain:
import numpy as np
inv_domain = jno.domain.from_array({"pt": np.zeros((1, 1))})
crux_inv = jno.core([inv_loss.mean])
crux_inv.solve(500)
Step 5: Read Back the Result
What To Notice
u_net.freeze()prevents the surrogate weights from changing during inversion.- Swapping the spatial variable for a
jno.np.parameteris the only change needed to go from forward to inverse mode. - A single-point domain (
jno.domain.from_array) is sufficient when the loss has no spatial structure. - The inversion is gradient-based, so initialisation matters. For functions with multiple pre-images (e.g. \(\sin\)), initialise near the expected region.
Going Further
- Add multiple observations at different locations to make the inverse problem uniquely determined.
- Combine with a field-level inverse problem (recovering a coefficient \(k(x)\)) — see the Inverse Problems guide.
- Extend to 2-D: the query becomes
jno.np.parameter((2,))and the frozen 2D PINN evaluates at that point.
Script Snippet
"""05 — Surrogate inversion: recover a hidden input from a frozen PINN"""
from pathlib import Path
import foundax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import jno
π = jno.np.pi
# ── Keys ─────────────────────────────────────────────────────────────────────
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
# ══════════════════════════════════════════════════════════════════════════════
# Phase 1 — Forward solve
# ══════════════════════════════════════════════════════════════════════════════
domain = jno.domain.line(mesh_size=0.01)
x, _ = domain.variable("interior")
u_net = jno.nn.wrap(foundax.mlp(in_features=1, output_dim=1, hidden_dims=32, num_layers=3, key=k1))
u_net.optimizer(optax.adam(1e-3))
u = u_net(x) * x * (1 - x) # hard zero Dirichlet BCs
pde = u.dd(x) + π**2 * jno.np.sin(π * x) # residual: u'' + π²sin(πx) = 0
crux_fwd = jno.core([pde.mse])
crux_fwd.solve(3_000)
# Verify forward accuracy before inverting
_u, _u_exact = crux_fwd.eval([u, jno.np.sin(π * x)])
fwd_err = float(jnp.linalg.norm(_u - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"[forward] rel-L2 error: {fwd_err:.3e}")
# ══════════════════════════════════════════════════════════════════════════════
# Phase 2 — Inverse: recover x* from u(x*) = u_obs
# ══════════════════════════════════════════════════════════════════════════════
# Freeze the trained surrogate — its weights will not change during inversion
u_net.freeze()
x_true = 0.3
u_obs = float(jnp.sin(jnp.pi * x_true)) # "measured" value ≈ 0.809
print(f"[inverse] target u_obs = {u_obs:.4f} (x_true = {x_true})")
# x_query is the variable we optimise — it starts as an unknown scalar in [0, 1]
x_query = jno.np.parameter((1,), name="x_query")
x_query.initialize(jax.nn.initializers.constant(0.1)) # initial guess: 0.1
x_query.optimizer(optax.adam(5e-3))
# Evaluate the frozen surrogate at the query point (same BC factor as training)
u_at_query = u_net(x_query) * x_query * (1 - x_query)
# Loss: drive u_net(x_query) towards the observation
inv_loss = (u_at_query - u_obs) ** 2
# Single-point domain — the loss has no spatial dependence on the mesh
inv_domain = jno.domain.from_array({"pt": np.zeros((1, 1))})
crux_inv = jno.core([inv_loss.mean], domain=inv_domain)
crux_inv.solve(500)
# ── Results ──────────────────────────────────────────────────────────────────
_x_q = crux_inv.eval([x_query])[0]
x_recovered = float(_x_q)
abs_err = abs(x_recovered - x_true)
print(f"[inverse] recovered x = {x_recovered:.4f} (true = {x_true:.4f})")
print(f"[inverse] absolute error: {abs_err:.4f}")
# ── Assertions ────────────────────────────────────────────────────────────────
assert fwd_err < 0.05, f"Forward solve inaccurate: rel-L2 = {fwd_err:.3e}"
assert abs_err < 0.05, f"Inversion error too large: |x_recovered − x_true| = {abs_err:.4f}"
# ── Result tracking ───────────────────────────────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(
f"05_coupled_and_inverse/surrogate_inversion.py"
f" | fwd_epochs=3000 | inv_epochs=500"
f" | fwd_rel_L2={fwd_err:.6e}"
f" | x_recovered={x_recovered:.6f}"
f" | abs_err={abs_err:.6f}\n"
)