Skip to content

Mean-field Variational Inference on a BNN regressor

Same gapped regression problem as Tutorial 07, but trained via mean-field Variational Inference instead of SGLD. The 32 noisy observations of u(x) = sin³(6x) are placed in [-0.8, -0.2] ∪ [0.2, 0.8] — a deliberate gap around x = 0. Each weight in a small MLP is given a Gaussian variational marginal q(θ_i) = N(μ_i, σ_i); the joint q(θ) = ∏_i q(θ_i) is fit by maximising the evidence lower bound (ELBO), then 600 i.i.d. samples are drawn from the fitted q for posterior summaries.

VI vs SGLD on the same problem

Metric SGLD (T07) Mean-field VI (T09)
In-data rel-L2 of mean ≈ 0.19 ≈ 0.24
In-data band (90 %) ≈ 1.53 ≈ 0.32
In-gap band (90 %) ≈ 3.74 ≈ 0.44
Gap / in-data ratio ≈ 2.45× ≈ 1.38×
Wall-clock similar (CPU, < 30 s) similar (CPU, < 30 s)

VI gives much tighter bands, but the gap-vs-data ratio is smaller because mean-field's per-weight independence assumption typically underestimates correlated weight uncertainty. Both methods reproduce the qualitative BNN behaviour (band widens in the gap) — VI just does so with less dramatic absolute widths.

Why scale the residual by √N?

Look at line 95 of the script::

residual = (y_pred - y_train_const) / sigma_obs * jnp.sqrt(N_obs)

The canonical Gaussian-noise log-likelihood is the sum over data points, but residual.mse returns the mean. Multiplying the residual by √N makes residual.mse equal to the sum of squared standardised residuals, which is the right magnitude for the Bayesian log-likelihood. Without this rescaling the likelihood term is N times too small and the prior dominates, leaving VI stuck near initialisation.

