Skip to content

Laplace warm-start via .initialize()

Laplace approximation as a logdensity-aware initializer. Slots into the same .initialize() hook pathfinder uses, with a different algorithm:

a.initialize(jno.bayesian.laplace(
    map_steps=300,
    map_optimizer=optax.adam(1e-1),
    hessian_strategy="diagonal",
))
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=100, adapt=True)

Algorithm

  1. MAP via optax. Optimise -log p(θ | data) with the supplied optimiser (Adam by default). Runs as a JIT-compiled jax.lax.scan over map_steps iterations.
  2. Hessian at the MAPH = -∇²log p. Two strategies:

    • hessian_strategy="full" — full (D, D) Hessian via jax.hessian. Numerically clean; memory cost grows as . Right for D < ~1000.
    • hessian_strategy="diagonal" (default) — diagonal of H computed by D Hessian-vector probes. Memory cost O(D) — required for BNN-scale problems. Compute cost similar to full but no D×D matrix is ever materialised.
  3. Posterior approximation N(MAP, H⁻¹). For num_chains=1 the warm position is the MAP; for num_chains>1 we sample K over-dispersed warm positions from this Gaussian. The diagonal of H⁻¹ is returned as the kernel's inverse_mass_matrix.

A small ridge (default 1e-6) is added to H before any inversion / Cholesky to guard against ill-conditioned Hessians at non-converged MAP estimates.

Trade-offs vs Pathfinder

Aspect Pathfinder Laplace
MAP search L-BFGS (quasi-Newton) gradient descent (Adam by default)
Hessian low-rank inverse-Hessian from L-BFGS path exact ∇²log p at MAP (diagonal or full)
Robustness on multi-modal posteriors better — explores the L-BFGS path local approximation only
Cost for large D dominated by L-BFGS line searches full Hessian quadratic; diagonal linear
Failure mode falls back to a normal approximation that may underestimate posterior variance needs ridge if H is ill-conditioned

Both produce a Gaussian approximation suitable as a warm start; they just get there by different routes.

Numbers from the tutorial

T02-scale problem (truth A = 3.14, B = -2.71); two side-by-side runs with num_chains=2:

Run .initialize(...) A B R-hat A R-hat B Wall
baseline none 3.148 -2.673 1.016 0.999 6.6 s
laplace laplace(map_steps=300, optax.adam(1e-1)) 3.236 -2.553 1.013 1.005 9.8 s

Both recover truth. Laplace's R-hat is marginally better on the harder coefficient. Wall-clock is slower because the Hessian path incurs its own JIT compile; for production problems where the L-BFGS or window-adaptation runs would otherwise dominate, Laplace amortises.

Composition with existing features

Identical to pathfinder's composition matrix — masks, multi-chain, non-IMM kernels, substeps / VI guards all work the same. The mechanism is the shared _BayesianInitializer hook; nothing pathfinder-specific is involved.

When to use Laplace

  • The posterior is unimodal and well-approximated by a Gaussian.
  • You want an exact mass-matrix estimate at the MAP rather than the L-BFGS approximation pathfinder produces.
  • You can afford the MAP-search optimiser steps (Adam can be slow on steep posteriors; tune map_optimizer and map_steps accordingly).

For multi-modal posteriors, large D, or when you want a single robust warm-start with no tuning, pathfinder is usually the better default.

References

  • MacKay, D. J. C. (1992). A Practical Bayesian Framework for Backpropagation Networks. §6 (Laplace approximation around the posterior mode). Neural Computation, 4(3), 448-472. doi:10.1162/neco.1992.4.3.448
  • Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., & Hennig, P. (2021). Laplace Redux — Effortless Bayesian Deep Learning. §2 (Laplace approximations for neural networks). NeurIPS 2021. arXiv:2106.14806
  • Magnani, E., Krämer, N., Pförtner, M., & Hennig, P. (2024). Linearization Turns Neural Operators into Function-Valued Gaussian Processes. §3 (linearised-Laplace for neural operators). arXiv:2406.05072

Script

"""12 — Laplace warm-start via ``.initialize()``"""

import os

# Two sequential solve() calls share device memory; pin CPU for
# portability (remove on hosts with enough VRAM).
os.environ.setdefault("JAX_PLATFORMS", "cpu")

import time  # noqa: E402
from pathlib import Path  # noqa: E402

