yamle.methods.swag module#
- class yamle.methods.swag.SWAGMethod(covariance, fullrank, apply_to_normalisation, scale, num_members, epochs_to_collect, *args, **kwargs)[source]#
Bases:
BaseMethodThis class is the extension of the base method for stochastic weight averaging.
This method was described in the paper “A Simple Baseline for Bayesian Uncertainty in Deep Learning”: https://arxiv.org/pdf/1902.02476.pdf.
- Parameters:
covariance¶ (bool) – Whether to estimate the full covariance matrix.
fullrank¶ (bool) – Whether to use the full rank covariance matrix.
apply_to_normalisation¶ (bool) – Whether to apply the method to the normalisation layers.
scale¶ (float) – The scale of the sampling.
num_members¶ (int) – The number of samples to take when sampling the weights during testing.
epochs_to_collect¶ (List[int]) – The epochs to collect the weights from.
- state_dict()[source]#
This method is used to get the state dictionary of the method.
- Return type:
Dict[str,Any]
- load_state_dict(state_dict)[source]#
This method is used to load the state dictionary of the method.
- Return type:
None
- on_train_epoch_end()[source]#
This method is used to collect the model at the end of each epoch.
- Return type:
None
- on_validation_epoch_start()[source]#
This method is used to cache the training weights.
- Return type:
None
- on_train_epoch_start()[source]#
This method is used to set the model to training mode.
- Return type:
None
- static add_specific_args(parent_parser)[source]#
This method is used to add the specific arguments for the class.
- Return type:
ArgumentParser
- test_name: Optional[str]#
- prepare_data_per_node: bool#
- allow_zero_length_dataloader_with_multiple_devices: bool#
- training: bool#