yamle.methods.svi module#

class yamle.methods.svi.SVIReparameterizationMethod(prior_mean, log_variance, prior_log_variance, p, mode, method, **kwargs)[source]#

Bases: SVIMethod

This class is the extension of the base method for stochastic variational inference methods. Implemented with respect to Local Reparameterization Trick (LRT) or the simple Reparameterization Trick (RT).

It is assumed that the posterior should be mean-field and that the prior should be a Gaussian.

Parameters:
  • prior_mean (float) – The mean of the prior. Only used if the method is lrt, rt or flipout_gaussian.

  • log_variance (float) – The initial value of the log variance of the weights. Only used if the method is lrt, rt or flipout_gaussian.

  • prior_log_variance (float) – The log variance of the prior. Only used if the method is lrt, rt or flipout_gaussian.

  • p (float) – The probability in the DropConnect layer. Only used if the method is flipout_dropconnect.

  • mode (str) – Whether the last layer or the all layers should be used for the inference.

  • method (str) – Whether to use the lrt, rt, flipout_dropconnect or flipout_gaussian method.

static add_specific_args(parent_parser)[source]#

This method is used to add the specific arguments for the class.

Return type:

ArgumentParser

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#
class yamle.methods.svi.SVILRTMethod(**kwargs)[source]#

Bases: SVIReparameterizationMethod

This class is the extension of the base method for stochastic variational inference methods. Implemented with respect to Local Reparameterization Trick (LRT) and Gaussian prior.

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#
class yamle.methods.svi.SVILRTVDMethod(**kwargs)[source]#

Bases: SVIReparameterizationMethod

This class is the extension of the base method for stochastic variational inference methods. Implemented with respect to Local Reparameterization Trick (LRT) and Variational Dropout prior.

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#
class yamle.methods.svi.SVIRTMethod(**kwargs)[source]#

Bases: SVIReparameterizationMethod

This class is the extension of the base method for stochastic variational inference methods. Implemented with respect to Reparameterization Trick (RT) and Gaussian prior.

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#
class yamle.methods.svi.SVIFlipOutRTMethod(**kwargs)[source]#

Bases: SVIReparameterizationMethod

This class implements the SVI method using the FlipOut trick with Gaussian prior and reparameterization trick.

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#
class yamle.methods.svi.SVIFlipOutDropConnectMethod(**kwargs)[source]#

Bases: SVIReparameterizationMethod

This class implements the SVI method using the FlipOut trick with DropConnect prior and reparameterization trick.

test_name: Optional[str]#
prepare_data_per_node: bool#
allow_zero_length_dataloader_with_multiple_devices: bool#
training: bool#