SFT
Attributes
attributenamestr= 'sft'attributeConfigtype= SFTConfigFunctions
func_check_inputs(self, ctx) -> NoneparamselfparamctxRunContextReturns
Nonefuncsetup(self, ctx, backend) -> NoneparamselfparamctxRunContextparambackendTinkerBackendReturns
Nonefuncbuild_batch(self, step_idx) -> TrainingBatchSlice batch_size Datums for step_idx, wrapping the dataset
when the slice straddles the end.
paramselfparamstep_idxintReturns
evsys_sdk.training.loop.TrainingBatchfuncstep_metrics(self, step_idx, batch, fb_result) -> dict[str, float]train_mean_nll from the per-position logprobs of each Datum,
weighted by the loss mask.
Tinker's cross_entropy loss returns loss_fn_outputs[i]["logprobs"]:
a per-position vector of log-probabilities of the target token (a
"perfect" prediction has logprob 0; otherwise negative). The mean NLL
is -sum(logprob * weight) / sum(weight) over the loss-mask
positions, averaged across the batch.
paramselfparamstep_idxintparambatchTrainingBatchparamfb_resultAnyReturns
dict[str, float]func_hyperparams_extra(self) -> dict[str, Any]paramselfReturns
dict[str, typing.Any]