Skip to content

Bayesian sampling — .bayesian(...) per parameter

model.bayesian(kernel_factory, **kw) replaces a model's per-step gradient update with one transition of a blackjax MCMC kernel. The configurator mirrors .optimizer(...) — each parameter in your script can independently choose to be optimised (point estimate via optax) or sampled (posterior chain via blackjax). Mix freely; crux.solve() dispatches per-model.

After training the chain is available on the model as model.posterior_samples, and on the crux itself via crux.eval([...], samples="chain") which vmaps the evaluator over the chain so nonlinear predictions push forward correctly.

Quick example — Bayesian inverse problem

Recover (A, B) in d(x) = A·sin(πx) + B·cos(πx) from noisy observations, with credible intervals:

import blackjax, jax, jno
import jax.numpy as jnp

π = jno.np.pi
dom = jno.domain(constructor=jno.domain.line(mesh_size=0.01))
x, _ = dom.variable("interior")

A_true, B_true = 3.14, -2.71
target = A_true * jno.np.sin(π * x) + B_true * jno.np.cos(π * x) \
       + jno.noise.gaussian(std=0.1)

k1, k2 = jax.random.split(jax.random.PRNGKey(0), 2)
a = jno.np.parameter((1,), key=k1, name="a")
b = jno.np.parameter((1,), key=k2, name="b")

for p in [a, b]:
    # inverse_mass_matrix defaults to identity of the right shape;
    # adapt=True (default) tunes step_size and IMM via blackjax.window_adaptation.
    p.bayesian(blackjax.nuts, step_size=1e-2, warmup=500, keep=1000)

residual = a * jno.np.sin(π * x) + b * jno.np.cos(π * x) - target
crux = jno.core([residual.mse])
crux.solve(1500)

# Raw chains — leading axis = sample
a_chain = a.posterior_samples           # (1000, 1)
A_mean = jnp.mean(a_chain, axis=0)
A_lo, A_hi = jnp.quantile(a_chain, jnp.array([0.05, 0.95]), axis=0)
print(f"A = {A_mean[0]:.3f}  [{A_lo[0]:.3f}, {A_hi[0]:.3f}]")

API

Model.bayesian(
    kernel_factory,             # e.g. blackjax.nuts, blackjax.sgld
    *,
    prior=None,                 # callable: pytree -> log p(θ);  default Gaussian(σ=10)
    warmup=500,                 # outer epochs to discard before collecting
    keep=1000,                  # number of post-warmup samples to retain
    thin=1,                     # keep one sample every `thin` post-warmup steps
    **kernel_kwargs,            # forwarded to kernel_factory; `step_size=` is required
                                # EXCEPT for HMC-family kernels with adapt=True + warmup>0
                                # (window adaptation chooses one — defaults to 1.0).
)

kernel_factory is duck-typed at solve time:

First parameter of factory Family Examples
logdensity_fn Full-data MCMC blackjax.nuts, blackjax.hmc, blackjax.mala
grad_estimator Stochastic-gradient blackjax.sgld, blackjax.sghmc

jno builds the appropriate closure from the live loss + context and rebuilds the kernel inside the JIT graph each step.

Custom kernel factories — (logdensity_fn, **kw) → SamplingAlgorithm

You are not limited to the blackjax kernels listed above. kernel_factory is anything that returns a blackjax.SamplingAlgorithm (a NamedTuple(init, step)). jno detects which family it belongs to by inspecting the first parameter name:

  • logdensity_fn → full-data MCMC; jno passes a closure θ → log p(data | θ) + log p(θ).
  • grad_estimator → stochastic-gradient MCMC; jno passes a mini-batch gradient closure.

A barebones random-walk Metropolis factory, for illustration:

import blackjax, jax
import jax.numpy as jnp

