Inverse B-PINN: multi-coefficient regression
Pure inverse problem, no PDE residual. Recover two coefficients
(A, B) of a parametric model d(x) = A sin(πx) + B cos(πx) from
observations of d. Each coefficient is configured with its own NUTS
kernel via .bayesian(blackjax.nuts, ...); solve() dispatches them
in parallel. Output: posterior mean and 90 % credible interval per
coefficient.
This is the cleanest demonstration of jNO's per-parameter Bayesian configurator — every scalar carries its own kernel, no mixed mode, no surrogate to worry about. The same code shape works for any number of coefficients.
Script
"""02 — Bayesian PINN inverse problem: multi-coefficient regression"""
from pathlib import Path
import blackjax
import jax
import jax.numpy as jnp
import jno
# ── Domain & "measured" data ──────────────────────────────────────────────────
π = jno.np.pi
domain = jno.domain.line(mesh_size=0.02)
x, _ = domain.variable("interior")
A_true, B_true = 3.14, -2.71
target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x)
# ── Trainable scalar parameters with per-parameter NUTS samplers ──────────────
k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
a = jno.np.parameter((1,), key=k1, name="a")
b = jno.np.parameter((1,), key=k2, name="b")
for p in [a, b]:
# adapt=True (default) runs blackjax.window_adaptation for `warmup`
# steps and tunes step_size + inverse_mass_matrix automatically — the
# `step_size` given here is just the adapter's initial guess.
p.bayesian(
blackjax.nuts,
step_size=1e-2,
warmup=300,
keep=500,
)
# ── Residual + solve ──────────────────────────────────────────────────────────
residual = a * jno.np.sin(π * x) + b * jno.np.cos(π * x) - target
crux = jno.core([residual.mse])
crux.solve(800)
# ── Posterior summary — raw chain → user post-processes ───────────────────────
a_chain = a.posterior_samples # shape (500, 1)
b_chain = b.posterior_samples
A_mean = float(jnp.mean(a_chain))
A_lo, A_hi = (float(v) for v in jnp.quantile(a_chain, jnp.array([0.05, 0.95])))
B_mean = float(jnp.mean(b_chain))
B_lo, B_hi = (float(v) for v in jnp.quantile(b_chain, jnp.array([0.05, 0.95])))
print(f"A = {A_mean:.3f} 90% CI = [{A_lo:.3f}, {A_hi:.3f}] truth = {A_true}")
print(f"B = {B_mean:.3f} 90% CI = [{B_lo:.3f}, {B_hi:.3f}] truth = {B_true}")
rel_A = abs(A_mean - A_true) / abs(A_true)
rel_B = abs(B_mean - B_true) / abs(B_true)
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
f.write(f"10_bayesian_pinns/02_inverse_multi_coefficient.py | epochs=800 | rel_A={rel_A:.4f} | rel_B={rel_B:.4f}\n")
assert rel_A < 0.3, f"A posterior mean off by {rel_A:.2%}"
assert rel_B < 0.3, f"B posterior mean off by {rel_B:.2%}"