Skip to content

Pattern B — Bayesian Last Layer (head sampled, body Adam-trained)

The practical Bayesian Last Layer recipe. A feature-extractor MLP body is trained deterministically with Adam, while the output layer ("head") is MCMC-sampled to quantify predictive uncertainty. Compared with Tutorial 10 (head Bayesian, body frozen at random init), Pattern B trains the body simultaneously so the head sees a learned feature map rather than a random one — and the posterior bands tighten dramatically.

What changed: from Pattern A to Pattern B

# Tutorial 10 — Pattern A: body frozen at init, head sampled.
net.mask(head_mask).bayesian(blackjax.sgld, step_size=1e-3, ...)

# Tutorial 14 — Pattern B: body Adam-trained, head sampled.
net.optimizer(optax.adam(5e-3))                      # NEW — body trained
net.mask(head_mask).bayesian(blackjax.sgld, ...)     # head sampled

One line — net.optimizer(...) — added on top of the Pattern A recipe. The masked Bayesian configurator was previously blocked from coexisting with a global optimiser on the same model; Phase 15 lifts that block via composite keys in opt_states (one entry per backend per layer).

How it works under the hood

At each step:

  1. Compute the full-loss gradient. Single jax.value_and_grad over the entire trainable pytree.
  2. Optax update — masked to the body. The optax chain is wrapped in optax.masked with the complement of the Bayesian mask, so only the body leaves receive updates.
  3. MCMC kernel step — masked to the head. The kernel's log-density closure reassembles the current body (just updated) with each candidate head, evaluates the loss, and samples the head.

This is a natural Metropolis-within-Gibbs / stochastic-approximation EM scheme. For K>1 chains the body's gradient is computed at the chain-0 head sample (SAEM simplification — proper averaging across chains would cost K forward passes per step).

Numbers from the tutorial

Same sin(πx) regression problem as T07 / T10 (32 noisy observations, σ = 0.05). Trained for 2300 epochs (warmup 1500, keep 400, thin 2).

Pattern Body Head rel-L2 vs sin(πx) Band median (90 %) Head leaf var
A (T10) frozen at init SGLD 0.394 0.240 6.5 × 10⁻²
B (T14) Adam-trained SGLD 0.063 0.269 2.3 × 10⁻¹

Pattern B gives a 6× tighter posterior mean at essentially the same band width — the body's Adam updates pull the feature map toward something useful, and the head's posterior tightens around it.

Composition with the rest of the API

Pattern B composes cleanly with every existing feature:

  • Multi-chain (num_chains=K) — kernel state is K-leading-masked; the body is a single shared point estimate; the body's gradient is computed at the chain-0 representative head sample.
  • Initializers.initialize(jno.bayesian.pathfinder(...)) warm-starts the masked head subset; pathfinder runs only on the head's log-density, the body's optax loop then continues with the warm-started head.
  • R-hat / ESS diagnostics — work unchanged on the masked head chain.
  • Auto-IMM injection — fires on the masked subset only (the head dimension D).

What's the same as Pattern A

The mask construction, the chain shape (K, N, *full_param), the .posterior_samples accessor, and crux.eval([net(x)], samples="auto") all behave identically to Pattern A. Drop-in: rename .mask(head_mask).bayesian(...) → keep that line, just add net.optimizer(...) before it.

When to use Pattern B vs full-net Bayesian

  • Full-net Bayesian (Tutorial 07) — every weight sampled by SGLD. Honest uncertainty propagation through the entire model, but long mixing time for high-dimensional posteriors.
  • Pattern A (Tutorial 10) — random-feature Bayesian regression with a frozen body. Useful as a baseline / demonstration; rarely optimal in practice.
  • Pattern B (this tutorial) — Bayesian Last Layer: body Adam-trained for fast convergence on the feature map; head sampled for predictive uncertainty. Often the best speed/quality tradeoff for B-PINNs and Bayesian regression on neural-network-induced features.

References

  • Snoek, J., Rippel, O., Swersky, K., Kiros, R., Satish, N., Sundaram, N., Patwary, M. M. A., Prabhat, & Adams, R. P. (2015). Scalable Bayesian Optimization Using Deep Neural Networks. §3 (Adaptive basis regression with Bayesian linear regression on the last layer). ICML 2015. arXiv:1502.05700
  • Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., & Hennig, P. (2021). Laplace Redux — Effortless Bayesian Deep Learning. §3 (last-layer Laplace). NeurIPS 2021. arXiv:2106.14806
  • Cappé, O., & Moulines, E. (2009). On-line Expectation-Maximization Algorithm for Latent Data Models. §3 (SAEM convergence theory). Journal of the Royal Statistical Society, Series B, 71(3), 593-613.

Script

"""14 — Pattern B: Bayesian last layer (head sampled, body Adam-trained)"""

from pathlib import Path

import blackjax
import equinox as eqx
import foundax
import jax
import jax.numpy as jnp
import numpy as np
import optax