def my_rwm(logdensity_fn, step_size):
    def init_fn(position):
        return {"position": position, "logdensity": logdensity_fn(position)}

    def step_fn(rng_key, state):
        k1, k2 = jax.random.split(rng_key)
        prop = jax.tree.map(
            lambda p: p + step_size * jax.random.normal(k1, p.shape),
            state["position"],
        )
        new_logd = logdensity_fn(prop)
        accept = jnp.log(jax.random.uniform(k2)) < new_logd - state["logdensity"]
        new_state = {
            "position":  jax.tree.map(lambda x, y: jnp.where(accept, y, x),
                                      state["position"], prop),
            "logdensity": jnp.where(accept, new_logd, state["logdensity"]),
        }
        return new_state, {"accepted": accept}

    return blackjax.SamplingAlgorithm(init_fn, step_fn)

net.bayesian(my_rwm, step_size=1e-2, warmup=500, keep=1000)

The built-in kernels in the table above all match this protocol — your factory plugs in the same way and goes through the same warmup / thin / keep pipeline.

Output — chains by default

crux.eval(...) is auto-chain-aware per expression: if an expression's dependency graph touches a model with posterior_samples set, the evaluator is vmap-ped over that chain. Otherwise the expression is evaluated at the point value as before. No samples= argument is needed for the common case.

Read .optimizer() (point) .bayesian() (chain)
crux.eval([m]) point value (n_kept, *m_shape) chain
crux.eval([expr]) (n_points, …) point (n_kept, n_points, …) chain (auto)
m.posterior_samples None stacked module pytree (or array for parameter)

No .mean() / .std() / .quantile() helpers are provided — compute whatever summary you need from the chain with jnp.mean, jnp.quantile, arviz, or your favourite plotting library.

Nonlinear pushforward — handled automatically

For predictions through a neural network, the posterior mean over outputs is not the output at the posterior mean of the weights. This used to require an explicit samples="chain"; the default now does the right thing:

u_chain = crux.eval([u])                          # (n_kept, n_points, 1)
u_mean = jnp.mean(u_chain, axis=0)
u_lo, u_hi = jnp.quantile(u_chain, jnp.array([0.05, 0.95]), axis=0)

Escape hatches

crux.eval([u], samples="chain")    # force chain (raises if no Bayesian deps)
crux.eval([u], samples="point")    # force point: evaluate at last sample,
                                   # skips the vmap. Quick debugging / sanity.

The samples="point" mode returns a single sample at the model's current position — useful when you just want a quick number, but not a substitute for the posterior summary on nonlinear outputs.

Mixed: optimised + sampled

Different parameters use different update rules; they coexist in one crux.solve():

encoder = jno.nn.wrap(foundax.mlp(2, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(0)))
head    = jno.nn.wrap(foundax.mlp(1, hidden_dims=32, num_layers=1, key=jax.random.PRNGKey(1)))

encoder.optimizer(optax.adam(1e-3))                          # point estimate
head.bayesian(blackjax.sgld, step_size=1e-5)                 # SGLD chain

Only head.posterior_samples is populated; encoder.posterior_samples is None.

Bayesian PINN — predictive bands

import blackjax, foundax, jax, jno
import jax.numpy as jnp

π = jno.np.pi
dom = jno.domain(constructor=jno.domain.line(mesh_size=0.01))
x, _  = dom.variable("interior")
xb, _ = dom.variable("boundary")

net = jno.nn.wrap(foundax.mlp(1, hidden_dims=32, num_layers=3,
                              key=jax.random.PRNGKey(0)))
net.bayesian(blackjax.sgld, step_size=1e-5, warmup=2000, keep=1000)

u    = net(x)
u_xx = jno.diff(u, x, order=2)
pde  = u_xx + (π ** 2) * jno.np.sin(π * x)        # u'' = -π² sin(πx)
bc   = net(xb) - 0.0

crux = jno.core([pde.mse, bc.mse])
crux.solve(3000)

u_chain = crux.eval([u], samples="chain")
u_mean     = jnp.mean(u_chain, axis=0)
u_lo, u_hi = jnp.quantile(u_chain, jnp.array([0.05, 0.95]), axis=0)

Priors — built-in factories

The prior= argument takes any pytree → float returning the log-prior density. Four built-in factories live at jno.bayesian.priors.*; each one returns a callable that obeys the contract above.

