Schedules
Learning Rate
The simplest way to set a learning rate is to bake it into the optax constructor:
optax chains work the same way — wrap any combination of optax transformations:
For dynamic schedules, attach the schedule with .scale(...). Construct the optimizer with a placeholder rate of 1 so .scale sets the effective learning rate — which it re-evaluates every step:
from jno import LearningRateSchedule as lrs
net.optimizer(optax.adam(1)).scale(lrs.exponential(1e-3, 0.9, 2000, 1e-5))
A loss-adaptive dlrs(...) (Dynamic Learning Rate Scheduler) plugs in the same way — it
raises or lowers the rate from the recent loss trend on every step:
Learning Rate Schedules
LearningRateSchedule wraps any callable (epoch, individual_losses) → scalar so it can be passed to .scale(...). Build your own:
from jno import LearningRateSchedule as lrs
# Any (epoch, losses) -> scalar callable is a schedule
lrs(lambda epoch, losses: 1e-4 * (0.9 ** (epoch / 500)))
# Adapt to a runtime signal — drop the LR when the PDE loss plateaus
lrs(lambda epoch, losses: 1e-3 if losses[0] > 1e-2 else 1e-5)
Built-in schedule factories
For common shapes, lrs ships these factories:
# Constant
lrs.constant(1e-3)
lrs(1e-3) # shorthand
# Exponential decay: lr(t) = max(lr_end, lr0 * decay_rate^(t/decay_steps))
lrs.exponential(lr0=1e-3, decay_rate=0.9, decay_steps=1000, lr_end=1e-5)
# Cosine decay
lrs.cosine(total_steps=5000, lr0=1e-3, lr_end=1e-6)
# Linear warm-up then cosine decay
lrs.warmup_cosine(total_steps=5000, warmup_steps=500, lr0=1e-3, lr_end=1e-6)
# Piecewise constant
lrs.piecewise_constant(
boundaries=[1000, 3000],
values=[1e-3, 5e-4, 1e-4], # len(boundaries) + 1 values
)
All factories — built-in and custom — accept min_lr and max_lr keyword arguments to clamp the output.
Adaptive Loss Weights
Loss weights are traced placeholders — call an adaptive balancer with your losses before passing them to jno.core, multiply the returned weights into the losses, then pass the weighted losses to the solver.
The weights are recomputed inside the compiled JAX function every step — no Python callback overhead.
Logging weights
Each weight placeholder exposes a .tracker() method that logs its value during training without contributing to the loss:
Available balancers
relobralo — Relative Loss Balancing via Residual Algorithms
Balances losses relative to their initial values, preventing any single loss from dominating as magnitudes change during training.
w0, w1 = jno.fn.adaptive.relobralo(
[pde, bc],
alpha=0.99, # exponential moving average factor
tau=0.1, # temperature for softmax normalisation
expected_rho=0.999, # target ratio for balancing
seed=42,
)
softadapt — SoftAdapt
Weights losses by the softmax of their recent rate of change — losses improving fastest are downweighted.
dwa — Dynamic Weight Average
Weights by the ratio of each loss's current value to its previous-step value, smoothed by a temperature.
lbpinns_loss_balancing — LbPINNs
Learnable log-variance weights updated via an internal Adam step each iteration.
w0, w1 = jno.fn.adaptive.lbpinns_loss_balancing(
[pde, bc],
init_s=0.0, # initial log-variance (per loss, or scalar broadcast)
lr_s=1e-2, # learning rate for the log-variance parameters
)
rlw — Random Loss Weighting
Draws weights from a Dirichlet distribution each step, providing stochastic regularisation across losses.
Loss preprocessing (mode)
All balancers accept a mode keyword that normalises the raw loss values before computing weights:
mode |
Effect |
|---|---|
"raw" (default) |
Use loss values as-is |
"minmax" |
Scale each loss to [0, 1] over its observed range |
"l2" |
Normalise by the L2 norm of the loss vector |