import blackjax  # noqa: E402
import jax  # noqa: E402
import jax.numpy as jnp  # noqa: E402
import optax  # noqa: E402

import jno  # noqa: E402

π = jno.np.pi
# T02-scale truth — Laplace's Adam-based MAP search converges in
# ``map_steps=300`` for this magnitude.  For larger-scale truths
# (e.g. T11's (50, -30)), bump ``map_steps`` or use ``optax.sgd``
# with a tuned learning-rate schedule.
A_true, B_true = 3.14, -2.71


def _build_problem():
    domain = jno.domain.line(mesh_size=0.02)
    x, _ = domain.variable("interior")
    target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x)

    k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
    a = jno.np.parameter((1,), key=k1, name="a")
    b = jno.np.parameter((1,), key=k2, name="b")

    residual = a * jno.np.sin(π * x) + b * jno.np.cos(π * x) - target
    return domain, a, b, residual


def _run(label, configure_a, configure_b, total_epochs):
    domain, a, b, residual = _build_problem()
    configure_a(a)
    configure_b(b)
    crux = jno.core([residual.mse])
    t0 = time.perf_counter()
    crux.solve(total_epochs)
    wall = time.perf_counter() - t0

    a_chain = a.posterior_samples
    b_chain = b.posterior_samples
    rhat_a = float(jno.bayesian.rhat(a_chain)[0])
    rhat_b = float(jno.bayesian.rhat(b_chain)[0])
    A_mean = float(jnp.mean(a_chain))
    B_mean = float(jnp.mean(b_chain))
    print(f"[{label:10s}] A={A_mean:+.3f}  B={B_mean:+.3f}  R-hat A={rhat_a:.3f}  R-hat B={rhat_b:.3f}  wall={wall:.2f}s")
    return {
        "label": label,
        "A_mean": A_mean,
        "B_mean": B_mean,
        "rhat_a": rhat_a,
        "rhat_b": rhat_b,
        "wall": wall,
    }


# ── Run 1: baseline window adaptation from default zero init ────────────────
baseline = _run(
    label="baseline",
    configure_a=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
    configure_b=lambda p: p.bayesian(blackjax.nuts, step_size=1e-2, warmup=300, keep=300, num_chains=2, adapt=True),
    total_epochs=600,
)


# ── Run 2: Laplace warm-start, no window adaptation ─────────────────────────
def _laplace_only(p):
    # Laplace finds the MAP via 300 Adam steps then forms a Gaussian
    # approximation; the diagonal of H^{-1} becomes the IMM.
    p.initialize(
        jno.bayesian.laplace(
            map_steps=300,
            map_optimizer=optax.adam(1e-1),
            hessian_strategy="diagonal",
        )
    )
    p.bayesian(
        blackjax.nuts,
        step_size=1e-2,
        inverse_mass_matrix=jnp.ones(1),
        warmup=0,
        keep=300,
        num_chains=2,
        adapt=False,
    )


laplace_only = _run(label="laplace", configure_a=_laplace_only, configure_b=_laplace_only, total_epochs=300)

# ── Append summary to tutorial_results.txt ──────────────────────────────────
results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
    for run in (baseline, laplace_only):
        rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
        rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
        f.write(
            f"10_bayesian_pinns/12_laplace_init.py | run={run['label']:10s} | "
            f"rel_A={rel_A:.4f} | rel_B={rel_B:.4f} | "
            f"rhat_a={run['rhat_a']:.3f} | rhat_b={run['rhat_b']:.3f} | "
            f"wall={run['wall']:.2f}s\n"
        )

# ── Asserts ─────────────────────────────────────────────────────────────────
for run in (baseline, laplace_only):
    rel_A = abs(run["A_mean"] - A_true) / abs(A_true)
    rel_B = abs(run["B_mean"] - B_true) / abs(B_true)
    assert rel_A < 0.3, f"{run['label']}: A off by {rel_A:.2%}"
    assert rel_B < 0.3, f"{run['label']}: B off by {rel_B:.2%}"

# Laplace warm-start should produce R-hat at least as good as the
# baseline (warm-start can't hurt mixing).  Loose tolerance — multichain
# R-hat with K=2 on a short chain is intrinsically noisy.
assert laplace_only["rhat_a"] <= baseline["rhat_a"] + 0.5, (
    f"laplace R-hat A worse than baseline: {laplace_only['rhat_a']:.3f} vs {baseline['rhat_a']:.3f}"
)