yamle.trainers.trainer module#
- class yamle.trainers.trainer.BaseTrainer(save_path, datamodule, epochs, accelerator, devices, precision, method, gradient_clip_norm_value=0.0, gradient_clip_value=5.0, mode='train', st_checkpoint_dir=None, debug=False, compile=True, task='classification', no_saving=False, no_initial_saving=True, no_validation_saving=True, no_every_train_epoch_saving=True, no_augmentation_testing=True, profiler=None)[source]#
Bases:
object
This class defines a base trainer which given a method and data loaders performs training and evaluation.
- Parameters:
save_path¶ (str) – The path to the experiment folder.
datamodule¶ (BaseDataModule) – The datamodule to be used for training or evaluation.
epochs¶ (int) – The number of epochs to train for.
accelerator¶ (str) – The accelerator to be used for training or evaluation. (cpu, gpu, ddp, ddp2, ddp_spawn, auto)
devices¶ (List[int]) – The devices to be used for training or evaluation.
precision¶ (int) – The precision to be used for training or evaluation.
method¶ (BaseMethod) – The method to be used for training or evaluation.
gradient_clip_norm_value¶ (float) – The gradient clipping value when clipping by norm. Defaults to 0.0.
gradient_clip_value¶ (float) – The gradient clipping value when clipping by value. Defaults to 5.0.
mode¶ (str) – The mode of the trainer. (train, eval, tune)
st_checkpoint_dir¶ (Optional[str]) – The path to the Syne-Tune checkpoint directory. Defaults to None.
debug¶ (bool) – Whether to run in debug mode. Defaults to False.
compile¶ (bool) – Whether to compile the model. Defaults to True.
task¶ (str) – The task to be performed.
no_saving¶ (bool) – Whether to not do any kind of saving. Defaults to False.
no_initial_saving¶ (bool) – Whether to save the initial model. Defaults to True.
no_validation_saving¶ (bool) – Whether to save the validation model. Defaults to True.
no_every_train_epoch_saving¶ (bool) – Whether to save the model after every epoch. Defaults to True.
no_augmentation_testing¶ (bool) – Whether to use augmentation during testing. Defaults to True.
profiler¶ (str) – The profiler to be used for debugging. (simple, advanced, None)
- fit(results=None)[source]#
This method trains the method and the embedded model.
Returns the time it took to train the model.
- Return type:
float
- property interrupted: bool#
This property returns whether the training was interrupted.
- test(results)[source]#
This method tests the method and the embedded model.
The results are stored in the given dictionary. Returns the time it took to test the model.
- Return type:
float
- calibrate()[source]#
This is a helper function which runs calibration data through the model.
Note that the Trainer is not used here, the model is not trained - the gradients are not updated. We need to manually run the calibration data through the validation_step method of the method.
Returns the time it took to calibrate the model.
- Return type:
float