Skip to content

W&B Integration and Explainability Callbacks

This tutorial shows how to connect a jNO training run to Weights & Biases and use the built-in explainability callbacks to understand what is happening inside the training loop.

What gets logged

Source W&B keys
GradientNormsCallback explainability/gradient_norm/constraint_0, …/constraint_N
CosSimilarityCallback explainability/cos_sim/0_1, …, explainability/cos_sim_matrix (heatmap)
GradientAlignmentCallback explainability/gradient_alignment
LossLandscapeCallback explainability/loss_landscape (heatmap image)
CheckpointCallback versioned checkpoint artifact with total_loss, individual_losses, checkpoint_dir

Step 1: Enable W&B in jno.setup

dire = jno.setup(__file__, wandb=True)

This initialises a W&B run (project name defaults to the script filename stem) and also calls weave.init("armbrul/jNO") for Weave tracing if the weave package is installed.

Pass a dict to forward any wandb.init kwargs:

jno.setup(__file__, wandb={"project": "jNO", "tags": ["poisson", "1d"]})

Step 2: Define the Problem

We solve the 1-D Poisson equation with a soft boundary condition so the solver has two separate loss terms — necessary to make the explainability metrics meaningful.

domain = jno.domain.line(mesh_size=0.05)
x,  _ = domain.variable("interior")
xb, _ = domain.variable("boundary")

u_net = jno.nn.wrap(
    foundax.mlp(in_features=1, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(0))
).optimizer(optax.adam(optax.exponential_decay(1e-3, 1_000, 0.5, end_value=1e-5)))

u  = u_net(x)
ub = u_net(xb)

pde = -u.d2(x) - jno.np.sin(π * x)
bc  = ub   # u = 0 on boundary

Step 3: Create the Explainability Callbacks

All four callbacks share the same interface: interval controls how often they run, and the optional mask lets you restrict gradient computations to a subset of parameters.

Gradient norms

cb_norms = jno.callbacks.gradient_norms(interval=50)

Tracks \(\|\nabla L_i\|_2\) for each constraint \(i\). A suddenly large norm usually signals a constraint that is dominating the update.

Cosine similarity matrix

cb_cos = jno.callbacks.cos_similarity(interval=50)

Computes the full \((N \times N)\) pairwise cosine similarity between every pair of constraint gradients. When W&B is active this is uploaded as a heatmap image.

\[\text{sim}_{ij} = \frac{\nabla L_i \cdot \nabla L_j}{\|\nabla L_i\| \|\nabla L_j\|}\]
Value Meaning
\(\approx +1\) Constraints reinforce each other
\(\approx 0\) Independent — no interaction
\(\approx -1\) Gradient conflict — one constraint hurts the other

Total gradient alignment

cb_align = jno.callbacks.gradient_alignment(interval=50)

A single scalar in \([-1, 1]\) measuring global agreement across all gradients (Eq. 3.1, [2502.00604]):

\[\text{alignment} \;=\; 2\left\|\frac{1}{N}\sum_{i=1}^{N} \frac{\nabla L_i}{\|\nabla L_i\|}\right\|^2 - 1\]

Near \(+1\) means all loss terms pull in the same direction; \(0\) means orthogonal; near \(-1\) means anti-aligned (destructive interference).

2-D loss landscape

cb_landscape = jno.callbacks.loss_landscape(
    interval=200,   # expensive — n_grid² forward passes per call
    n_grid=11,
    alpha_range=0.5,
)

Samples two random filter-normalised directions and evaluates the total loss on an \((n\_\text{grid} \times n\_\text{grid})\) grid around the current parameters. Logged as a heatmap image in W&B. A smooth bowl shape is a sign of a well-conditioned optimisation landscape; sharp ridges or irregular bumps indicate ill-conditioning.

Reducing cost

Pass mask=bool_pytree to restrict perturbations and gradient computations to a small subset of parameters (e.g. only the output layer). This can reduce cost by orders of magnitude while preserving the diagnostic signal.


Step 4: Checkpoint with W&B Artifact

