Skip to content

Inverse: ODE rate constant (no surrogate)

Bayesian rate-constant inference for a first-order ODE with a fixed posterior. The model du/dt = -k u(t) has the closed-form solution u(t) = exp(-k t). We plug this analytical expression directly into the likelihood — no neural surrogate is involved — and let NUTS sample a fixed-target posterior over k from sparse noisy observations.

This is the simplest member of a broad family of real-world rate-constant inverse problems (radioactive decay, first-order pharmacokinetic elimination, single-compartment epidemic dynamics). Linka et al. 2022 use the same Bayesian inference recipe for COVID-19 SIR modelling; the only difference is the dimensionality of the state vector and the noise model.

Why no neural surrogate? When the forward model has a closed form (or a cheap numerical integrator), wrapping it in a PINN introduces mixed-mode noise that makes the posterior brittle to hyperparameters. A direct analytical likelihood gives a properly-defined Bayesian inference where hyperparameters affect chain efficiency only — not the target. See Inverse FEM Diffusivity for the pattern when no closed form is available: jNO's FEM solver provides the differentiable forward and blackjax samples the posterior directly.

References

Linka, K., Schäfer, A., Meng, X., Zou, Z., Karniadakis, G. E., & Kuhl, E. (2022). Bayesian Physics-Informed Neural Networks for real-world nonlinear dynamical systems. Computer Methods in Applied Mechanics and Engineering, 402, 115346.

Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler. Journal of Machine Learning Research, 15(1), 1593-1623.

Script

"""04 — Bayesian inverse for ODE rate constant (no surrogate, closed-form)"""

from pathlib import Path

import blackjax
import jax
import jax.numpy as jnp

import jno

# ── Physical setup ────────────────────────────────────────────────────────────
k_true = 0.5
T_end = 4.0
sigma_obs = 0.05  # observation-noise standard deviation

# ── Domain (1-D line over the time axis: x ≡ t) ───────────────────────────────
domain = jno.domain.line(x_range=(0.0, T_end), mesh_size=0.1)
t, _ = domain.variable("interior")

# ── Synthetic noisy observations of u(t) under k_true ─────────────────────────
# jno.noise.gaussian is redrawn every step from the solver's PRNG key;
# with a fixed global seed the realisation is deterministic across the
# whole chain, so the likelihood NUTS sees is consistent.
u_obs = jno.np.exp(-k_true * t) + jno.noise.gaussian(std=sigma_obs)

# ── Bayesian rate constant — NUTS with closed-form forward model ─────────────
k = jno.np.parameter((1,), key=jax.random.PRNGKey(0), name="k")
k.bayesian(
    blackjax.nuts,
    step_size=1e-1,  # initial guess; window adaptation tunes it
    warmup=400,
    keep=600,
    # adapt=True default: window_adaptation runs for `warmup` steps and
    # tunes step_size + inverse_mass_matrix.  This is pure Bayesian
    # inference (no mixed mode), so adaptation is well-defined.
)

# ── Closed-form forward model: u(t; k) = exp(-k · t) ──────────────────────────
u_model = jno.np.exp(-k * t)

# ── Likelihood: Gaussian-noise residual scaled by σ ───────────────────────────
# The 1/σ² factor is the proper log-likelihood weighting; jno's `.mse`
# averages the squared residual.  Multiplying by 1/σ² gives the correct
# scale of the Gaussian log-density (up to a constant).
residual = (u_model - u_obs) / sigma_obs

# ── Solve — pure Bayesian, no surrogate, no substeps needed ──────────────────
crux = jno.core([residual.mse])
crux.solve(1000)

# ── Posterior summary ────────────────────────────────────────────────────────
k_chain = k.posterior_samples  # shape (600, 1)
k_mean = float(jnp.mean(k_chain))
k_std = float(jnp.std(k_chain))
k_lo, k_hi = (float(v) for v in jnp.quantile(k_chain, jnp.array([0.05, 0.95])))

print(f"k = {k_mean:.4f} ± {k_std:.4f}")
print(f"   90% CI = [{k_lo:.4f}, {k_hi:.4f}]   truth = {k_true}")

rel_k = abs(k_mean - k_true) / abs(k_true)

results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
    f.write(f"10_bayesian_pinns/04_inverse_ode_decay.py | epochs=1000 | rel_k={rel_k:.4f} | CI_width={k_hi - k_lo:.4f}\n")

assert rel_k < 0.1, f"posterior-mean k off by {rel_k:.2%}"