Gradient and Sensitivity Analysis
This tutorial shows how to use u.grad(net) — the parameter Jacobian — to monitor what a PINN is learning during training. The key technique is computing the cosine similarity between domain regions as a .tracker(), so you can spot gradient conflict before the solve finishes.
Concepts
The Parameter Jacobian
Every trained network defines a mapping \(\mathbf{u}(\mathbf{x}; \theta)\) from spatial points to outputs. The parameter Jacobian is:
u.grad(net) returns a symbolic NetworkGradient expression. It is traced just like any other Placeholder — you can pass it to jnn.function, include it in a tracker, or evaluate it with crux.eval.
Gradient Cosine Similarity
To compare how two groups of collocation points interact during training, compress each group's Jacobian rows into a single sensitivity direction:
Then compute the cosine similarity between \(\mathbf{g}_A\) and \(\mathbf{g}_B\):
| Value | Meaning |
|---|---|
| \(\approx +1\) | Aligned — learning from group \(A\) also helps group \(B\) |
| \(\approx 0\) | Orthogonal — the two groups are independent |
| \(\approx -1\) | Conflict — improving group \(A\) hurts group \(B\) |
Why use a sparse mask?
Computing the full Jacobian over all \(P\) parameters is expensive. For in-training monitoring you only need a signal, not the exact answer. Restricting to the output-layer weights gives the dominant gradient directions at a fraction of the cost.
net.mask(bool_pytree) stores the selection; the next call to u.grad(net.mask(...)) reads it:
all_false = jax.tree_util.tree_map(lambda _: False, u_net.module)
output_mask = eqx.tree_at(lambda m: m.output_layer.weight, all_false, True)
J = u.grad(u_net.mask(output_mask)) # traced; shape (N, P_out) at eval time
Problem Setup
We solve the familiar 1D Poisson equation:
Exact solution: \(u(x) = \sin(\pi x) / \pi^2\).
domain = jno.domain.line(mesh_size=0.1)
x, _ = domain.variable("interior")
u_net = jno.nn.wrap(
foundax.mlp(in_features=1, hidden_dims=32, num_layers=3, key=jax.random.PRNGKey(0))
).optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.5, end_value=1e-5)))
u = u_net(x) * x * (1 - x) # hard BC: u(0) = u(1) = 0
u_xx = u.d2(x)
pde = -u_xx - jno.np.sin(π * x)
Step 1: Build the Sparse Mask
import equinox as eqx, jax
all_false = jax.tree_util.tree_map(lambda _: False, u_net.module)
output_mask = eqx.tree_at(lambda m: m.output_layer.weight, all_false, True)
# Symbolic Jacobian — only the output-layer weight, shape (N, P_out_weight)
J = u.grad(u_net.mask(output_mask))
J is a symbolic NetworkGradient node. Nothing is computed yet — it just records the network and the mask.
Step 2: Define the Cosine Similarity as a Tracker
Wrap the cosine similarity calculation in a plain JAX function and use jnn.function to lift it into the symbolic graph, then attach it as a non-loss tracker:
from jno.numpy import tracker
def _cos_sim_halves(J):
N = J.shape[0]
mid = N // 2
g_left = J[:mid].mean(axis=0)
g_right = J[mid:].mean(axis=0)
denom = jnp.linalg.norm(g_left) * jnp.linalg.norm(g_right) + 1e-12
return jnp.dot(g_left, g_right) / denom
cos_tracker = tracker(jno.np.function(_cos_sim_halves, [J]), interval=200)
interval=200 means it is evaluated and logged every 200 epochs without contributing to the gradient.
Step 3: Solve
During training the log will show the cosine similarity alongside the PDE loss. For a smooth, symmetric solution like \(\sin(\pi x)/\pi^2\) you expect the value to stay positive (ideally \(>0.5\)) throughout — meaning both halves of the domain reinforce the same parameter updates.
A value that drops toward zero or turns negative during training is a warning: the network is beginning to represent the two halves in nearly-orthogonal (or conflicting) parts of parameter space.
What J measures during training
u.grad(net) gives \(\partial u / \partial \theta\), the Jacobian of the network output w.r.t. parameters.
For PDE losses that penalize derivatives of \(u\), the cosine similarity gives a correct picture of output-level sensitivity. The true loss gradient involves additional terms from differentiating through spatial derivatives, so treat this as a diagnostic signal rather than an exact measure of gradient conflict.
Step 4: Post-training Analysis
After solving, evaluate the Jacobian directly. When crux.eval receives a single expression, it returns the raw array without a batch dimension:
Compute the final cosine similarity:
g_left = J_sparse[:N // 2].mean(axis=0)
g_right = J_sparse[N // 2:].mean(axis=0)
cos_sim = float(
jnp.dot(g_left, g_right)
/ (jnp.linalg.norm(g_left) * jnp.linalg.norm(g_right) + 1e-12)
)
print(f"cos_sim (left vs right) = {cos_sim:.4f}")
Step 5: Neural Tangent Kernel (Full Jacobian)
For a deeper analysis, evaluate the full Jacobian (all parameters) after training:
[J_full] = crux.eval([u.grad(u_net)]) # (N, P_total)
K = J_full @ J_full.T
eigvals = jnp.sort(jnp.linalg.eigvalsh(K))[::-1]
eff_rank = float(jnp.sum(eigvals)**2 / (jnp.sum(eigvals**2) + 1e-12))
cond = float(eigvals[0] / (eigvals[-1] + 1e-12))
print(f"Effective rank = {eff_rank:.2f}")
print(f"Condition number = {cond:.1f}")
The effective rank (participation ratio) tells you how many independent learning modes the network uses. The NTK condition number measures how uniformly different spatial patterns are learned: a high condition number means some patterns converge much slower than others.
What To Notice
- The cosine similarity during training lets you catch gradient conflict early — long before the loss plateaus.
- A sparse mask (output layer only) makes the tracker cheap enough to run every 200 epochs even on large networks.
crux.eval([single_expr])returns the array without a batch dimension; you don't need to strip a leading[0].- A low effective rank (close to 1) means the network is learning in a near-1D subspace — widen the network or increase collocation points.
- Cosine similarity \(< 0.3\) between two groups that should behave similarly is a warning: consider adding more collocation points or using adaptive resampling.
Script Snippet
"""07 — Gradient and sensitivity analysis with u.grad(net)"""
import equinox as eqx
import foundax
import jax
import jax.numpy as jnp
import optax
import jno
π = jno.np.pi
# ── Domain ─────────────────────────────────────────────────────────────────────
domain = jno.domain.line(mesh_size=0.001)
x, _ = domain.variable("interior")
xb, _ = domain.variable("boundary")
# ── Exact solution (for validation only, not used in training) ─────────────────
u_exact = jno.np.sin(π * x) / π**2
# ── Network with hard-enforced BCs ─────────────────────────────────────────────
u_net = jno.nn.wrap(
foundax.mlp(
in_features=1,
hidden_dims=32,
num_layers=3,
key=jax.random.PRNGKey(0),
)
).optimizer(optax.adam(optax.exponential_decay(1e-3, 10, 0.5, end_value=1e-5)))
u = u_net(x)
u_xx = u.d2(x)
pde = -u_xx - jno.np.sin(π * x) # residual — should be 0
ub = u_net(xb)
# ── In-training cosine similarity tracker ─────────────────────────────────────
# Build a boolean mask selecting only the output-layer weight matrix.
# This makes the Jacobian fast to compute — P_out_weight ≪ P_total.
all_false = jax.tree_util.tree_map(lambda _: False, u_net.params)
output_mask = eqx.tree_at(lambda m: m.output_layer.weight, all_false, True)
# Symbolic Jacobian restricted to the masked parameters.
# J shape at eval time: (N, P_out_weight)
J1 = pde.mse.grad(u_net.mask(output_mask))
J2 = ub.mse.grad(u_net.mask(output_mask))
cos_tracker = jno.np.dot(J1, J2).tracker(100)
# ── Solve ──────────────────────────────────────────────────────────────────────
crux = jno.core([pde.mse, ub.mse, cos_tracker])
crux.solve(5000)
_u, _u_exact = crux.eval([u, u_exact])
rel_l2 = float(jnp.linalg.norm(_u - _u_exact) / (jnp.linalg.norm(_u_exact) + 1e-8))
print(f"Relative L² error: {rel_l2:.3e}")
assert rel_l2 < 1e-1, f"solution error too large: {rel_l2:.3e}"
# ── Post-training: gradient alignment between PDE and BC loss ─────────────────
# crux.eval([single_expr]) returns the raw array without a batch dimension.
[g_pde] = crux.eval([J1]) # (P_out_weight,) — gradient of PDE loss
[g_bc] = crux.eval([J2]) # (P_out_weight,) — gradient of BC loss
cos_sim = float(jnp.dot(g_pde, g_bc) / (jnp.linalg.norm(g_pde) * jnp.linalg.norm(g_bc) + 1e-12))
print("\nGradient alignment (PDE vs BC, output-layer params)")
print(f" dot(g_pde, g_bc) = {float(jnp.dot(g_pde, g_bc)):.4f}")
print(f" cos_sim = {cos_sim:.4f}")
if cos_sim > 0.5:
print(" → Strongly aligned: PDE and BC losses reinforce each other.")
elif cos_sim > 0:
print(" → Weakly aligned: compatible but partially independent.")
else:
print(" → Conflict: PDE and BC losses are pulling parameters in opposite directions!")
assert cos_sim > -1.0, f"cos_sim out of range: {cos_sim:.4f}"
# ── Post-training: full Jacobian + Neural Tangent Kernel ──────────────────────
# Clear the output-layer mask so we get the full (N, P_total) Jacobian.
[J_full] = crux.eval([u.grad(u_net.mask(None))]) # (N, P_total)
N, P_total = J_full.shape
print(f"\nFull Jacobian J shape: {J_full.shape} ({P_total} parameters)")
K = J_full @ J_full.T # (N, N)
# Clip small negative eigenvalues (numerical noise from semi-definite K).
eigvals = jnp.maximum(jnp.sort(jnp.linalg.eigvalsh(K))[::-1], 0.0)
eff_rank = float(jnp.sum(eigvals) ** 2 / (jnp.sum(eigvals**2) + 1e-12))
cond = float(eigvals[0] / (eigvals[-1] + 1e-12))
print(f"\nNeural Tangent Kernel K ({N}×{N})")
print(f" λ_max = {float(eigvals[0]):.4f}")
print(f" λ_min = {float(eigvals[-1]):.4f}")
print(f" Eff. rank = {eff_rank:.2f} (trace² / ‖K‖²_F)")
print(f" Cond. number = {cond:.1f}")