Pathfinder warm-start via .initialize()
Demonstrates jno's logdensity-aware initializer hook. Pathfinder
(Zhang et al. 2022) runs L-BFGS on the loss-derived log-density and
turns the inverse-Hessian trajectory into a normal approximation to the
posterior. From that fitted q we get:
- a warm starting position — the MAP-ish
state.positionfor K=1 chains, or K i.i.d. samples for K>1 (proper over-dispersion); - a diagonal
inverse_mass_matrixestimate from the per-dimension variance of M draws fromq.
Exposed through the existing .initialize() API — no new kwargs on
.bayesian():
a.initialize(jno.bayesian.pathfinder(maxiter=30, num_samples=200))
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=100, adapt=True)
.bayesian()'s warmup + adapt still mean what they do today and
apply after pathfinder. When adapt=True, window adaptation
re-runs from pathfinder's warm position and may further refine
step_size; when adapt=False, pathfinder's IMM is final and the
user's step_size is kept verbatim.
The behaviour matrix
.initialize(pathfinder(...)) |
adapt |
What runs |
|---|---|---|
| not set | True (default) |
window adaptation from the user's init |
| not set | False |
user's init, user's step_size — no warmup |
| set | True |
pathfinder → window: warm position + IMM, then window refines step_size |
| set | False |
pathfinder only: warm position + pathfinder IMM, user's step_size |
Numbers from the tutorial
Three side-by-side runs on the harmonic regression problem from T02
(truth A = 50.0, B = -30.0, default zero init). Each run uses
num_chains=2.
| Run | .initialize(...) |
warmup, adapt |
A | B | R-hat A | R-hat B | Wall |
|---|---|---|---|---|---|---|---|
| baseline | none | 300, True |
49.55 | -29.77 | 1.038 | 1.000 | 6.5 s |
| pf-only | pathfinder(30, 200) |
0, False |
49.61 | -29.60 | 1.007 | 1.006 | 12.7 s |
| chained | pathfinder(30, 200) |
100, True |
49.63 | -29.70 | 1.004 | 0.998 | 8.3 s |
The pathfinder-touched runs produce lower R-hat (chains agree more tightly) on R-hat A specifically — the harder of the two coefficients for the baseline. Pathfinder's L-BFGS lands both chains in the same basin before sampling, so they don't need to mix into agreement themselves. Wall-clock is dominated by JIT compile on this tiny problem; on production-scale B-PINN runs the L-BFGS cost amortises quickly.
Extending — writing your own initializer
The .initialize() hook is generic. Any class with
requires_logdensity = True plus a __call__ matching the
_BayesianInitializer
contract will be detected and dispatched the same way:
import jno
class _MyInit(jno.bayesian._BayesianInitializer):
def __call__(self, rng_key, logdensity_fn, position, num_chains):
# ... your algorithm ...
return new_position, {"inverse_mass_matrix": ...} # IMM optional
a.initialize(_MyInit())
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=100, adapt=True)
Future jno initializers — Laplace approximation (Magnani et al. 2024), SVGD (Liu & Wang 2016), MAP via Adam — slot in as additional subclasses on the same hook.
Composition with existing features
- Masks —
.mask(M).bayesian()+ pathfinder works: pathfinder runs against the masked subset's log-density; the unmasked complement stays at init. - Multi-chain — when
num_chains > 1, pathfinder samples K distinct starting positions from the fittedq. Strictly better dispersion than theinit_jitterheuristic, which is silently overridden when pathfinder is set. - Non-IMM kernels (MALA / SGLD / SGHMC) — pathfinder's warm position is applied; the IMM update is silently dropped by the signature gate. Sampler runs as before.
substeps=— not compatible (initializer runs against the full loss, kernel sees only substep-local constraints) — raises a clear error at solve start..vi(...)— not compatible (VI initialises its own variational distribution fromstate.mu = position) — raises a clear error at solve start.
Reference
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. arXiv:2108.03782
Script
"""11 — Pathfinder warm-start via ``.initialize()``"""
import os
# Three sequential solve() calls share device memory; on small GPUs the
# second JIT compile can OOM. Pinning to CPU keeps the tutorial portable;
# remove this line on a host with enough VRAM (~6 GiB suffices).
os.environ.setdefault("JAX_PLATFORMS", "cpu")
import time # noqa: E402
from pathlib import Path # noqa: E402
import blackjax # noqa: E402
import jax # noqa: E402
import jax.numpy as jnp # noqa: E402
import jno # noqa: E402
# ── Shared problem definition ───────────────────────────────────────────────
π = jno.np.pi
A_true, B_true = 50.0, -30.0 # large-magnitude truth → posterior mode far from default zero init
def _build_problem():
"""Build a fresh domain + parameters at default zero init. Returns
``(domain, a, b, residual)`` ready for ``crux.solve``.
"""
domain = jno.domain.line(mesh_size=0.02)
x, _ = domain.variable("interior")
target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x)
k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
a = jno.np.parameter((1,), key=k1, name="a")
b = jno.np.parameter((1,), key=k2, name="b")
residual = a * jno.np.sin(π * x) + b * jno.np.cos(π * x) - target
return domain, a, b, residual
def _run(label, configure_a, configure_b, total_epochs):
"""Run one solve(); return per-parameter summary + wall-clock."""
domain, a, b, residual = _build_problem()
configure_a(a)
configure_b(b)
crux = jno.core([residual.mse])
t0 = time.perf_counter()
crux.solve(total_epochs)
wall = time.perf_counter() - t0
a_chain = a.posterior_samples # (K, N, 1)
b_chain = b.posterior_samples
rhat_a = float(jno.bayesian.rhat(a_chain)[0])
rhat_b = float(jno.bayesian.rhat(b_chain)[0])
A_mean = float(jnp.mean(a_chain))
B_mean = float(jnp.mean(b_chain))
print(f"[{label:10s}] A={A_mean:+.3f} B={B_mean:+.3f} R-hat A={rhat_a:.3f} R-hat B={rhat_b:.3f} wall={wall:.2f}s")
return {
"label": label,
"A_mean": A_mean,
"B_mean": B_mean,
"rhat_a": rhat_a,
"rhat_b": rhat_b,
"wall": wall,
}
# ── Run 1: baseline — window adaptation from bad init ───────────────────────
baseline = _run(
label="baseline",
configure_a=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
configure_b=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
total_epochs=600,
)
# ── Run 2: pathfinder only — no window adaptation ──────────────────────────
def _pf_only(p):
p.initialize(jno.bayesian.pathfinder(maxiter=30, num_samples=200))
p.bayesian(
blackjax.nuts,
step_size=1e-2,
inverse_mass_matrix=jnp.ones(1),
warmup=0,
keep=300,
num_chains=2,
adapt=False,
)
pathfinder_only = _run(label="pf-only", configure_a=_pf_only, configure_b=_pf_only, total_epochs=300)
# ── Run 3: chained — pathfinder + window adaptation ─────────────────────────
def _pf_chain(p):
p.initialize(jno.bayesian.pathfinder(maxiter=30, num_samples=200))
p.bayesian(
blackjax.nuts,
step_size=1e-2,
warmup=100,
keep=300,
num_chains=2,
adapt=True,
)
chained = _run(label="chained", configure_a=_pf_chain, configure_b=_pf_chain, total_epochs=400)
# ── Append summary row to tutorial_results.txt ──────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
for run in (baseline, pathfinder_only, chained):
rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
f.write(
f"10_bayesian_pinns/11_pathfinder_init.py | run={run['label']:10s} | "
f"rel_A={rel_A:.4f} | rel_B={rel_B:.4f} | "
f"rhat_a={run['rhat_a']:.3f} | rhat_b={run['rhat_b']:.3f} | "
f"wall={run['wall']:.2f}s\n"
)
# ── Asserts ─────────────────────────────────────────────────────────────────
# All three runs recover (A, B). The pathfinder-touched runs do so with
# materially less warmup (or none); the baseline pays for full window
# adaptation from the bad init.
for run in (baseline, pathfinder_only, chained):
rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
assert rel_A < 0.3, f"{run['label']}: A off by {rel_A:.2%}"
assert rel_B < 0.3, f"{run['label']}: B off by {rel_B:.2%}"
# Chained warm-start should produce R-hat at least as good as the
# baseline (warm-start can't hurt mixing). Loose tolerance — multichain
# R-hat with K=2 on a short chain is intrinsically noisy.
assert chained["rhat_a"] <= baseline["rhat_a"] + 0.5, (
f"chained R-hat A worse than baseline: {chained['rhat_a']:.3f} vs {baseline['rhat_a']:.3f}"
)
assert chained["rhat_b"] <= baseline["rhat_b"] + 0.5, (
f"chained R-hat B worse than baseline: {chained['rhat_b']:.3f} vs {baseline['rhat_b']:.3f}"
)