W&B Integration and Explainability Callbacks
This tutorial shows how to connect a jNO training run to Weights & Biases and use the built-in explainability callbacks to understand what is happening inside the training loop.
What gets logged
| Source | W&B keys |
|---|---|
GradientNormsCallback |
explainability/gradient_norm/constraint_0, …/constraint_N |
CosSimilarityCallback |
explainability/cos_sim/0_1, …, explainability/cos_sim_matrix (heatmap) |
GradientAlignmentCallback |
explainability/gradient_alignment |
LossLandscapeCallback |
explainability/loss_landscape (heatmap image) |
CheckpointCallback |
versioned checkpoint artifact with total_loss, individual_losses, checkpoint_dir |
Step 1: Enable W&B in jno.setup
This initialises a W&B run (project name defaults to the script filename stem) and also calls weave.init("armbrul/jNO") for Weave tracing if the weave package is installed.
Pass a dict to forward any wandb.init kwargs:
Step 2: Define the Problem
We solve the 1-D Poisson equation with a soft boundary condition so the solver has two separate loss terms — necessary to make the explainability metrics meaningful.
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")
u_net = jno.nn.wrap(
foundax.mlp(in_features=1, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(0))
).optimizer(optax.adam(optax.exponential_decay(1e-3, 1_000, 0.5, end_value=1e-5)))
u = u_net(x)
ub = u_net(xb)
pde = -u.d2(x) - jno.np.sin(π * x)
bc = ub # u = 0 on boundary
Step 3: Create the Explainability Callbacks
All four callbacks share the same interface: interval controls how often they run, and the optional mask lets you restrict gradient computations to a subset of parameters.
Gradient norms
Tracks \(\|\nabla L_i\|_2\) for each constraint \(i\). A suddenly large norm usually signals a constraint that is dominating the update.
Cosine similarity matrix
Computes the full \((N \times N)\) pairwise cosine similarity between every pair of constraint gradients. When W&B is active this is uploaded as a heatmap image.
| Value | Meaning |
|---|---|
| \(\approx +1\) | Constraints reinforce each other |
| \(\approx 0\) | Independent — no interaction |
| \(\approx -1\) | Gradient conflict — one constraint hurts the other |
Total gradient alignment
A single scalar in \([-1, 1]\) measuring global agreement across all gradients (Eq. 3.1, [2502.00604]):
Near \(+1\) means all loss terms pull in the same direction; \(0\) means orthogonal; near \(-1\) means anti-aligned (destructive interference).
2-D loss landscape
cb_landscape = jno.callbacks.loss_landscape(
interval=200, # expensive — n_grid² forward passes per call
n_grid=11,
alpha_range=0.5,
)
Samples two random filter-normalised directions and evaluates the total loss on an \((n\_\text{grid} \times n\_\text{grid})\) grid around the current parameters. Logged as a heatmap image in W&B. A smooth bowl shape is a sign of a well-conditioned optimisation landscape; sharp ridges or irregular bumps indicate ill-conditioning.
Reducing cost
Pass mask=bool_pytree to restrict perturbations and gradient computations to a small subset of parameters (e.g. only the output layer). This can reduce cost by orders of magnitude while preserving the diagnostic signal.
Step 4: Checkpoint with W&B Artifact
cb_ckpt = jno.callbacks.checkpoint(
directory=f"{dire}/checkpoints",
save_interval_epochs=500,
max_to_keep=3,
best_fn=lambda m: m["total_loss"],
)
Each time a checkpoint is saved, jNO uploads it to W&B as a versioned checkpoint artifact. The artifact metadata includes:
{
"epoch": 500,
"total_loss": 0.0023,
"individual_losses": [0.0019, 0.0004],
"checkpoint_dir": "/path/to/checkpoints/500",
"timestamp": 1717000000.0,
}
Step 5: Solve
crux = jno.core([pde.mse, bc.mse])
crux.solve(
2_000,
callbacks=[cb_norms, cb_cos, cb_align, cb_landscape, cb_ckpt],
)
cb_ckpt.close()
All callbacks register themselves in on_solve_begin (called once after the initial JIT compilation) to pre-compile their JAX functions against the live parameter shapes. This means the first call to a callback at epoch % interval == 0 runs with a pre-warmed XLA kernel.
Step 6: Read Results Locally
Even without W&B, every callback stores its history as numpy arrays:
# Gradient norms: shape (n_samples, n_constraints)
norms = cb_norms.result["norms"]
# Cosine similarity: shape (n_samples, n_constraints, n_constraints)
cos_mat = cb_cos.result["cos_sim_matrix"]
# Alignment: shape (n_samples,)
alignment = cb_align.result["alignment"]
# Landscapes: shape (n_samples, n_grid, n_grid)
landscapes = cb_landscape.result["landscapes"]
What To Notice
- All W&B calls are no-ops when
jno.setupis called withoutwandb=True— no behaviour change in scripts that do not need W&B. - The explainability callbacks use
jacrevinternally; they are independent of the training step and do not affect the parameter updates. - For large models, always provide a
maskto limit which parameters are differentiated. The output-layer weights often give a good proxy signal at a fraction of the cost. - The gradient alignment scalar dropping during training is a reliable early warning of constraint conflict.
Script Snippet
"""09 — W&B + Weave integration with explainability callbacks"""
import foundax
import jax
import jax.numpy as jnp
import optax
import jno
π = jno.np.pi
# ── Setup: creates the run directory and initialises W&B + Weave ──────────────
# Set wandb=True to push metrics to Weights & Biases.
# The project name defaults to the script filename stem ("wandb_integration").
# Override with: jno.setup(__file__, name="my_run", wandb={"project": "jNO"})
dire = jno.setup(__file__, wandb=True)
# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")
# ── Exact solution (validation only) ──────────────────────────────────────────
u_exact = jno.np.sin(π * x) / π**2
# ── Network ───────────────────────────────────────────────────────────────────
u_net = jno.nn.wrap(
foundax.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
key=jax.random.PRNGKey(0),
)
).optimizer(optax.adam(optax.exponential_decay(1e-3, 1_000, 0.5, end_value=1e-5)))
u = u_net(x)
ub = u_net(xb)
# ── Constraints ───────────────────────────────────────────────────────────────
pde = -u.d2(x) - jno.np.sin(π * x) # PDE residual
bc = ub # u = 0 on boundary
# ── Explainability callbacks ───────────────────────────────────────────────────
# All four are logged to W&B when the run is active.
# Use interval= to control how often each is computed.
# The loss landscape is the most expensive — keep its interval large for real runs.
cb_norms = jno.callbacks.gradient_norms(
interval=50,
# mask=... # optional: restrict to a subset of parameters (faster)
)
cb_cos = jno.callbacks.cos_similarity(
interval=50,
)
cb_align = jno.callbacks.gradient_alignment(
interval=50,
)
cb_landscape = jno.callbacks.loss_landscape(
interval=200, # expensive — n_grid² forward passes per call
n_grid=11,
alpha_range=0.5,
)
cb_residual = jno.callbacks.residual_stats(
interval=50,
)
# Input saliency: |∂u/∂x| at the interior collocation points.
# Any jno placeholder expression compiles — try Jacobian(u, [x, y]) for 2-D problems.
cb_saliency = jno.callbacks.input_sensitivity(
u.d(x),
interval=50,
)
# Empirical NTK spectrum: diagnoses spectral bias.
cb_ntk = jno.callbacks.ntk_spectrum(
u.grad(u_net),
n_points=32,
top_k=5,
interval=200,
)
# Hessian eigenspectrum: top-k eigenvalues + sharpness via Lanczos w/ HVPs.
cb_hess = jno.callbacks.hessian_spectrum(
k=5,
n_iter=15,
interval=400,
)
# ── Checkpoint callback ────────────────────────────────────────────────────────
# Saves to disk every 500 epochs; uploads a versioned artifact to W&B.
# Artifact metadata includes epoch, total_loss, individual_losses, checkpoint_dir.
cb_ckpt = jno.callbacks.checkpoint(
directory=f"{dire}/checkpoints",
save_interval_epochs=500,
max_to_keep=3,
best_fn=lambda m: m["total_loss"],
)
# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc.mse])
crux.solve(
2_000,
callbacks=[cb_norms, cb_cos, cb_align, cb_landscape, cb_residual, cb_saliency, cb_ntk, cb_hess, cb_ckpt],
)
cb_ckpt.close()
# ── Inspect callback results locally ──────────────────────────────────────────
norms_result = cb_norms.result
print("\n── Gradient norms ────────────────────────────────────────────────────")
print(f" sampled epochs : {norms_result['epochs']}")
print(f" norms shape : {norms_result['norms'].shape} (samples × constraints)")
print(f" final norms : {norms_result['norms'][-1]}")
cos_result = cb_cos.result
print("\n── Cosine similarity (final sample) ──────────────────────────────────")
print(f" matrix:\n{cos_result['cos_sim_matrix'][-1]}")
align_result = cb_align.result
print("\n── Gradient alignment ────────────────────────────────────────────────")
print(f" values : {align_result['alignment']}")
print(" (1.0 = perfect alignment, 0.0 = destructive interference)")
land_result = cb_landscape.result
print("\n── Loss landscape ────────────────────────────────────────────────────")
print(f" grid shape : {land_result['landscapes'].shape}")
print(f" final min : {land_result['landscapes'][-1].min():.4e}")
print(f" final max : {land_result['landscapes'][-1].max():.4e}")
residual_result = cb_residual.result
print("\n── Residual statistics ───────────────────────────────────────────────")
print(f" means shape : {residual_result['means'].shape} (samples × constraints)")
print(f" final mean : {residual_result['means'][-1]}")
print(f" final max : {residual_result['maxes'][-1]}")
print(f" final p99 : {residual_result['p99'][-1]}")
saliency_result = cb_saliency.result
print("\n── Input sensitivity ─────────────────────────────────────────────────")
print(f" values shape : {saliency_result['values'].shape}")
final_abs = jnp.abs(saliency_result["values"][-1])
print(f" final mean|∂u/∂x| : {float(final_abs.mean()):.4e}")
print(f" final max |∂u/∂x| : {float(final_abs.max()):.4e}")
ntk_result = cb_ntk.result
print("\n── NTK spectrum ──────────────────────────────────────────────────────")
print(f" top-{ntk_result['eigvals_topk'].shape[1]} eigvals (final): {ntk_result['eigvals_topk'][-1]}")
print(f" λ_max (final) : {ntk_result['lambda_max'][-1]:.4e}")
print(f" condition (final) : {ntk_result['condition_number'][-1]:.4e}")
hess_result = cb_hess.result
print("\n── Hessian eigenspectrum ─────────────────────────────────────────────")
print(f" top-{hess_result['eigvals'].shape[1]} eigvals (final): {hess_result['eigvals'][-1]}")
print(f" sharpness (final) : {hess_result['sharpness'][-1]:.4e}")
# ── Validate solution ──────────────────────────────────────────────────────────
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(_u - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"\nRelative L² error: {rel_l2:.3e}")