cb_ckpt = jno.callbacks.checkpoint(
    directory=f"{dire}/checkpoints",
    save_interval_epochs=500,
    max_to_keep=3,
    best_fn=lambda m: m["total_loss"],
)

Each time a checkpoint is saved, jNO uploads it to W&B as a versioned checkpoint artifact. The artifact metadata includes:

{
    "epoch": 500,
    "total_loss": 0.0023,
    "individual_losses": [0.0019, 0.0004],
    "checkpoint_dir": "/path/to/checkpoints/500",
    "timestamp": 1717000000.0,
}

Step 5: Solve

crux = jno.core([pde.mse, bc.mse])
crux.solve(
    2_000,
    callbacks=[cb_norms, cb_cos, cb_align, cb_landscape, cb_ckpt],
)
cb_ckpt.close()

All callbacks register themselves in on_solve_begin (called once after the initial JIT compilation) to pre-compile their JAX functions against the live parameter shapes. This means the first call to a callback at epoch % interval == 0 runs with a pre-warmed XLA kernel.


Step 6: Read Results Locally

Even without W&B, every callback stores its history as numpy arrays:

# Gradient norms: shape (n_samples, n_constraints)
norms = cb_norms.result["norms"]

# Cosine similarity: shape (n_samples, n_constraints, n_constraints)
cos_mat = cb_cos.result["cos_sim_matrix"]

# Alignment: shape (n_samples,)
alignment = cb_align.result["alignment"]

# Landscapes: shape (n_samples, n_grid, n_grid)
landscapes = cb_landscape.result["landscapes"]

What To Notice

  • All W&B calls are no-ops when jno.setup is called without wandb=True — no behaviour change in scripts that do not need W&B.
  • The explainability callbacks use jacrev internally; they are independent of the training step and do not affect the parameter updates.
  • For large models, always provide a mask to limit which parameters are differentiated. The output-layer weights often give a good proxy signal at a fraction of the cost.
  • The gradient alignment scalar dropping during training is a reliable early warning of constraint conflict.

Script Snippet

"""09 — W&B + Weave integration with explainability callbacks"""

import foundax
import jax
import jax.numpy as jnp
import optax

import jno

π = jno.np.pi

# ── Setup: creates the run directory and initialises W&B + Weave ──────────────
# Set wandb=True to push metrics to Weights & Biases.
# The project name defaults to the script filename stem ("wandb_integration").
# Override with:  jno.setup(__file__, name="my_run", wandb={"project": "jNO"})
dire = jno.setup(__file__, wandb=True)

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")

# ── Exact solution (validation only) ──────────────────────────────────────────
u_exact = jno.np.sin(π * x) / π**2

# ── Network ───────────────────────────────────────────────────────────────────
u_net = jno.nn.wrap(
    foundax.mlp(
        in_features=1,
        hidden_dims=32,
        num_layers=3,
        key=jax.random.PRNGKey(0),
    )
).optimizer(optax.adam(optax.exponential_decay(1e-3, 1_000, 0.5, end_value=1e-5)))

u = u_net(x)
ub = u_net(xb)

# ── Constraints ───────────────────────────────────────────────────────────────
pde = -u.d2(x) - jno.np.sin(π * x)  # PDE residual
bc = ub  # u = 0 on boundary

# ── Explainability callbacks ───────────────────────────────────────────────────
# All four are logged to W&B when the run is active.
# Use interval= to control how often each is computed.
# The loss landscape is the most expensive — keep its interval large for real runs.

cb_norms = jno.callbacks.gradient_norms(
    interval=50,
    # mask=...  # optional: restrict to a subset of parameters (faster)
)

cb_cos = jno.callbacks.cos_similarity(
    interval=50,
)

cb_align = jno.callbacks.gradient_alignment(
    interval=50,
)

cb_landscape = jno.callbacks.loss_landscape(
    interval=200,  # expensive — n_grid² forward passes per call
    n_grid=11,
    alpha_range=0.5,
)

cb_residual = jno.callbacks.residual_stats(
    interval=50,
)

