yamle.trainers.trainer module#

class yamle.trainers.trainer.BaseTrainer(save_path, datamodule, epochs, accelerator, devices, precision, method, gradient_clip_norm_value=0.0, gradient_clip_value=5.0, mode='train', st_checkpoint_dir=None, debug=False, compile=True, task='classification', no_saving=False, no_initial_saving=True, no_validation_saving=True, no_every_train_epoch_saving=True, no_augmentation_testing=True, profiler=None)[source]#

Bases: object

This class defines a base trainer which given a method and data loaders performs training and evaluation.

Parameters:
  • save_path (str) – The path to the experiment folder.

  • datamodule (BaseDataModule) – The datamodule to be used for training or evaluation.

  • epochs (int) – The number of epochs to train for.

  • accelerator (str) – The accelerator to be used for training or evaluation. (cpu, gpu, ddp, ddp2, ddp_spawn, auto)

  • devices (List[int]) – The devices to be used for training or evaluation.

  • precision (int) – The precision to be used for training or evaluation.

  • method (BaseMethod) – The method to be used for training or evaluation.

  • gradient_clip_norm_value (float) – The gradient clipping value when clipping by norm. Defaults to 0.0.

  • gradient_clip_value (float) – The gradient clipping value when clipping by value. Defaults to 5.0.

  • mode (str) – The mode of the trainer. (train, eval, tune)

  • st_checkpoint_dir (Optional[str]) – The path to the Syne-Tune checkpoint directory. Defaults to None.

  • debug (bool) – Whether to run in debug mode. Defaults to False.

  • compile (bool) – Whether to compile the model. Defaults to True.

  • task (str) – The task to be performed.

  • no_saving (bool) – Whether to not do any kind of saving. Defaults to False.

  • no_initial_saving (bool) – Whether to save the initial model. Defaults to True.

  • no_validation_saving (bool) – Whether to save the validation model. Defaults to True.

  • no_every_train_epoch_saving (bool) – Whether to save the model after every epoch. Defaults to True.

  • no_augmentation_testing (bool) – Whether to use augmentation during testing. Defaults to True.

  • profiler (str) – The profiler to be used for debugging. (simple, advanced, None)

fit(results=None)[source]#

This method trains the method and the embedded model.

Returns the time it took to train the model.

Return type:

float

property interrupted: bool#

This property returns whether the training was interrupted.

test(results)[source]#

This method tests the method and the embedded model.

The results are stored in the given dictionary. Returns the time it took to test the model.

Return type:

float

calibrate()[source]#

This is a helper function which runs calibration data through the model.

Note that the Trainer is not used here, the model is not trained - the gradients are not updated. We need to manually run the calibration data through the validation_step method of the method.

Returns the time it took to calibrate the model.

Return type:

float

fine_tune(epochs)[source]#

This method fine-tunes the method and the embedded model.

It does it for additional epochs with respect to a new trainer. The fine-tuning is done on the calibration data.

Returns the time it took to fine-tune the model.

Return type:

float

static add_specific_args(parser)[source]#

This method adds trainer arguments to the given parser.

Parameters:

parser (ArgumentParser) – The parser to which the arguments should be added.

Return type:

ArgumentParser