yamle.methods.method module#
- class yamle.methods.method.BaseMethod(model, loss, regularizer, learning_rate=0.001, regularizer_weight=0.0, momentum=0.9, task='classification', optimizer='adam', scheduler='step', scheduler_step_size=10, scheduler_gamma=0.1, scheduler_factor=0.1, scheduler_patience=10, seed=42, inputs_dim=(1, 28, 28), inputs_dtype=torch.float32, outputs_dim=10, targets_dim=1, outputs_dtype=torch.float32, datamodule=None, plotting_training=False, plotting_testing=False, save_path=None, save_test_predictions=False, metrics=None, metrics_kwargs={}, model_kwargs={}, **kwargs)[source]#
Bases:
LightningModule
This class is the base class for all methods in the project.
It assumes that the output of the model has a shape (batch_size, 1, num_classes) for training. This corresponds to a single Monte Carlo sample.
- Parameters:
model¶ (nn.Module) – The model to be trained.
loss¶ (nn.Module) – The loss function to be used.
regularizer¶ (BaseRegularizer) – The regularizer to be used.
learning_rate¶ (float) – The learning rate to be used for training.
regularizer_weight¶ (float) – The weight of the regularizer.
momentum¶ (float) – The momentum to be used for training.
task¶ (str) – The task to be performed. Can be either classification or regression.
optimizer¶ (str) – The optimizer to be used for training. Can be either adam or sgd.
scheduler¶ (str) – The learning rate scheduler to be used for training.
scheduler_step_size¶ (int) – The step size to be used for the learning rate scheduler.
scheduler_gamma¶ (float) – The gamma to be used for the learning rate scheduler.
scheduler_factor¶ (float) – The factor to be used for the learning rate scheduler.
scheduler_patience¶ (int) – The patience to be used for the learning rate scheduler.
seed¶ (int) – The seed to be used for training.
inputs_dim¶ (Tuple[int, ...]) – The shape of the inputs to the model.
inputs_dtype¶ (torch.dtype) – The dtype of the inputs to the model.
outputs_dim¶ (int) – The number of outputs of the model.
targets_dim¶ (int) – The feature dimension of the targets.
outputs_dtype¶ (torch.dtype) – The dtype of the outputs of the model.
datamodule¶ (Optional[BaseDataModule]) – The datamodule to be used for training or testing.
plotting_training¶ (bool) – Whether to plot sanity checks or not during training.
plotting_testing¶ (bool) – Whether to plot sanity checks or not during testing.
save_path¶ (Optional[str]) – The path to save files to.
save_test_predictions¶ (bool) – Whether to save the test predictions or not.
metrics_kwargs¶ (Dict[str, Any]) – The keyword arguments to be passed to the metrics.
model_kwargs¶ (Dict[str, Any]) – The keyword arguments to be passed to the model.
- tasks = ['regression', 'classification', 'text_classification', 'segmentation', 'depth_estimation', 'pre_training', 'reconstruction']#
- on_fit_start()[source]#
This method is used to set the metrics to the correct device.
- Return type:
None
- on_test_start()[source]#
This method is used to set the metrics to the correct device.
- Return type:
None
- training_step(batch, batch_idx)[source]#
This method is used to perform a single training step.
This method should not be overridden. It can catch exceptions if they are raised inside _training_step.
- Return type:
Dict
[str
,Any
]
- validation_step(batch, batch_idx)[source]#
This method is used to perform a single validation step.
This method should not be overridden. It can catch exceptions if they are raised inside _validation_step.
- Return type:
Dict
[str
,Any
]
- test_step(batch, batch_idx)[source]#
This method is used to perform a single test step.
This method should not be overridden. It can catch exceptions if they are raised inside _test_step.
- Return type:
Dict
[str
,Any
]
- on_before_backward(loss)[source]#
This method is called before the backward pass, but after the loss has been computed.
By default regularizer term is added to the loss.
- Return type:
None
- optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)[source]#
This method is used to perform the optimizer step.
The optimzier is not stepped if an exception is raised during the training step.
- Return type:
None
- on_train_batch_end(outputs, batch, batch_idx)[source]#
This method is used to update the metrics at the end of each training batch.
Weight decay is also performed at the end of each training batch, if it is selected.
- Return type:
None
- on_validation_batch_end(outputs, batch, batch_idx)[source]#
This method is used to update the metrics at the end of each validation batch.
- Return type:
None
- on_test_batch_end(outputs, batch, batch_idx)[source]#
This method is used to update the metrics at the end of each test batch.
- Return type:
None
- on_train_epoch_start()[source]#
This method is used to set the model in training mode at the beginning of each training epoch.
- Return type:
None
- on_validation_epoch_start()[source]#
This method is used to set the model in evaluation mode at the beginning of each validation epoch.
- Return type:
None
- on_test_epoch_start()[source]#
This method is used to set the model in evaluation mode at the beginning of each test epoch.
- Return type:
None
- on_train_epoch_end()[source]#
This method is used to:
Reset the model at the end of each training epoch. Step the learning rate schedulers if automatic optimization is not selected. Plot the training results if plotting is selected. Apply the regularizer at the end of each training epoch if a regularizer is selected.
- Return type:
None
- on_validation_epoch_end()[source]#
This method is used to reset the model at the end of each validation epoch.
- Return type:
None
- on_test_epoch_end()[source]#
This method is used to reset the model at the end of each test epoch.
- Return type:
None
- reset_metrics(prefix, complete=False)[source]#
This method is used to reset the metrics. The metrics are not reset at the end of training, validation and testing, because they are logged externally.
- get_parameters(recurse=True)[source]#
This method is used to get the parameters of the model.
- Return type:
List
[Parameter
]
- get_named_parameters(recurse=True)[source]#
This method is used to get the named parameters of the model.
- Return type:
List
[Tuple
[str
,Parameter
]]
- configure_optimizers()[source]#
This method is used to configure the optimizers to be used for training.
Additionally, it is used to configure the learning rate schedulers.
- Return type:
Tuple
[List
[Optimizer
],List
[_LRScheduler
]]
- state_dict()[source]#
This method is used to get the state dict of the method.
- Return type:
Dict
[str
,Any
]
- load_state_dict(state_dict)[source]#
This method is used to load the state dict of the method.
- Return type:
None
- analyse(save_path)[source]#
This method is used to analyse the method.
The analysis is used on a trained method before the evaluation. Implement here any analysis that should be performed on the trained method/model. The save_path is the path to the directory where the analysis should be saved.
- Return type:
None
- backward(loss, *args, **kwargs)[source]#
This method is used to perform the backward pass.
- Parameters:
loss¶ (torch.Tensor) – The loss to be used for the backward pass.
- Return type:
None
- on_after_backward()[source]#
This method is used to perform any operation after the backward pass
A regularizer might perform some operations after the backward pass.
- Return type:
None
- property evaluation#
This property is used to get if the method is in evaluation mode.