(For MCMC tutorials in this section the same scaling would tighten the posterior; it's not strictly required since MCMC's gradient signal is more robust to magnitude than VI's stochastic ELBO gradient.)

How jno's mean-field VI is initialised

jno overrides two of blackjax's defaults at solve-start (see jno/bayesian.py:init_state):

  1. state.mu = position (the model's initial weights), rather than blackjax's zeros — gives VI a sensible starting point on non-trivial architectures.
  2. state.rho = -3.0 everywhere (initial std ≈ 0.05), rather than blackjax's larger default — keeps initial MC ELBO samples close to the mean so the gradient estimator is low-variance from the start. The optimiser then grows rho where the posterior is genuinely wide.

API used

import blackjax, optax

u_net.vi(
    blackjax.meanfield_vi,
    optimizer=optax.adam(5e-3),
    num_samples=8,
    posterior_draws=600,
)
crux.solve(6000)              # 6000 ELBO optimisation steps

After solve, u_net.posterior_samples has shape (1, 600, *param) — exact same arviz layout as MCMC posteriors, so crux.eval(samples="auto"), jno.bayesian.{rhat, ess}, and wandb stats work identically.

References

  • Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic Differentiation Variational Inference. JMLR 18(1), 430-474.
  • Hoffman, M. D., Blei, D. M., Wang, C., & Paisley, J. (2013). Stochastic Variational Inference. JMLR 14(1), 1303-1347.
  • Yang, L., Meng, X., & Karniadakis, G. E. (2021). B-PINNs: Bayesian physics-informed neural networks for forward and inverse PDE problems with noisy data. JCP 425, 109913 (uses VI on BNN PINN in §3.2).

Script

"""09 — BNN regression via Variational Inference (mean-field)"""

from pathlib import Path

import blackjax
import foundax
import jax
import jax.numpy as jnp
import numpy as np
import optax

import jno

# ── Synthetic training data — same 32 points + gap as Tutorial 07 ───────────
sigma_obs = 0.1
rng_np = np.random.default_rng(0)

x_left = np.linspace(-0.8, -0.2, 16)
x_right = np.linspace(0.2, 0.8, 16)
x_train_np = np.concatenate([x_left, x_right]).astype(np.float32)
y_clean = np.sin(6.0 * x_train_np) ** 3
y_train_np = (y_clean + sigma_obs * rng_np.normal(size=x_train_np.shape)).astype(np.float32)

train_dom = jno.domain.from_array({"train": x_train_np.reshape(-1, 1)})
x_train_var, _ = train_dom.variable("train")
if "train" in train_dom.context:
    train_dom.context["train"] = train_dom.context["train"].astype(np.float32)
y_train_const = jnp.asarray(y_train_np).reshape(-1, 1)

# ── Bayesian MLP — every weight fit via mean-field VI ────────────────────────
u_net = jno.nn.wrap(
    foundax.mlp(
        in_features=1,
        hidden_dims=16,
        num_layers=2,
        key=jax.random.PRNGKey(0),
    )
)
# Mean-field VI: q(θ) is a product of independent Gaussians, one per
# weight.  blackjax optimises the ELBO via the supplied optax
# optimiser; after solve, ``posterior_draws`` i.i.d. samples are
# drawn from the fitted q and stored on net.posterior_samples in the
# same (1, N, *param) layout as the MCMC path.
u_net.vi(
    blackjax.meanfield_vi,
    optimizer=optax.adam(5e-3),
    num_samples=8,
    posterior_draws=600,
)

# For VI on a BNN the likelihood scaling matters: ``residual.mse`` is
# ``mean_i (y_pred_i - y_i)² / σ²``, but the canonical Gaussian-noise
# log-likelihood is the **sum** over data points.  We rescale by
# ``sqrt(N)`` so that ``residual.mse`` equals the sum-of-squared-
# standardised-residuals — this puts the right magnitude on the
# likelihood term and gives VI a strong gradient signal.
N_obs = float(x_train_np.shape[0])
y_pred = u_net(x_train_var)
residual = (y_pred - y_train_const) / sigma_obs * jnp.sqrt(N_obs)

crux = jno.core([residual.mse])
crux.solve(6000)  # 6000 ELBO optimisation steps

# ── Predict on a dense eval grid (same as T07) ──────────────────────────────
eval_dom = jno.domain.line(x_range=(-1.0, 1.0), mesh_size=0.02)
x_eval, _ = eval_dom.variable("interior")

# auto-chain: u_net carries posterior_samples from the fitted q.
u_chain = crux.eval([u_net(x_eval)], domain=eval_dom)  # (1, 600, n_eval, 1)
u_mean = jnp.mean(u_chain, axis=(0, 1))
u_lo, u_hi = jnp.quantile(u_chain, jnp.array([0.05, 0.95]), axis=(0, 1))

band_width = np.asarray(u_hi - u_lo).reshape(-1)
x_eval_np = np.asarray(eval_dom.context["interior"]).reshape(-1)

u_truth = np.sin(6.0 * x_eval_np) ** 3
u_mean_np = np.asarray(u_mean.reshape(-1))
in_data = (np.abs(x_eval_np) >= 0.2) & (np.abs(x_eval_np) <= 0.8)
in_gap = np.abs(x_eval_np) < 0.2

rel_l2_in_data = float(np.linalg.norm(u_mean_np[in_data] - u_truth[in_data]) / (np.linalg.norm(u_truth[in_data]) + 1e-8))
band_in_data = float(np.median(band_width[in_data]))
band_in_gap = float(np.median(band_width[in_gap]))
band_ratio = band_in_gap / max(band_in_data, 1e-12)

print(f"[vi-bnn] in-data rel-L2 of posterior mean : {rel_l2_in_data:.4f}")
print(f"         band width  in-data (median)     : {band_in_data:.4f}")
print(f"         band width  in-gap  (median)     : {band_in_gap:.4f}")
print(f"         gap / data band ratio            : {band_ratio:.2f}")

results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
    f.write(
        f"10_bayesian_pinns/09_vi_bnn_regression.py | epochs=2000 | "
        f"rel_L2_in_data={rel_l2_in_data:.4f} | band_in_data={band_in_data:.4f} | "
        f"band_in_gap={band_in_gap:.4f} | band_ratio={band_ratio:.2f}\n"
    )

# Loose asserts — VI is faster to converge than SGLD on this problem,
# so the in-data band should be tighter while the gap-vs-data ratio
# stays positive (uncertainty still grows in the data gap).
assert rel_l2_in_data < 0.5, f"posterior mean in-data rel-L2 too high: {rel_l2_in_data:.3e}"
assert band_in_data > 1e-4, f"in-data band collapsed: {band_in_data:.3e}"
assert band_in_gap > band_in_data, (
    f"VI predictive band did not widen in the gap (in-data {band_in_data:.4f}, gap {band_in_gap:.4f})"
)