Skip to content

Weights & Biases

jNO has first-class W&B support. Enabling it adds automatic metric logging, checkpoint artifacts, weight histograms, and Weave tracing with a single flag in jno.setup.


Enabling W&B

Pass wandb=True to jno.setup:

dire = jno.setup(__file__, wandb=True)

This calls wandb.init (project name defaults to the script filename stem), logs source code via run.log_code(), and initialises Weave tracing via weave.init("armbrul/jNO") if the weave package is installed.

To forward extra kwargs to wandb.init, pass a dict:

jno.setup(__file__, wandb={"project": "jNO", "tags": ["poisson", "1d"], "group": "sweep-01"})

Any key not supplied falls back to the default (project → script stem, dir → run directory).


What gets logged automatically

Source W&B keys / type
Training loss (every step) loss, constraint_0, constraint_1, …
CheckpointCallback versioned checkpoint artifact
Weight histograms weights/<model>/<layer>
GradientNormsCallback explainability/gradient_norm/constraint_N
CosSimilarityCallback explainability/cos_sim/i_j + heatmap image
GradientAlignmentCallback explainability/gradient_alignment
LossLandscapeCallback explainability/loss_landscape (heatmap image)

Everything in the table below the first row requires the corresponding callback to be passed to solve(). See Explainability for details on the explainability callbacks.


Checkpoint artifacts

When CheckpointCallback saves a checkpoint and a W&B run is active, it uploads the checkpoint directory as a versioned checkpoint artifact. The artifact metadata includes:

{
    "epoch": 500,
    "total_loss": 0.0023,
    "individual_losses": [0.0019, 0.0004],
    "checkpoint_dir": "/path/to/runs/checkpoints/500",
    "timestamp": 1717000000.0,
}
cb = jno.callbacks.checkpoint(
    directory=f"{dire}/checkpoints",
    save_interval_epochs=500,
    max_to_keep=3,
    best_fn=lambda m: m["total_loss"],
)
crux.solve(5000, callbacks=[cb])
cb.close()

Alerts

Send a W&B alert from anywhere in your script:

from jno.utils.config import wandb_alert

wandb_alert("NaN detected", f"Loss exploded at epoch {epoch}", level="WARN")

level is one of "INFO", "WARN", "ERROR". The call is a no-op when no W&B run is active.


Helper functions

jno.utils.config exposes three thin wrappers used internally; you can call them directly if you need fine-grained control:

from jno.utils.config import get_wandb_run, wandb_log, wandb_log_model

# Check whether a run is active
run = get_wandb_run()   # returns the wandb.Run or None

# Log arbitrary metrics at a specific step
wandb_log({"my_metric": 0.42}, step=1000)

# Upload a model as an artifact
wandb_log_model(my_pytree, name="best_model")

All three are no-ops when get_wandb_run() returns None.


Full example

A runnable script that combines all four explainability callbacks, checkpointing, and W&B logging is available in the tutorial examples:

"""09 — W&B + Weave integration with explainability callbacks"""

import foundax
import jax
import jax.numpy as jnp
import optax

import jno

π = jno.np.pi

# ── Setup: creates the run directory and initialises W&B + Weave ──────────────
# Set wandb=True to push metrics to Weights & Biases.
# The project name defaults to the script filename stem ("wandb_integration").
# Override with:  jno.setup(__file__, name="my_run", wandb={"project": "jNO"})
dire = jno.setup(__file__, wandb=True)

# ── Domain ────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.05)
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")

# ── Exact solution (validation only) ──────────────────────────────────────────
u_exact = jno.np.sin(π * x) / π**2

# ── Network ───────────────────────────────────────────────────────────────────
u_net = jno.nn.wrap(
    foundax.mlp(
        in_features=1,
        hidden_dims=32,
        num_layers=3,
        key=jax.random.PRNGKey(0),
    )
).optimizer(optax.adam(optax.exponential_decay(1e-3, 1_000, 0.5, end_value=1e-5)))

u = u_net(x)
ub = u_net(xb)

# ── Constraints ───────────────────────────────────────────────────────────────
pde = -u.d2(x) - jno.np.sin(π * x)  # PDE residual
bc = ub  # u = 0 on boundary

# ── Explainability callbacks ───────────────────────────────────────────────────
# All four are logged to W&B when the run is active.
# Use interval= to control how often each is computed.
# The loss landscape is the most expensive — keep its interval large for real runs.

cb_norms = jno.callbacks.gradient_norms(
    interval=50,
    # mask=...  # optional: restrict to a subset of parameters (faster)
)

cb_cos = jno.callbacks.cos_similarity(
    interval=50,
)