Factory Form When to use
priors.gaussian(sigma=10.0, fan_in_aware=False) \(-\|\theta\|^2 / (2\sigma^2)\) Wide default (σ=10) is "effectively flat"; pass smaller σ for shrinkage. fan_in_aware=True scales σ by 1/√fan_in per weight tensor.
priors.laplace(scale=1.0) \(-\|\theta\|_1 / \text{scale}\) Sparse-friendly: encourages many components near zero.
priors.student_t(df=4.0, scale=1.0) \(-\frac{df+1}{2} \sum \log\big(1 + (\theta/\text{scale})^2 / df\big)\) Heavy-tailed alternative to Gaussian; practical substitute for horseshoe on individual weights. df must be > 2 for finite variance.
priors.layerwise_gaussian(base_sigma=1.0, default_sigma=1.0, fan_in_aware=True) Per-leaf \(N(0, \sigma_\text{leaf}^2)\) with \(\sigma_\text{weight} = \text{base}/\sqrt{\text{fan\_in}}\), \(\sigma_\text{bias} = \text{default}\) The standard BNN-PINN prior (Sun et al. 2019, Wenzel et al. 2020).
import jno

# Wide Gaussian, σ=10 — the historical default
a.bayesian(blackjax.nuts, step_size=1e-2,
           prior=jno.bayesian.priors.gaussian(sigma=10.0))

# Sparse coefficient — Laplace
a.bayesian(blackjax.nuts, step_size=1e-2,
           prior=jno.bayesian.priors.laplace(scale=0.5))

# BNN head with fan-in-aware layer-wise priors
head.mask(M).bayesian(blackjax.sgld, step_size=1e-3,
                     prior=jno.bayesian.priors.layerwise_gaussian())

When prior=None the internal default is priors.gaussian(sigma=10.0) — effectively flat at typical parameter scales, but for BNN weights at scale 0.01 it's overly wide and for outputs at scale 100 it's overly tight. Prefer one of the named factories above for non-trivial problems.

Custom priors — pytree → float contract

Any pytree → float callable works. For example, an L½ prior:

def l_half_prior(p, scale=1.0):
    return -sum(
        jnp.sum(jnp.sqrt(jnp.abs(leaf)))
        for leaf in jax.tree_util.tree_leaves(p)
        if hasattr(leaf, "dtype") and jnp.issubdtype(leaf.dtype, jnp.floating)
    ) / scale

a.bayesian(blackjax.nuts, step_size=1e-2, prior=l_half_prior)

Masked priors see only the masked subset

When configured via .mask(M).bayesian() / .mask(M).vi(), the prior closure receives the masked subset of the position (whatever the kernel sees) — not the full model pytree. The built-in factories iterate over the leaves they're handed, so they handle masked and unmasked solves identically. Custom priors should be aware that p is whatever subset the kernel operates on; if you need the full pytree (e.g. for a hierarchical prior coupling masked and unmasked leaves), use a global .bayesian() rather than .mask(M).bayesian().

Adaptation (NUTS / HMC)

For HMC-family kernels, adapt=True (default) runs blackjax.window_adaptation for the first warmup steps before the main solve loop. Step size and inverse mass matrix are tuned automatically; the loop then collects keep samples from epoch 0.

a.bayesian(blackjax.nuts, step_size=1.0, warmup=500, keep=1000)   # adapt=True default
# → window adaptation tunes step_size + inverse_mass_matrix in 500 steps,
#   loop collects 1000 samples from the adapted state.

For non-adaptive kernels (mala, sgld, sghmc) adapt= is silently ignored and warmup=N retains the classic "discard the first N samples" meaning.

Mixed mode (Bayesian + optax)

Window adaptation runs once at the start, with all non-Bayesian models at their initial weights. In mixed setups (e.g. a Bayesian coefficient + an optax-trained surrogate) the adapted step size is tuned against the untrained surrogate's logdensity and is typically wrong for the actual joint problem. Set adapt=False and pick step_size by hand in that case.

