yamle.methods.method module#

class yamle.methods.method.BaseMethod(model, loss, regularizer, learning_rate=0.001, regularizer_weight=0.0, momentum=0.9, task='classification', optimizer='adam', scheduler='step', scheduler_step_size=10, scheduler_gamma=0.1, scheduler_factor=0.1, scheduler_patience=10, seed=42, inputs_dim=(1, 28, 28), inputs_dtype=torch.float32, outputs_dim=10, targets_dim=1, outputs_dtype=torch.float32, datamodule=None, plotting_training=False, plotting_testing=False, save_path=None, save_test_predictions=False, metrics=None, metrics_kwargs={}, model_kwargs={}, **kwargs)[source]#

Bases: LightningModule

This class is the base class for all methods in the project.

It assumes that the output of the model has a shape (batch_size, 1, num_classes) for training. This corresponds to a single Monte Carlo sample.

Parameters:
  • model (nn.Module) – The model to be trained.

  • loss (nn.Module) – The loss function to be used.

  • regularizer (BaseRegularizer) – The regularizer to be used.

  • learning_rate (float) – The learning rate to be used for training.

  • regularizer_weight (float) – The weight of the regularizer.

  • momentum (float) – The momentum to be used for training.

  • task (str) – The task to be performed. Can be either classification or regression.

  • optimizer (str) – The optimizer to be used for training. Can be either adam or sgd.

  • scheduler (str) – The learning rate scheduler to be used for training.

  • scheduler_step_size (int) – The step size to be used for the learning rate scheduler.

  • scheduler_gamma (float) – The gamma to be used for the learning rate scheduler.

  • scheduler_factor (float) – The factor to be used for the learning rate scheduler.

  • scheduler_patience (int) – The patience to be used for the learning rate scheduler.

  • seed (int) – The seed to be used for training.

  • inputs_dim (Tuple[int, ...]) – The shape of the inputs to the model.

  • inputs_dtype (torch.dtype) – The dtype of the inputs to the model.

  • outputs_dim (int) – The number of outputs of the model.

  • targets_dim (int) – The feature dimension of the targets.

  • outputs_dtype (torch.dtype) – The dtype of the outputs of the model.

  • datamodule (Optional[BaseDataModule]) – The datamodule to be used for training or testing.

  • plotting_training (bool) – Whether to plot sanity checks or not during training.

  • plotting_testing (bool) – Whether to plot sanity checks or not during testing.

  • save_path (Optional[str]) – The path to save files to.

  • save_test_predictions (bool) – Whether to save the test predictions or not.

  • metrics_kwargs (Dict[str, Any]) – The keyword arguments to be passed to the metrics.

  • model_kwargs (Dict[str, Any]) – The keyword arguments to be passed to the model.

tasks = ['regression', 'classification', 'text_classification', 'segmentation', 'depth_estimation', 'pre_training', 'reconstruction']#
on_fit_start()[source]#

This method is used to set the metrics to the correct device.

Return type:

None

on_test_start()[source]#

This method is used to set the metrics to the correct device.

Return type:

None

training_step(batch, batch_idx)[source]#

This method is used to perform a single training step.

This method should not be overridden. It can catch exceptions if they are raised inside _training_step.

Return type:

Dict[str, Any]

validation_step(batch, batch_idx)[source]#

This method is used to perform a single validation step.

This method should not be overridden. It can catch exceptions if they are raised inside _validation_step.

Return type:

Dict[str, Any]

test_step(batch, batch_idx)[source]#

This method is used to perform a single test step.

This method should not be overridden. It can catch exceptions if they are raised inside _test_step.

Return type:

Dict[str, Any]

on_before_backward(loss)[source]#

This method is called before the backward pass, but after the loss has been computed.

By default regularizer term is added to the loss.

Return type:

None

optimizer_step(epoch, batch_idx, optimizer, optimizer_closure=None)[source]#

This method is used to perform the optimizer step.

The optimzier is not stepped if an exception is raised during the training step.

Return type:

None

on_train_batch_end(outputs, batch, batch_idx)[source]#

This method is used to update the metrics at the end of each training batch.

Weight decay is also performed at the end of each training batch, if it is selected.

Return type:

None

on_validation_batch_end(outputs, batch, batch_idx)[source]#

This method is used to update the metrics at the end of each validation batch.

Return type:

None

on_test_batch_end(outputs, batch, batch_idx)[source]#

This method is used to update the metrics at the end of each test batch.

Return type:

None

on_train_epoch_start()[source]#

This method is used to set the model in training mode at the beginning of each training epoch.

Return type:

None

on_validation_epoch_start()[source]#

This method is used to set the model in evaluation mode at the beginning of each validation epoch.

Return type:

None

on_test_epoch_start()[source]#

This method is used to set the model in evaluation mode at the beginning of each test epoch.

Return type:

None

on_train_epoch_end()[source]#

This method is used to:

Reset the model at the end of each training epoch. Step the learning rate schedulers if automatic optimization is not selected. Plot the training results if plotting is selected. Apply the regularizer at the end of each training epoch if a regularizer is selected.

Return type:

None

on_validation_epoch_end()[source]#

This method is used to reset the model at the end of each validation epoch.

Return type:

None

on_test_epoch_end()[source]#

This method is used to reset the model at the end of each test epoch.

Return type:

None

on_after_model_load()[source]#

This method is used after the model is loaded.

Return type:

None

on_before_model_load()[source]#

This method is used before the model is loaded.

Return type:

None

on_before_method_load()[source]#

This method is used before the method is loaded.

Return type:

None

on_after_method_load()[source]#

This method is used after the method is loaded.

Return type:

None

reset_metrics(prefix, complete=False)[source]#

This method is used to reset the metrics. The metrics are not reset at the end of training, validation and testing, because they are logged externally.

Parameters:
  • prefix (str) – The prefix of the metrics to be reset.

  • complete (bool) – If True, the metrics are reset completely. If False, only the values are reset.

Return type:

None

get_parameters(recurse=True)[source]#

This method is used to get the parameters of the model.

Return type:

List[Parameter]

get_named_parameters(recurse=True)[source]#

This method is used to get the named parameters of the model.

Return type:

List[Tuple[str, Parameter]]

configure_optimizers()[source]#

This method is used to configure the optimizers to be used for training.

Additionally, it is used to configure the learning rate schedulers.

Return type:

Tuple[List[Optimizer], List[_LRScheduler]]

state_dict()[source]#

This method is used to get the state dict of the method.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]#

This method is used to load the state dict of the method.

Return type:

None

analyse(save_path)[source]#

This method is used to analyse the method.

The analysis is used on a trained method before the evaluation. Implement here any analysis that should be performed on the trained method/model. The save_path is the path to the directory where the analysis should be saved.

Return type:

None

backward(loss, *args, **kwargs)[source]#

This method is used to perform the backward pass.

Parameters:

loss (torch.Tensor) – The loss to be used for the backward pass.

Return type:

None

on_after_backward()[source]#

This method is used to perform any operation after the backward pass

A regularizer might perform some operations after the backward pass.

Return type:

None

property evaluation#

This property is used to get if the method is in evaluation mode.

static add_specific_args(parent_parser)[source]#

This method is used to add method specific arguments to the parent parser.

Return type:

ArgumentParser