yamle.methods.swag module#

class yamle.methods.swag.SWAGMethod(covariance, fullrank, apply_to_normalisation, scale, num_members, epochs_to_collect, *args, **kwargs)[source]#

Bases: BaseMethod

This class is the extension of the base method for stochastic weight averaging.

This method was described in the paper “A Simple Baseline for Bayesian Uncertainty in Deep Learning”: https://arxiv.org/pdf/1902.02476.pdf.

Parameters:
  • covariance (bool) – Whether to estimate the full covariance matrix.

  • fullrank (bool) – Whether to use the full rank covariance matrix.

  • apply_to_normalisation (bool) – Whether to apply the method to the normalisation layers.

  • scale (float) – The scale of the sampling.

  • num_members (int) – The number of samples to take when sampling the weights during testing.

  • epochs_to_collect (List[int]) – The epochs to collect the weights from.

state_dict()[source]#

This method is used to get the state dictionary of the method.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]#

This method is used to load the state dictionary of the method.

Return type:

None

on_train_epoch_end()[source]#

This method is used to collect the model at the end of each epoch.

Return type:

None

on_validation_epoch_start()[source]#

This method is used to cache the training weights.

Return type:

None

on_train_epoch_start()[source]#

This method is used to set the model to training mode.

Return type:

None

static add_specific_args(parent_parser)[source]#

This method is used to add the specific arguments for the class.

Return type:

ArgumentParser

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#