Logdensity-aware initializers (.initialize() extension)

Model.initialize(...) already takes a path, a pytree, or a stateless (shape, dtype, key) -> array callable. A fourth shape is now accepted: any object with requires_logdensity = True whose __call__ runs inside solve() with access to the loss-derived log-density. Pathfinder (Zhang et al. 2022) is the first concrete implementation.

import jno, blackjax

a.initialize(jno.bayesian.pathfinder(maxiter=30, num_samples=200))
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=100, adapt=True)

Pathfinder runs L-BFGS on the log-density and turns the inverse-Hessian trajectory into a normal approximation to the posterior. From the fitted q jno extracts (a) a warm starting position (the MAP for K=1 chains; K i.i.d. samples for K>1 — proper over-dispersion) and (b) a diagonal inverse_mass_matrix from the per-dimension variance of M draws. .bayesian()'s warmup and adapt then apply after pathfinder.

Behaviour matrix

.initialize(pathfinder(...)) adapt What runs
not set True window adaptation from the user's init (today's default)
not set False user's init, user's step_size — no warmup
set True pathfinder → window: warm position + IMM, then window refines step_size
set False pathfinder only: warm position + pathfinder's IMM, user's step_size kept

The protocol — _BayesianInitializer

Minimal, single-method contract. Any class with requires_logdensity = True is detected by Model.initialize and dispatched the same way at solve time:

class _BayesianInitializer:
    requires_logdensity: ClassVar[bool] = True

    def __call__(self, rng_key, logdensity_fn, position, num_chains):
        # → (new_position, extra_kwargs_update)
        ...
  • logdensity_fn is already mask-wrapped for .mask(M).bayesian() groups — subclasses see only the masked subset.
  • For num_chains > 1, return a (K, *leaf)-leading pytree (one warm position per chain).
  • extra_kwargs_update is merged into the kernel handle's extra_kwargs; keys the kernel doesn't accept (e.g. an IMM update against a MALA kernel) are silently dropped.

Composition

  • Masks.mask(M).bayesian() + pathfinder works: pathfinder runs on the masked subset's log-density; the unmasked complement stays at init.
  • Multi-chain — pathfinder samples K distinct starting positions from the fitted q. init_jitter is silently overridden when an initializer is set.
  • Non-IMM kernels (MALA / SGLD / SGHMC) — warm position is applied; the IMM update is silently dropped (signature gate).
  • substeps= + initializer → clear error (the initializer runs on the full loss; substep kernels see only substep-local constraints).
  • .vi(...) + initializer → clear error (VI initialises its own variational distribution from the position; warm-start is redundant).

Future initializers (same hook, no further core changes)

Slot Algorithm Notes
jno.bayesian.pathfinder(...) blackjax pathfinder This release. See Tutorial 11.
jno.bayesian.laplace(...) MAP via optax + jax.hessian (diagonal or full) This release. See Tutorial 12. MacKay 1992 §6; Daxberger et al. 2021 §2.
jno.bayesian.svgd(...) blackjax svgd This release. See Tutorial 13. K particles → K chain inits. Liu & Wang 2016 §3.
jno.bayesian.map(...) Fixed-step optax warm-start Future. No IMM output; user keeps step_size.
User-written subclass anything that fits the contract See Tutorial 11.

A worked example lives at Tutorial 11.

Multiple chains

Pass num_chains=K (default 1) to run K independent MCMC chains in parallel via jax.vmap:

a.bayesian(
    blackjax.nuts,
    step_size=1e-2,
    num_chains=4,
    init_jitter=0.1,   # per-chain Gaussian perturbation of the initial position
    warmup=300,
    keep=400,
)

After crux.solve(), a.posterior_samples has shape (K, N, *param) — the canonical arviz layout (chain, draw, *). All Bayesian models in a single solve() must share the same num_chains; mismatched values raise at solve-start.