cb_align = jno.callbacks.gradient_alignment(
    interval=50,
)

cb_landscape = jno.callbacks.loss_landscape(
    interval=200,  # expensive — n_grid² forward passes per call
    n_grid=11,
    alpha_range=0.5,
)

cb_residual = jno.callbacks.residual_stats(
    interval=50,
)

# Input saliency: |∂u/∂x| at the interior collocation points.
# Any jno placeholder expression compiles — try Jacobian(u, [x, y]) for 2-D problems.
cb_saliency = jno.callbacks.input_sensitivity(
    u.d(x),
    interval=50,
)

# Empirical NTK spectrum: diagnoses spectral bias.
cb_ntk = jno.callbacks.ntk_spectrum(
    u.grad(u_net),
    n_points=32,
    top_k=5,
    interval=200,
)

# Hessian eigenspectrum: top-k eigenvalues + sharpness via Lanczos w/ HVPs.
cb_hess = jno.callbacks.hessian_spectrum(
    k=5,
    n_iter=15,
    interval=400,
)

# ── Checkpoint callback ────────────────────────────────────────────────────────
# Saves to disk every 500 epochs; uploads a versioned artifact to W&B.
# Artifact metadata includes epoch, total_loss, individual_losses, checkpoint_dir.
cb_ckpt = jno.callbacks.checkpoint(
    directory=f"{dire}/checkpoints",
    save_interval_epochs=500,
    max_to_keep=3,
    best_fn=lambda m: m["total_loss"],
)

# ── Solve ─────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, bc.mse])
crux.solve(
    2_000,
    callbacks=[cb_norms, cb_cos, cb_align, cb_landscape, cb_residual, cb_saliency, cb_ntk, cb_hess, cb_ckpt],
)
cb_ckpt.close()

# ── Inspect callback results locally ──────────────────────────────────────────
norms_result = cb_norms.result
print("\n── Gradient norms ────────────────────────────────────────────────────")
print(f"  sampled epochs : {norms_result['epochs']}")
print(f"  norms shape    : {norms_result['norms'].shape}   (samples × constraints)")
print(f"  final norms    : {norms_result['norms'][-1]}")

cos_result = cb_cos.result
print("\n── Cosine similarity (final sample) ──────────────────────────────────")
print(f"  matrix:\n{cos_result['cos_sim_matrix'][-1]}")

align_result = cb_align.result
print("\n── Gradient alignment ────────────────────────────────────────────────")
print(f"  values : {align_result['alignment']}")
print("  (1.0 = perfect alignment, 0.0 = destructive interference)")

land_result = cb_landscape.result
print("\n── Loss landscape ────────────────────────────────────────────────────")
print(f"  grid shape : {land_result['landscapes'].shape}")
print(f"  final min  : {land_result['landscapes'][-1].min():.4e}")
print(f"  final max  : {land_result['landscapes'][-1].max():.4e}")

residual_result = cb_residual.result
print("\n── Residual statistics ───────────────────────────────────────────────")
print(f"  means shape : {residual_result['means'].shape}   (samples × constraints)")
print(f"  final mean  : {residual_result['means'][-1]}")
print(f"  final max   : {residual_result['maxes'][-1]}")
print(f"  final p99   : {residual_result['p99'][-1]}")

saliency_result = cb_saliency.result
print("\n── Input sensitivity ─────────────────────────────────────────────────")
print(f"  values shape : {saliency_result['values'].shape}")
final_abs = jnp.abs(saliency_result["values"][-1])
print(f"  final mean|∂u/∂x| : {float(final_abs.mean()):.4e}")
print(f"  final max |∂u/∂x| : {float(final_abs.max()):.4e}")

ntk_result = cb_ntk.result
print("\n── NTK spectrum ──────────────────────────────────────────────────────")
print(f"  top-{ntk_result['eigvals_topk'].shape[1]} eigvals (final): {ntk_result['eigvals_topk'][-1]}")
print(f"  λ_max (final)        : {ntk_result['lambda_max'][-1]:.4e}")
print(f"  condition (final)    : {ntk_result['condition_number'][-1]:.4e}")

hess_result = cb_hess.result
print("\n── Hessian eigenspectrum ─────────────────────────────────────────────")
print(f"  top-{hess_result['eigvals'].shape[1]} eigvals (final): {hess_result['eigvals'][-1]}")
print(f"  sharpness (final)    : {hess_result['sharpness'][-1]:.4e}")

# ── Validate solution ──────────────────────────────────────────────────────────
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(_u - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"\nRelative L² error: {rel_l2:.3e}")