Skip to content

SVGD warm-start via .initialize()

Stein Variational Gradient Descent (SVGD) as a logdensity-aware initializer. Third concrete entry on the .initialize() hook landed in Phase 12:

a.initialize(jno.bayesian.svgd(num_iters=300, num_particles=32))
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=100, adapt=True)

Algorithm

  1. Seed particles. num_particles particles are placed around the user-supplied position by adding Gaussian noise of std init_jitter. Default num_particles = max(num_chains, 32) so we always have enough particles for stable variance estimation even when the caller only asked for 1 chain. init_jitter defaults to Nonemax(0.1 * std(position), 1e-3) — a scale-aware spread one-tenth of the parameter scale. Pass an explicit positive float (e.g. init_jitter=0.5) to override.
  2. Run SVGD. Each particle is updated by a kernelised functional gradient (Liu & Wang 2016, eq. 8):

$$ \phi^*(x) = \frac{1}{N}\sum_{j} \left[ k(x_j, x)\, \nabla_{x_j} \log p(x_j) + \nabla_{x_j} k(x_j, x) \right] $$

The first term pulls each particle toward higher posterior density. The second term — the gradient of the RBF kernel — pushes particles apart so they spread out. jno's wrapper calls blackjax.svgd inside a jax.lax.scan over num_iters iterations. 3. Use the particle cloud as the warm-start.

  • num_chains=1 — particle-cloud mean as the warm position.
  • num_chains>1 — first num_chains particles as K distinct warm positions. The repulsive kernel dynamics already provide proper over-dispersion; no additional jitter is needed.

Per-dim particle variance (plus a small ridge) is returned as the diagonal inverse_mass_matrix.

Trade-offs vs Pathfinder / Laplace

Aspect Pathfinder Laplace SVGD
Posterior approximation unimodal Gaussian from L-BFGS factors unimodal Gaussian at MAP particle cloud (can be multi-modal)
Multi-modal underestimates local only captures with enough particles
Compute cost per step one L-BFGS line search one gradient eval O(N²) pairwise kernel evals
Memory low O(D²) full, O(D) diagonal O(N × D)
When to pick unimodal, fast, no tuning exact Hessian wanted suspected multi-modality, willing to pay cost

Numbers from the tutorial

T02-scale problem (truth A = 3.14, B = -2.71); two side-by-side runs with num_chains=2:

Run .initialize(...) A B R-hat A R-hat B Wall
baseline none 3.148 -2.673 1.016 0.999 6.5 s
svgd svgd(num_iters=300, num_particles=32) 3.211 -2.505 1.003 1.004 9.3 s

Both recover truth. SVGD's R-hat is the tightest of the three initializers (Pathfinder ≈ Laplace ≈ SVGD on a unimodal problem; SVGD's advantage shows on multi-modal posteriors which this short tutorial doesn't exhibit).

Composition with existing features

Identical to pathfinder's composition matrix. Masks, multi-chain, non-IMM kernels, substeps / VI guards all work the same — they're handled by the shared _BayesianInitializer dispatch helpers, not by anything SVGD-specific.

When to use SVGD

  • You suspect a multi-modal posterior — SVGD's repulsive kernel can reach distinct modes that Pathfinder / Laplace cannot.
  • You can afford num_particles² pairwise kernel evaluations per iteration. For num_particles=32 and num_iters=300 that's ≈ 300k kernel evals — small for scalar PDE coefficients, modest for BNNs.
  • You want a fully deterministic variational method (Pathfinder's ELBO sampling is stochastic; Laplace requires gradient-descent convergence; SVGD's only randomness is the initial particle seeding).

For unimodal posteriors with no multi-modality concerns, Pathfinder is usually the cheaper default.

Reference

Liu, Q., & Wang, D. (2016). Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm. §3 (the SVGD update rule and the kernelised Stein discrepancy it minimises). Advances in Neural Information Processing Systems (NeurIPS), 29, 2378-2386. arXiv:1608.04471

Script

"""13 — SVGD warm-start via ``.initialize()``"""