init_jitter > 0 over-disperses the K starting positions so R-hat is conservative (chains forced apart at start; if they reconverge to the same distribution, that's strong evidence of convergence).

A single window-adaptation sweep runs at start (PyMC convention) and its adapted step-size + inverse mass matrix are broadcast to all K chains; per-chain adaptation can be enabled in a follow-up if needed.

Convergence diagnostics

Two pure-JAX helpers on jno.bayesian operate directly on the (K, N, *param) chain layout — no arviz dep:

Helper What it computes Threshold
jno.bayesian.rhat(chain) Vehtari et al. 2021 split, rank-normalised, folded R-hat < 1.01 (strict) or < 1.05 (lenient) → converged
jno.bayesian.ess(chain) Effective sample size via FFT-based autocorrelation + Geyer 1992 truncation > 100 per parameter typically sufficient

Both return arrays of shape *param, one diagnostic per parameter component.

Example:

chain = a.posterior_samples              # (K, N, 1)
r = jno.bayesian.rhat(chain)             # → (1,)
e = jno.bayesian.ess(chain)              # → (1,)
print(f"A: R-hat = {float(r[0]):.4f}, ESS = {float(e[0]):.1f}")

rhat strategies — when K=1 falls back silently

rhat(chain, strategy=...) controls what happens for single-chain input. Three values:

strategy K==1 behaviour K>=2 behaviour
"auto" (default) Split-R-hat on the two halves (Gelman et al. 2014 BDA3 §11.4) Multichain R-hat
"multichain" Raises ValueError — loud failure when you expected multiple chains Multichain R-hat
"split" Split-R-hat on the two halves Split every chain in half → 2K chains, then multichain R-hat (extra stationarity check)

Use "multichain" when you've explicitly configured num_chains>=2 and want a hard failure if the chain layout doesn't carry them (catches bugs where you thought you had 4 chains but posterior_samples came back as (1, N, *param)).

See Tutorial 08 for a worked end-to-end example.

Per-step kernel diagnostics — model.posterior_diagnostics

Every blackjax kernel call returns an info NamedTuple alongside the new state. jno captures the fields that matter for the kernel family in use and aggregates them across the chain into model.posterior_diagnostics, a {field: (K, N) array} dict:

Kernel family Captured fields
NUTS / HMC is_divergent (bool), acceptance_rate (float), energy (float)
MALA acceptance_rate only
SGLD / SGHMC (SG-MCMC) None — these kernels have no info NamedTuple
Mean-field VI None — track ELBO via history.total_loss instead
a.bayesian(blackjax.nuts, step_size=1e-2, warmup=500, keep=1000)
crux.solve(...)

diag = a.posterior_diagnostics              # {"is_divergent": (K,N), ...}
n_divergent = int(diag["is_divergent"].sum())
acc = float(diag["acceptance_rate"].mean())  # target: 0.6–0.8 for NUTS

is_divergent is the single most diagnostic signal of an unhealthy NUTS / HMC run. More than ~1% divergent transitions almost always means the integrator's step_size is too large for the local posterior curvature — drop step_size, raise inverse_mass_matrix to match the geometry, or run window adaptation (adapt=True).

The same information surfaces in three other places so problems are loud:

  1. wandb (per print-rate chunk) — posterior/<name>/n_is_divergent, posterior/<name>/mean_acceptance_rate, posterior/<name>/mean_energy.
  2. Solve-end summary — one log line per Bayesian model::

    INFO: Model 'a': 12/1000 divergent (1.20%), mean_accept=0.81, mean_energy=42.3

  3. Handle-creation log — each .bayesian() configuration logs the diagnostic schema it'll track::

    INFO: Model 1: Bayesian sampling via 'nuts' (kind=full, ..., diagnostics=is_divergent, acceptance_rate, energy)

Kernels that surface no info object (SG-MCMC, VI) are flagged explicitly at handle-creation time (diagnostics=none (kernel API has no info object)) — never silently downgraded.

Pure-Bayesian fastpath (automatic)

When a solve() call qualifies as pure-Bayesian — every Bayesian model on the same num_chains, warmup, keep, and thin; no .optimizer() models in the same solve; no substeps=; no offload_data=True; no trackers; no adaptive resampling; and inner_steps == 1solve() auto-dispatches to a scan-based fastpath that closes three perf gaps in the per-epoch Python loop:

  1. No outer value_and_grad. The slow path runs jax.value_and_grad(loss_wrapper)(trainable) every step and discards the gradients when no optax models are present. The fastpath omits that pass entirely — the MCMC kernels still compute their own gradients via the existing logdensity / grad-estimator closures.
  2. One XLA dispatch per print_rate steps. warmup steps run in a jax.lax.fori_loop with no sample accumulation; the post-warmup phase runs in a jax.lax.scan chunked at print_rate outer iterations, with thin inner steps per outer iteration in a nested fori_loop. Typical solve with print_rate ≈ keep / 10: ~10 XLA dispatches instead of epochs of them.
  3. One host transfer per chunk. Samples are stacked inside XLA and returned as a single (chunk_keep, K, *param) tensor per Bayesian model.

The dispatch is fully automatic — there is no kwarg to set. At solve-start jno logs a single line so the decision is visible::

INFO: MCMC fastpath: scan over 500 samples × thin=1 + 0 warmup, chunked at print_rate=80.

If your solve doesn't qualify (mixed-mode, substeps, streaming, …) the per-epoch Python loop runs exactly as before with all its features intact. The fastpath's output (posterior_samples, wandb metrics, history) matches the slow path's at the same print_rate cadence, just with fewer datapoints between print boundaries.

Variational Inference (mean-field)

Model.vi(...) fits a variational approximation to the posterior through the same crux.solve() driver as .bayesian() — but optimises the evidence lower bound (ELBO) of a Gaussian product q(θ) = ∏_i N(μ_i, σ_i) instead of running an MCMC chain. After solve(), posterior_draws i.i.d. samples are drawn from the fitted q and stored on the model as posterior_samples in the same (1, N, *param) layout as the MCMC path:

import blackjax, optax

a.vi(
    blackjax.meanfield_vi,
    optimizer=optax.adam(1e-3),
    num_samples=8,           # MC samples per ELBO eval
    posterior_draws=500,     # draws from fitted q for posterior_samples
)
crux.solve(2000)              # 2000 ELBO optimisation steps
chain = a.posterior_samples   # (1, 500, *param) — draws from fitted q

Configuration mirrors the MCMC .bayesian() API: same downstream crux.eval(samples="auto") plumbing, same jno.bayesian.{rhat, ess} helpers (which trivially report ≈ 1 and ≈ N respectively for VI draws — see caveat below).

Aspect MCMC Mean-field VI
Mechanism per-step Metropolis-Hastings / Langevin / leapfrog per-step ELBO optimisation
Cost High (many forward passes per sample) Low (one MC ELBO eval per step)
Calibration Asymptotically exact Diagonal-covariance lower bound
Multi-modal Multi-chain reveals modes Captures one mode
posterior_samples shape (K, N, *param) from collected chain (1, posterior_draws, *param) drawn from fitted q

Two overrides on blackjax's defaults at init_state time make VI converge usefully on non-trivial models. Both are exposed as kwargs on Model.vi(...):

  • init_mu_at_position=True (default) — state.mu starts at the model's initial weights instead of blackjax's zeros. Matches the numpyro autoguide convention; pass False to restore blackjax's zero start.
  • init_log_std=-3.0 (default → σ ≈ 0.05) — state.rho starts small everywhere instead of blackjax's broader init (σ ≈ 1). Keeps the initial MC ELBO sample close to the mean so the gradient estimator is low-variance from the start. The optimiser then grows rho where the posterior is genuinely wide. Pass init_log_std=0.0 (σ ≈ 1) to restore blackjax's default.
a.vi(
    blackjax.meanfield_vi,
    optimizer=optax.adam(1e-3),
    init_log_std=-3.0,       # tight initial q
    init_mu_at_position=True, # mu = current weights
)

Likelihood scaling for VI — likelihood_scale=

The canonical Gaussian-noise log-likelihood is a sum over data points, but jno's residual.mse returns the mean — so by default the likelihood term in the ELBO is N× too small and the prior dominates, leaving VI stuck near initialisation.

Pass likelihood_scale=N_obs (or N_obs / sigma**2 more generally) on .vi(...) so the ELBO uses the correct magnitude::

a.vi(
    blackjax.meanfield_vi,
    optimizer=optax.adam(1e-3),
    likelihood_scale=N_obs,   # canonical sum-over-data weighting
)

The same kwarg works on .bayesian(...) for MCMC kernels — less critical (HMC's geometry is more robust to magnitude) but still correct. Available on both Pattern A (global) and Pattern B/D (masked) configurators.

See Tutorial 09 for a worked end-to-end example.

Diagnostics caveat

jno.bayesian.rhat and jno.bayesian.ess still run on VI posterior_samples, but they trivially report ≈ 1 and ≈ N respectively because the draws are i.i.d. from the fitted q. For VI convergence monitoring, watch the ELBO trajectory in history (via crux.solve(...).total_loss per-chunk values) instead.

Mutual exclusion

A single model has either .bayesian() or .vi() — never both. Setting one after the other raises a clear error. Models with VI and models with MCMC can coexist in the same solve call — each runs its own paradigm during the step loop.

Composable per-mask backends (v1)

.mask(M) followed by .bayesian(...) (or .vi(...)) restricts the posterior to the subset of the model's parameter pytree where M is True. Leaves outside the mask stay at their initial value throughout solve() — they are not updated.

import equinox as eqx

# Mark only the output layer ("head") as Bayesian; body stays at init.
all_false     = jax.tree_util.tree_map(lambda _: False, net.module)
head_all_true = jax.tree_util.tree_map(lambda _: True, net.module.output_layer)
head_mask     = eqx.tree_at(lambda m: m.output_layer, all_false,
                            replace=head_all_true)

net.mask(head_mask).bayesian(blackjax.sgld, step_size=1e-3,
                              warmup=1500, keep=400, thin=2)

After solve, net.posterior_samples stores the full pytree at every sample. Leaves inside the mask vary across the chain; leaves outside the mask are constant. crux.eval([net(x)], samples="auto"), jno.bayesian.{rhat, ess}, and wandb stats all work uniformly — no special case for masked solves.

A worked example lives at Tutorial 10.

What's supported

  • Pattern A: .mask(M).bayesian(...) (or .vi(...)) on a model with no global .optimizer(...). Body stays at init; masked subset is the posterior.
  • Pattern B (Phase 15): .mask(M).bayesian(...) + global .optimizer(...) on the same model. Body is Adam-trained; head is MCMC-sampled. K=1 and K>1 both supported: for K>1 the body's gradient is computed at the chain-0 representative head (SAEM simplification). Tutorial: T14.

!!! warning "Pattern B + K>1 — SAEM chain-0 representative" When num_chains>1 with Pattern B, only one of the K head chains influences the body's optax update each step (chain 0). All K chains still explore the head's posterior, but the body sees a single representative — equivalent to SAEM-style joint inference, not K independent head+body solves. jno emits a one-line WARNING at solve-start naming the trade-off so it is visible without reading the docs. Pass num_chains=1 if you want independent head+body runs (one solve per chain). * Pattern D (Phase 16): multiple disjoint .mask().bayesian() groups on the same model. Each group's kernel state lives at its own composite key ("<lid>.<group_idx>") in opt_states; the step loop iterates groups in sorted order (Metropolis-within-Gibbs cycle for K=1; SAEM-style chain-0 representative for K>1). * Pattern E (Phase 16): mixed VI + MCMC on disjoint masks of the same model. MCMC accumulates per-step samples; at solve end the VI handle draws posterior_draws i.i.d. samples from its fitted distribution and splices them into the MCMC chain at the VI mask's leaves. Strict matching: VI's posterior_draws must equal the MCMC group's keep (validated at solve start). * Masked + num_chains > 1 (Phase 15): masked Bayesian solves with K parallel chains work for Pattern A, B, and D. * .lora() + .bayesian() (no mask) — the LoRA partition already restricts trainable parameters to the LoRA adapters; .bayesian() samples that restricted subset.

State storage (composite keys)

Internally, opt_states uses two key formats:

  • Bare "<lid>" — optax states (one per layer with .optimizer(...)).
  • Composite "<lid>.<group_idx>" — Bayesian / VI kernel states (one per masked group; bare-.bayesian() layers use "<lid>.0").

A Pattern B + D layer therefore carries entries like {"1": optax_state, "1.0": kernel_g0, "1.1": kernel_g1} simultaneously. The helpers jno.core._lid_of(k), _group_idx_of(k), _bay_key(lid, gi) parse / build these keys.

What's still blocked

The composite-key scheme generalises further (multiple VI groups beyond one per layer, overlapping masks, etc.) but the current implementation only validates the patterns above. Patterns not in the supported list above fall back to clear NotImplementedError / ValueError at solve start.

What posterior_samples looks like

The chain stores the full model pytree (both masked and unmasked leaves) with leading axes (K, N, *param_shape) for K=1 single-chain solves. Unmasked leaves are constant along the chain axis; masked leaves vary. This full-pytree storage is a memory cost: for very narrow masks on wide models, sparse storage (varying leaves only + a frozen-leaves snapshot, reassembled lazily) is documented as a v2 follow-up. The user-facing API will not change.

Wandb integration

When a wandb run is active, per-Bayesian-model statistics are logged at the same print-rate cadence as the rest of the training metrics:

Key Meaning
posterior/<name>/n_samples Number of samples collected in the chain so far
posterior/<name>/n_chains num_chains for this model (1 for the default)
posterior/<name>/mean Running posterior mean (scalar parameters only)
posterior/<name>/n_is_divergent Running divergent-transition count (NUTS / HMC only)
posterior/<name>/mean_acceptance_rate Running mean MH acceptance rate (NUTS / HMC / MALA)
posterior/<name>/mean_energy Running mean Hamiltonian energy (NUTS / HMC only)

<name> comes from the name= argument of jno.np.parameter(...) or jno.nn.wrap(..., name=...). Multi-leaf modules (MLPs) only get the chain length — full per-leaf statistics are out of scope for jno-side logging; use arviz against model.posterior_samples instead.

Memory

A full chain costs ~keep × #params × 4 bytes per Bayesian model. For large BNN PINNs increase thin= or decrease keep= to stay within GPU/CPU memory.

References

  • NUTS — Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15(1), 1593–1623.
  • SGLD — Welling, M., & Teh, Y. W. (2011). Bayesian Learning via Stochastic Gradient Langevin Dynamics. ICML 2011, 681–688.

The kernels themselves come from blackjax — jno only wires their (state, key) → state interface into the per-model step dispatch.

Limitations (this release)

  • VI (blackjax.vi.*) has different mechanics (ELBO optimisation) and is not yet routed through .bayesian().
  • Multi-chain — chains are single-chain only. Running K chains for R-hat / cross-chain diagnostics needs K separate solve() calls and manual stacking (or a follow-up that wires jax.vmap over seeds).
  • Discrete posteriors (e.g. over Choice selections) need SMC and are out of scope.
  • Custom forward models outside the jNO tracer (FEM solver, ODE integrator, finite volume) can't be wrapped in .bayesian() directly — the API expects the forward to be expressible as a jNO Placeholder expression. For those cases use blackjax directly with jNO supplying the differentiable forward; see Inverse FEM Diffusivity for the pattern.

Combining with substeps=

substeps= is supported with .bayesian() to enable the two-stage decoupled inference pattern: substep 0 trains a surrogate via optax, substep 1 runs one NUTS proposal on a Bayesian coefficient with the surrogate stop_gradient-ed. See the section index of the Bayesian PINNs tutorial chapter. When using substeps you must set adapt=False on the Bayesian model — window adaptation runs against the full loss but the kernel only sees the substep-local constraint set.