import jno

# ── Synthetic data — same gapped target as T07 / T10 ────────────────────────
sigma_obs = 0.05
rng_np = np.random.default_rng(0)
x_train_np = np.linspace(-1.0, 1.0, 32).astype(np.float32)
y_clean = np.sin(np.pi * x_train_np)
y_train_np = (y_clean + sigma_obs * rng_np.normal(size=x_train_np.shape)).astype(np.float32)

train_dom = jno.domain.from_array({"train": x_train_np.reshape(-1, 1)})
x_train_var, _ = train_dom.variable("train")
if "train" in train_dom.context:
    train_dom.context["train"] = train_dom.context["train"].astype(np.float32)
y_train_const = jnp.asarray(y_train_np).reshape(-1, 1)

# ── MLP with head Bayesian, body Adam-trained ───────────────────────────────
u_net = jno.nn.wrap(
    foundax.mlp(
        in_features=1,
        hidden_dims=16,
        num_layers=2,
        key=jax.random.PRNGKey(0),
    )
)

# Build the head mask: True only under module.output_layer; False on body.
all_false = jax.tree_util.tree_map(lambda _: False, u_net.module)
head_all_true = jax.tree_util.tree_map(lambda _: True, u_net.module.output_layer)
head_mask = eqx.tree_at(lambda m: m.output_layer, all_false, replace=head_all_true)
n_head = sum(int(b) for b in jax.tree_util.tree_leaves(head_mask))
n_total = len(jax.tree_util.tree_leaves(head_mask))
print(f"[pattern-b] head leaves: {n_head} / {n_total}")

# Pattern B: body trained via Adam, head sampled via SGLD.  Phase 15
# enables both to coexist on the same model.
u_net.optimizer(optax.adam(5e-3))  # ← body (the unmasked complement)
u_net.mask(head_mask).bayesian(  # ← head
    blackjax.sgld,
    step_size=5e-4,
    warmup=1500,
    keep=400,
    thin=2,
)

# Likelihood with σ-scaled residuals.
y_pred = u_net(x_train_var)
residual = (y_pred - y_train_const) / sigma_obs

crux = jno.core([residual.mse])
crux.solve(2300)

# ── Diagnostics ─────────────────────────────────────────────────────────────
chain = u_net.posterior_samples  # (1, keep, *full pytree)
# Head leaves vary across the chain; body leaves were Adam-trained
# during sampling so they ALSO vary across the chain (each sample
# captures the body state at sample time).  The headline contract
# we check is just that the chain is populated and the head leaves
# vary.
head_idxs = [i for i, m in enumerate(jax.tree_util.tree_leaves(head_mask)) if m]
chain_leaves = jax.tree_util.tree_leaves(chain)
head_var = max(float(jnp.mean(jnp.var(chain_leaves[i], axis=1))) for i in head_idxs)
print(f"[pattern-b] head leaf max var-along-N : {head_var:.3e}  (should be > 0)")

# ── Predictive bands on a dense eval grid ───────────────────────────────────
eval_dom = jno.domain.line(x_range=(-1.0, 1.0), mesh_size=0.02)
x_eval, _ = eval_dom.variable("interior")

u_chain = crux.eval([u_net(x_eval)], domain=eval_dom)  # (1, keep, n_eval, 1)
u_mean = jnp.mean(u_chain, axis=(0, 1))
u_lo, u_hi = jnp.quantile(u_chain, jnp.array([0.05, 0.95]), axis=(0, 1))
band = np.asarray(u_hi - u_lo).reshape(-1)
x_eval_np = np.asarray(eval_dom.context["interior"]).reshape(-1)
u_truth = np.sin(np.pi * x_eval_np)
rel_l2 = float(np.linalg.norm(u_mean.reshape(-1) - u_truth) / (np.linalg.norm(u_truth) + 1e-8))
band_median = float(np.median(band))

print(f"[pattern-b] posterior-mean rel-L2 vs sin(πx) : {rel_l2:.4f}")
print(f"[pattern-b] band width median               : {band_median:.4f}")

results_file = Path(__file__).parent.parent.parent / "tutorial_results.txt"
with open(results_file, "a") as f:
    f.write(
        f"10_bayesian_pinns/14_pattern_b_bnn_head.py | epochs=2300 | "
        f"rel_L2={rel_l2:.4f} | band_median={band_median:.4f} | head_var_max={head_var:.3e}\n"
    )

# Asserts: Pattern B contract is the same as Tutorial 10's (head varies,
# fit recovers sin(πx)); the difference is the body is now trained
# rather than frozen — so the fit is materially tighter than T10's.
assert head_var > 1e-8, f"head should vary across the chain; got var={head_var:.3e}"
assert rel_l2 < 1.0, f"posterior-mean rel-L2 too high: {rel_l2:.3e}"
assert band_median > 1e-4, f"posterior band degenerate (median {band_median:.3e})"