Callback
Override any of these. Defaults are no-ops so callbacks don't have to implement every hook.
Hook signatures are positional + typed (not a generic event-bag) so IDE autocomplete and static type checking still work. Each hook fires at exactly one moment in the loop - see the docstrings.
All hooks are sync (def, not async). If you need to schedule
async work, spawn asyncio.create_task(...) from inside the hook.
Functions
funcon_train_start(self, state) -> NoneFires once before the for-loop starts. Open log files, register a wandb run, snapshot the config - anything that should happen before the first step.
paramselfparamstateLoopStateReturns
Nonefuncon_train_end(self, state, artifacts) -> NoneFires once after the loop completes (including the final checkpoint save). Flush summary writes, close files.
paramselfparamstateLoopStateparamartifacts'LoopArtifacts'Returns
Nonefuncon_step_end(self, state, step_idx, batch, metrics) -> NoneFires after every train step's metric row is written. The universal "do something per step" hook (printing, plotting, custom metric derivations, gradient debugging).
paramselfparamstateLoopStateparamstep_idxintparambatch'TrainingBatch'parammetricsdict[str, float]Returns
Nonefuncon_checkpoint(self, state, row) -> NoneFires after each checkpoint manifest row is recorded. Useful for shipping to S3, pruning old checkpoints, kicking a side eval.
paramselfparamstateLoopStateparamrow'ManifestRow'Returns
Nonefuncon_eval(self, state, step_idx, eval_name, metrics) -> NoneFires per evaluator after each in-loop eval completes. Useful
for pushing to a dashboard, plotting val curves, driving
early-stopping decisions. (Logger callbacks that also need the
rollout predictions should use :meth:on_benchmark_eval, which the
loop fires for benchmark evaluators with the full payload.)
paramselfparamstateLoopStateparamstep_idxintparameval_namestrparammetricsdict[str, float]Returns
Nonefuncon_experiment_start(self, ctx) -> NoneFires once at the start of an experiment, before any arm. A logger
creates the experiment record here (ctx.ids['experiment_id'] = ...).
paramselfparamctxLogContextReturns
Nonefuncon_group_start(self, ctx, group_name) -> NoneFires when a new run-group is needed (n_repeats replicates or
continual stages). A logger creates the group
(ctx.ids[f'group:\{group_name\}'] = ...).
paramselfparamctxLogContextparamgroup_namestrReturns
Nonefuncon_run_start(self, ctx) -> NoneFires per arm, before training. ctx.run_config is set. A logger
opens its run-scoped sink (wandb.init / create_run →
ctx.ids['run_id']), reading ctx.ids['experiment_id'] /
ctx.ids[f'group:\{ctx.group_name\}'] to parent it.
paramselfparamctxLogContextReturns
Nonefuncon_benchmark_eval(self, ctx, eval_result, predictions, *, step=None) -> NoneFires per benchmark scored - in-loop (step = the train step) or
post-training (step=None). Carries metrics + breakdowns + tags
(on eval_result) AND the per-task prediction rows. A logger
creates one eval row per (eval_result.name, step) and persists the
predictions.
paramselfparamctxLogContextparameval_result'EvalResult'parampredictionslist[dict]paramstepint | None= NoneReturns
Nonefuncon_run_end(self, ctx, run_result, arm) -> NoneFires per arm, after eval, before the run is marked completed. A logger flushes/closes its run-scoped sink (wandb.finish) and records the final status (update_run).
paramselfparamctxLogContextparamrun_result'RunResult'paramarm'ArmResult'Returns
Nonefuncon_experiment_end(self, ctx, result) -> NoneFires once at the end of the experiment. Final summary / flush.
paramselfparamctxLogContextparamresult'ExperimentResult'Returns
None