SVGD warm-start via .initialize()
Stein Variational Gradient Descent (SVGD) as a logdensity-aware
initializer. Third concrete entry on the .initialize() hook
landed in Phase 12:
a.initialize(jno.bayesian.svgd(num_iters=300, num_particles=32))
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=100, adapt=True)
Algorithm
- Seed particles.
num_particlesparticles are placed around the user-supplied position by adding Gaussian noise of stdinit_jitter. Defaultnum_particles = max(num_chains, 32)so we always have enough particles for stable variance estimation even when the caller only asked for 1 chain.init_jitterdefaults toNone→max(0.1 * std(position), 1e-3)— a scale-aware spread one-tenth of the parameter scale. Pass an explicit positive float (e.g.init_jitter=0.5) to override. - Run SVGD. Each particle is updated by a kernelised functional gradient (Liu & Wang 2016, eq. 8):
$$ \phi^*(x) = \frac{1}{N}\sum_{j} \left[ k(x_j, x)\, \nabla_{x_j} \log p(x_j) + \nabla_{x_j} k(x_j, x) \right] $$
The first term pulls each particle toward higher posterior
density. The second term — the gradient of the RBF kernel —
pushes particles apart so they spread out. jno's wrapper calls
blackjax.svgd inside a jax.lax.scan over num_iters
iterations.
3. Use the particle cloud as the warm-start.
num_chains=1— particle-cloud mean as the warm position.num_chains>1— firstnum_chainsparticles as K distinct warm positions. The repulsive kernel dynamics already provide proper over-dispersion; no additional jitter is needed.
Per-dim particle variance (plus a small ridge) is returned as the
diagonal inverse_mass_matrix.
Trade-offs vs Pathfinder / Laplace
| Aspect | Pathfinder | Laplace | SVGD |
|---|---|---|---|
| Posterior approximation | unimodal Gaussian from L-BFGS factors | unimodal Gaussian at MAP | particle cloud (can be multi-modal) |
| Multi-modal | underestimates | local only | captures with enough particles |
| Compute cost per step | one L-BFGS line search | one gradient eval | O(N²) pairwise kernel evals |
| Memory | low | O(D²) full, O(D) diagonal |
O(N × D) |
| When to pick | unimodal, fast, no tuning | exact Hessian wanted | suspected multi-modality, willing to pay N² cost |
Numbers from the tutorial
T02-scale problem (truth A = 3.14, B = -2.71); two side-by-side
runs with num_chains=2:
| Run | .initialize(...) |
A | B | R-hat A | R-hat B | Wall |
|---|---|---|---|---|---|---|
| baseline | none | 3.148 | -2.673 | 1.016 | 0.999 | 6.5 s |
| svgd | svgd(num_iters=300, num_particles=32) |
3.211 | -2.505 | 1.003 | 1.004 | 9.3 s |
Both recover truth. SVGD's R-hat is the tightest of the three initializers (Pathfinder ≈ Laplace ≈ SVGD on a unimodal problem; SVGD's advantage shows on multi-modal posteriors which this short tutorial doesn't exhibit).
Composition with existing features
Identical to pathfinder's composition matrix.
Masks, multi-chain, non-IMM kernels, substeps / VI guards all work
the same — they're handled by the shared _BayesianInitializer
dispatch helpers, not by anything SVGD-specific.
When to use SVGD
- You suspect a multi-modal posterior — SVGD's repulsive kernel can reach distinct modes that Pathfinder / Laplace cannot.
- You can afford
num_particles²pairwise kernel evaluations per iteration. Fornum_particles=32andnum_iters=300that's ≈ 300k kernel evals — small for scalar PDE coefficients, modest for BNNs. - You want a fully deterministic variational method (Pathfinder's ELBO sampling is stochastic; Laplace requires gradient-descent convergence; SVGD's only randomness is the initial particle seeding).
For unimodal posteriors with no multi-modality concerns, Pathfinder is usually the cheaper default.
Reference
Liu, Q., & Wang, D. (2016). Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm. §3 (the SVGD update rule and the kernelised Stein discrepancy it minimises). Advances in Neural Information Processing Systems (NeurIPS), 29, 2378-2386. arXiv:1608.04471
Script
"""13 — SVGD warm-start via ``.initialize()``"""
import os
# Two sequential solve() calls share device memory; pin CPU for
# portability (remove on hosts with enough VRAM).
os.environ.setdefault("JAX_PLATFORMS", "cpu")
import time # noqa: E402
from pathlib import Path # noqa: E402
import blackjax # noqa: E402
import jax # noqa: E402
import jax.numpy as jnp # noqa: E402
import jno # noqa: E402
π = jno.np.pi
A_true, B_true = 3.14, -2.71
def _build_problem():
domain = jno.domain.line(mesh_size=0.02)
x, _ = domain.variable("interior")
target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x)
k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
a = jno.np.parameter((1,), key=k1, name="a")
b = jno.np.parameter((1,), key=k2, name="b")
residual = a * jno.np.sin(π * x) + b * jno.np.cos(π * x) - target
return domain, a, b, residual
def _run(label, configure_a, configure_b, total_epochs):
domain, a, b, residual = _build_problem()
configure_a(a)
configure_b(b)
crux = jno.core([residual.mse])
t0 = time.perf_counter()
crux.solve(total_epochs)
wall = time.perf_counter() - t0
a_chain = a.posterior_samples
b_chain = b.posterior_samples
rhat_a = float(jno.bayesian.rhat(a_chain)[0])
rhat_b = float(jno.bayesian.rhat(b_chain)[0])
A_mean = float(jnp.mean(a_chain))
B_mean = float(jnp.mean(b_chain))
print(f"[{label:10s}] A={A_mean:+.3f} B={B_mean:+.3f} R-hat A={rhat_a:.3f} R-hat B={rhat_b:.3f} wall={wall:.2f}s")
return {
"label": label,
"A_mean": A_mean,
"B_mean": B_mean,
"rhat_a": rhat_a,
"rhat_b": rhat_b,
"wall": wall,
}
# ── Run 1: baseline window adaptation from default zero init ────────────────
baseline = _run(
label="baseline",
configure_a=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
configure_b=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
total_epochs=600,
)
# ── Run 2: SVGD warm-start ──────────────────────────────────────────────────
def _svgd_only(p):
# 32 particles, 300 SVGD iterations, init spread = 2. The
# particle ensemble approximates the posterior and serves as the
# warm-start (mean for K=1, K distinct particles for K>1).
p.initialize(
jno.bayesian.svgd(
num_iters=300,
num_particles=32,
init_jitter=2.0,
)
)
p.bayesian(
blackjax.nuts,
step_size=1e-2,
inverse_mass_matrix=jnp.ones(1),
warmup=0,
keep=300,
num_chains=2,
adapt=False,
)
svgd_only = _run(label="svgd", configure_a=_svgd_only, configure_b=_svgd_only, total_epochs=300)
# ── Append summary to tutorial_results.txt ──────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
for run in (baseline, svgd_only):
rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
f.write(
f"10_bayesian_pinns/13_svgd_init.py | run={run['label']:10s} | "
f"rel_A={rel_A:.4f} | rel_B={rel_B:.4f} | "
f"rhat_a={run['rhat_a']:.3f} | rhat_b={run['rhat_b']:.3f} | "
f"wall={run['wall']:.2f}s\n"
)
# ── Asserts ─────────────────────────────────────────────────────────────────
for run in (baseline, svgd_only):
rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
assert rel_A < 0.3, f"{run['label']}: A off by {rel_A:.2%}"
assert rel_B < 0.3, f"{run['label']}: B off by {rel_B:.2%}"
# SVGD warm-start should produce R-hat at least as good as the
# baseline; loose tolerance since K=2 R-hat is noisy.
assert svgd_only["rhat_a"] <= baseline["rhat_a"] + 0.5, (
f"svgd R-hat A worse than baseline: {svgd_only['rhat_a']:.3f} vs {baseline['rhat_a']:.3f}"
)