yamle.trainers.calibration module#

class yamle.trainers.calibration.CalibrationTrainer(calibration_epochs, *args, **kwargs)[source]#

Bases: BaseTrainer

This class defines a temperature trainer which first trains the model and then calibrates it.

The training is on the training set and the calibration is on the calibration set.

Parameters:

calibration_epochs (int) – The number of epochs to calibrate the model.

fit(train_dataloader, validation_dataloader)[source]#

This method trains the method and then does the calibration.

Parameters:
  • train_dataloader (DataLoader) – The dataloader to be used for training.

  • validation_dataloader (DataLoader) – The dataloader to be used for validation.

Return type:

float

static add_specific_args(parent_parser)[source]#

This method adds trainer arguments to the given parser.

Parameters:

parent_parser (ArgumentParser) – The parser to which the arguments should be added.

Return type:

ArgumentParser