Inverse Parameter
This example is an inverse problem rather than a field solve: it learns unknown scalar coefficients from residual constraints.
Problem Setup
The script introduces trainable scalar parameters and fits them so synthetic constraints are satisfied.
Step 1: Treat Parameters as Learnable Objects
Instead of only training a neural field, the script creates scalar parameter models that participate in optimization.
A_true, B_true, C_true = 3.14, -2.71, 42.0
domain = jno.domain.line(mesh_size=0.01)
x, _ = domain.variable("interior")
target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x) + C_true * x * (1 - x)
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
a = jno.np.parameter((1,), key=k1, name="a")
b = jno.np.parameter((1,), key=k2, name="b")
c = jno.np.parameter((1,), key=k3, name="c")
Step 2: Build Residuals From Data Relationships
The optimization target is a set of algebraic or residual constraints rather than a spatial PDE field.
residual = (a * jno.np.sin(π * x) + b * jno.np.cos(π * x) + c * x * (1 - x)) - target
for net in [a, b, c]:
net.optimizer(optax.adam(1e-2))
Step 3: Solve and Inspect Learned Coefficients
After optimization, the identified parameters are printed from the trained model set.
crux = jno.core([residual.mse])
history = crux.solve(30000)
_a, _b, _c = crux.eval([a, b, c])
print(f"Recovered parameters: a={_a[0]:.3f}, b={_b[0]:.3f}, c={_c[0]:.3f}")
What To Notice
- jNO can optimize more than neural fields.
- Inverse problems often reuse the same core workflow with different residual definitions.
- This example is a good template for coefficient discovery and calibration.
Going Further
For field identification (recovering a spatially-varying k(x,y) rather than a scalar), see the Inverse Problems guide, which covers:
field.regularize(...)— smooth, TV, positivity and bounded penalties on identified fieldsModel.constrain(transform)— hard parameter constraints via paramax reparameterizationjno.domain.from_array— attaching sparse sensor observations without file I/O
Script Snippet
"""05 — Inverse parameter identification"""
from pathlib import Path
import jax
import optax
import jno
π = jno.np.pi
A_true, B_true, C_true = 3.14, -2.71, 42.0
domain = jno.domain.line(mesh_size=0.01)
x, _ = domain.variable("interior")
target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x) + C_true * x * (1 - x)
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
a = jno.np.parameter((1,), key=k1, name="a")
b = jno.np.parameter((1,), key=k2, name="b")
c = jno.np.parameter((1,), key=k3, name="c")
residual = (a * jno.np.sin(π * x) + b * jno.np.cos(π * x) + c * x * (1 - x)) - target
for param in (a, b, c):
param.optimizer(optax.adam(1e-2))
crux = jno.core([residual.mse])
crux.solve(30_000)
_a, _b, _c = crux.eval([a, b, c])
print(f"Recovered: a={_a[0]:.3f} b={_b[0]:.3f} c={_c[0]:.3f} (truth: {A_true}, {B_true}, {C_true})")
rel_l2_a = float(jax.numpy.linalg.norm(_a - A_true) / (jax.numpy.linalg.norm(A_true) + 1e-8))
rel_l2_b = float(jax.numpy.linalg.norm(_b - B_true) / (jax.numpy.linalg.norm(B_true) + 1e-8))
rel_l2_c = float(jax.numpy.linalg.norm(_c - C_true) / (jax.numpy.linalg.norm(C_true) + 1e-8))
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(
f"05_coupled_and_inverse/inverse_parameter.py | epochs=30000"
f" | rel_L2_a={rel_l2_a:.6e} | rel_L2_b={rel_l2_b:.6e} | rel_L2_c={rel_l2_c:.6e}\n"
)
assert rel_l2_a < 1e-1, f"a rel_L2 too large: {rel_l2_a:.3e}"
assert rel_l2_b < 1e-1, f"b rel_L2 too large: {rel_l2_b:.3e}"
assert rel_l2_c < 1e-1, f"c rel_L2 too large: {rel_l2_c:.3e}"