import os

# Two sequential solve() calls share device memory; pin CPU for
# portability (remove on hosts with enough VRAM).
os.environ.setdefault("JAX_PLATFORMS", "cpu")

import time  # noqa: E402
from pathlib import Path  # noqa: E402

import blackjax  # noqa: E402
import jax  # noqa: E402
import jax.numpy as jnp  # noqa: E402

import jno  # noqa: E402

π = jno.np.pi
A_true, B_true = 3.14, -2.71


def _build_problem():
    domain = jno.domain.line(mesh_size=0.02)
    x, _ = domain.variable("interior")
    target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x)

    k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
    a = jno.np.parameter((1,), key=k1, name="a")
    b = jno.np.parameter((1,), key=k2, name="b")

    residual = a * jno.np.sin(π * x) + b * jno.np.cos(π * x) - target
    return domain, a, b, residual


def _run(label, configure_a, configure_b, total_epochs):
    domain, a, b, residual = _build_problem()
    configure_a(a)
    configure_b(b)
    crux = jno.core([residual.mse])
    t0 = time.perf_counter()
    crux.solve(total_epochs)
    wall = time.perf_counter() - t0

    a_chain = a.posterior_samples
    b_chain = b.posterior_samples
    rhat_a = float(jno.bayesian.rhat(a_chain)[0])
    rhat_b = float(jno.bayesian.rhat(b_chain)[0])
    A_mean = float(jnp.mean(a_chain))
    B_mean = float(jnp.mean(b_chain))
    print(f"[{label:10s}] A={A_mean:+.3f}  B={B_mean:+.3f}  R-hat A={rhat_a:.3f}  R-hat B={rhat_b:.3f}  wall={wall:.2f}s")
    return {
        "label": label,
        "A_mean": A_mean,
        "B_mean": B_mean,
        "rhat_a": rhat_a,
        "rhat_b": rhat_b,
        "wall": wall,
    }


# ── Run 1: baseline window adaptation from default zero init ────────────────
baseline = _run(
    label="baseline",
    configure_a=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
    configure_b=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
    total_epochs=600,
)


# ── Run 2: SVGD warm-start ──────────────────────────────────────────────────
def _svgd_only(p):
    # 32 particles, 300 SVGD iterations, init spread = 2.  The
    # particle ensemble approximates the posterior and serves as the
    # warm-start (mean for K=1, K distinct particles for K>1).
    p.initialize(
        jno.bayesian.svgd(
            num_iters=300,
            num_particles=32,
            init_jitter=2.0,
        )
    )
    p.bayesian(
        blackjax.nuts,
        step_size=1e-2,
        inverse_mass_matrix=jnp.ones(1),
        warmup=0,
        keep=300,
        num_chains=2,
        adapt=False,
    )


svgd_only = _run(label="svgd", configure_a=_svgd_only, configure_b=_svgd_only, total_epochs=300)

# ── Append summary to tutorial_results.txt ──────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
    for run in (baseline, svgd_only):
        rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
        rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
        f.write(
            f"10_bayesian_pinns/13_svgd_init.py | run={run['label']:10s} | "
            f"rel_A={rel_A:.4f} | rel_B={rel_B:.4f} | "
            f"rhat_a={run['rhat_a']:.3f} | rhat_b={run['rhat_b']:.3f} | "
            f"wall={run['wall']:.2f}s\n"
        )

# ── Asserts ─────────────────────────────────────────────────────────────────
for run in (baseline, svgd_only):
    rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
    rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
    assert rel_A < 0.3, f"{run['label']}: A off by {rel_A:.2%}"
    assert rel_B < 0.3, f"{run['label']}: B off by {rel_B:.2%}"

# SVGD warm-start should produce R-hat at least as good as the
# baseline; loose tolerance since K=2 R-hat is noisy.
assert svgd_only["rhat_a"] <= baseline["rhat_a"] + 0.5, (
    f"svgd R-hat A worse than baseline: {svgd_only['rhat_a']:.3f} vs {baseline['rhat_a']:.3f}"
)