Callbacks
Callbacks hook into the training loop without modifying the solver itself. Pass a list of callbacks to solve():
Each callback's on_epoch_end is called after every outer training step and can optionally signal early termination by returning True.
Build your own callback
Any subclass of Callback is a valid callback — override one or more of three hooks:
| Hook | Signature | When it fires |
|---|---|---|
on_solve_begin(**kw) |
returns None |
Once, after solve() finishes JIT setup, before the loop |
on_epoch_end(**kw) |
returns bool (True to stop training) |
After every outer training step |
on_training_end(**kw) |
returns None |
Once, after the loop finishes |
The **kw for each hook is documented in the base class source — the most useful keys inside on_epoch_end are epoch, total_loss, individual_losses, trainable, rng, and log.
from jno.utils.adaptive.callbacks import Callback
class LossPrinter(Callback):
def __init__(self, every: int = 100):
self.every = every
def on_epoch_end(self, **kw) -> bool:
if kw["epoch"] % self.every == 0:
print(f"epoch {kw['epoch']}: loss = {float(kw['total_loss']):.4e}")
return False # never request early stop
crux.solve(10_000, callbacks=[LossPrinter(every=500)])
Hooks you don't need can simply be omitted — the base class supplies no-op defaults. The built-in callbacks below all subclass Callback themselves; they are convenience helpers, not the only thing the system supports.
Built-in callbacks
For the common cases, jno ships these out of the box. Each is a jno.callbacks.* factory that returns a pre-configured Callback instance.
Early Stopping
Stop training automatically when a monitored metric stops improving.
cb = jno.callbacks.early_stopping(
patience=1000, # epochs with no improvement before stopping
min_delta=1e-6, # minimum change to count as improvement
mode="min", # "min", "max", or "rel"
)
crux.solve(100_000, callbacks=[cb])
print(cb.stopped_epoch) # epoch at which training halted (None if not triggered)
print(cb.best_metric) # best metric value observed
Modes
mode |
Stops when |
|---|---|
"min" |
metric hasn't dropped by more than min_delta for patience epochs |
"max" |
metric hasn't risen by more than min_delta for patience epochs |
"rel" |
metric hasn't improved by a fraction of min_delta relative to best value |
"rel" is useful when loss magnitudes vary across runs — a min_delta=0.01 means "stop if the loss hasn't improved by at least 1%".
Monitoring a custom metric
By default early stopping watches the total loss. Pass metric_fn to monitor anything available at the end of each step:
cb = jno.callbacks.early_stopping(
patience=500,
metric_fn=lambda **kw: float(kw["individual_losses"][1]), # watch constraint #1 only
)
The keyword arguments available inside metric_fn are: epoch, total_loss, individual_losses, trainable, opt_states, rng, log.
Starting from a baseline
cb = jno.callbacks.early_stopping(
patience=500,
baseline=1e-3, # stops if metric never gets below 1e-3
)
Checkpointing
Save model weights, optimizer states, and PRNG key to disk at regular intervals.
cb = jno.callbacks.checkpoint(
directory="runs/my_experiment/checkpoints",
save_interval_epochs=500, # save every 500 outer steps
max_to_keep=3, # keep only the 3 most recent checkpoints
)
crux.solve(10000, callbacks=[cb])
Keeping the best checkpoint
Pass best_fn to always retain the checkpoint with the lowest returned value, regardless of max_to_keep:
cb = jno.callbacks.checkpoint(
save_interval_epochs=200,
max_to_keep=2,
best_fn=lambda m: m["total_loss"], # keep the checkpoint with lowest total loss
)
Restoring a checkpoint
state = cb.restore() # latest checkpoint
state = cb.restore(step=2000) # specific step
# state keys: "trainable", "opt_states", "rng", "metadata"
print(state["metadata"]) # {"epoch": 2000, "total_loss": ..., "timestamp": ...}
To resume training from a restored checkpoint, reload the solver and re-attach the restored parameters:
Async checkpointing
Checkpoints are written in a background thread by default (async_checkpointing=True). Set to False for synchronous writes if you need guaranteed consistency before the process exits:
Explainability callbacks
jNO also provides callbacks for analysing gradient conflict, cosine similarity, and the loss landscape during training. See Explainability.