# Input saliency: |∂u/∂x| at the interior collocation points.
# Any jno placeholder expression compiles — try Jacobian(u, [x, y]) for 2-D problems.
cb_saliency = jno.callbacks.input_sensitivity(
    u.d(x),
    interval=50,
)

# Empirical NTK spectrum: diagnoses spectral bias.
cb_ntk = jno.callbacks.ntk_spectrum(
    u.grad(u_net),
    n_points=32,
    top_k=5,
    interval=200,
)

# Hessian eigenspectrum: top-k eigenvalues + sharpness via Lanczos w/ HVPs.
cb_hess = jno.callbacks.hessian_spectrum(
    k=5,
    n_iter=15,
    interval=400,
)

# ── Checkpoint callback ────────────────────────────────────────────────────────
# Saves to disk every 500 epochs; uploads a versioned artifact to W&B.
# Artifact metadata includes epoch, total_loss, individual_losses, checkpoint_dir.
cb_ckpt = jno.callbacks.checkpoint(
    directory=f"{dire}/checkpoints",
    save_interval_epochs=500,
    max_to_keep=3,
    best_fn=lambda m: m["total_loss"],
)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc.mse])
crux.solve(
    2_000,
    callbacks=[cb_norms, cb_cos, cb_align, cb_landscape, cb_residual, cb_saliency, cb_ntk, cb_hess, cb_ckpt],
)
cb_ckpt.close()

# ── Inspect callback results locally ──────────────────────────────────────────
norms_result = cb_norms.result
print("\n── Gradient norms ────────────────────────────────────────────────────")
print(f"  sampled epochs : {norms_result['epochs']}")
print(f"  norms shape    : {norms_result['norms'].shape}   (samples × constraints)")
print(f"  final norms    : {norms_result['norms'][-1]}")

cos_result = cb_cos.result
print("\n── Cosine similarity (final sample) ──────────────────────────────────")
print(f"  matrix:\n{cos_result['cos_sim_matrix'][-1]}")

align_result = cb_align.result
print("\n── Gradient alignment ────────────────────────────────────────────────")
print(f"  values : {align_result['alignment']}")
print("  (1.0 = perfect alignment, 0.0 = destructive interference)")

land_result = cb_landscape.result
print("\n── Loss landscape ────────────────────────────────────────────────────")
print(f"  grid shape : {land_result['landscapes'].shape}")
print(f"  final min  : {land_result['landscapes'][-1].min():.4e}")
print(f"  final max  : {land_result['landscapes'][-1].max():.4e}")

residual_result = cb_residual.result
print("\n── Residual statistics ───────────────────────────────────────────────")
print(f"  means shape : {residual_result['means'].shape}   (samples × constraints)")
print(f"  final mean  : {residual_result['means'][-1]}")
print(f"  final max   : {residual_result['maxes'][-1]}")
print(f"  final p99   : {residual_result['p99'][-1]}")

saliency_result = cb_saliency.result
print("\n── Input sensitivity ─────────────────────────────────────────────────")
print(f"  values shape : {saliency_result['values'].shape}")
final_abs = jnp.abs(saliency_result["values"][-1])
print(f"  final mean|∂u/∂x| : {float(final_abs.mean()):.4e}")
print(f"  final max |∂u/∂x| : {float(final_abs.max()):.4e}")

ntk_result = cb_ntk.result
print("\n── NTK spectrum ──────────────────────────────────────────────────────")
print(f"  top-{ntk_result['eigvals_topk'].shape[1]} eigvals (final): {ntk_result['eigvals_topk'][-1]}")
print(f"  λ_max (final)        : {ntk_result['lambda_max'][-1]:.4e}")
print(f"  condition (final)    : {ntk_result['condition_number'][-1]:.4e}")

hess_result = cb_hess.result
print("\n── Hessian eigenspectrum ─────────────────────────────────────────────")
print(f"  top-{hess_result['eigvals'].shape[1]} eigvals (final): {hess_result['eigvals'][-1]}")
print(f"  sharpness (final)    : {hess_result['sharpness'][-1]:.4e}")

# ── Validate solution ──────────────────────────────────────────────────────────
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(_u - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"\nRelative L² error: {rel_l2:.3e}")