Source code for yamle.methods.rbnn

from typing import List, Dict, Any
import torch
import argparse
import torch.nn as nn

from yamle.methods.uncertain_method import SVIMethod
from yamle.models.specific.rbnn import LinearRBNN, Conv2dRBNN
from yamle.utils.operation_utils import repeat_inputs_in_batch, average_predictions
from yamle.defaults import LOSS_KEY, TARGET_KEY, PREDICTION_KEY, MEAN_PREDICTION_KEY


[docs] def replace_with_rbnn( model: nn.Module, num_members: int, prior_mean: float, log_variance, prior_log_variance: float, method: str, ) -> nn.Module: """This method is used to replace all the `nn.Linear`, `nn.Conv2d` layers with a `LinearRBNN` and `Conv2dRBNN` respectively. Args: model (nn.Module): The model to replace the layers in. num_members (int): The number of members in the ensemble. prior_mean (float): The mean of the prior distribution. log_variance (float): The initial value of the log of the standard deviation of the weights. prior_log_variance (float): The initial value of the log of the standard deviation of the prior distribution. method (str): The method whether `additive` or `multiplicative` to be used for the rank-1 approximation. """ for name, child in model.named_children(): if isinstance(child, nn.Linear): setattr( model, name, LinearRBNN( in_features=child.in_features, out_features=child.out_features, bias=child.bias is not None, num_members=num_members, prior_mean=prior_mean, log_variance=log_variance, prior_log_variance=prior_log_variance, method=method, ), ) elif isinstance(child, nn.Conv2d): setattr( model, name, Conv2dRBNN( in_channels=child.in_channels, out_channels=child.out_channels, kernel_size=child.kernel_size, stride=child.stride, padding=child.padding, dilation=child.dilation, groups=child.groups, bias=child.bias is not None, num_members=num_members, prior_mean=prior_mean, log_variance=log_variance, prior_log_variance=prior_log_variance, method=method, ), ) else: replace_with_rbnn( child, num_members, prior_mean, log_variance, prior_log_variance, method )
[docs] class RBNNMethod(SVIMethod): """This class is the extension of the base method for Rank-1 Bayesian Neural Networks. Args: num_members (int): The number of members in the ensemble. prior_mean (float): The mean of the prior distribution. log_variance (float): The initial value of the log of the standard deviation of the weights. prior_log_variance (float): The initial value of the log of the standard deviation of the prior distribution. method (str): The method whether `additive` or `multiplicative` to be used for the rank-1 approximation. """ def __init__( self, prior_mean: float = 1.0, log_variance: float = -3.0, prior_log_variance: float = -3.0, method: str = "additive", **kwargs ): super().__init__(**kwargs) assert method in ["additive", "multiplicative"] replace_with_rbnn( self.model, self._num_members, prior_mean, log_variance, prior_log_variance, method, ) def _validation_test_step( self, batch: List[torch.Tensor], batch_idx: int ) -> Dict[str, Any]: """This method is used to perform a single validation/test step. It assumes that the batch has a shape `(batch_size, num_features)`. It assumes that the output of the model has a shape `(batch_size, n_samples, num_classes)`. """ x, y = repeat_inputs_in_batch(batch, self._num_members) y_hat = self._predict(x) y_hat = y_hat.reshape( -1, self._num_members * self._num_members, *y_hat.shape[2:] ) loss = self._loss(y_hat, y) y_hat_mean = average_predictions(y_hat, self._task) output = { LOSS_KEY: loss, TARGET_KEY: y.detach(), PREDICTION_KEY: y_hat.detach(), MEAN_PREDICTION_KEY: y_hat_mean.detach(), } return output def _validation_step( self, batch: List[torch.Tensor], batch_idx: int ) -> Dict[str, Any]: """This method is used to perform a single validation step.""" return self._validation_test_step(batch, batch_idx) def _test_step(self, batch: List[torch.Tensor], batch_idx: int) -> Dict[str, Any]: """This method is used to perform a single test step.""" return self._validation_test_step(batch, batch_idx)
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method adds the specific arguments for the MIMO method.""" parser = super(RBNNMethod, RBNNMethod).add_specific_args(parent_parser) parser.add_argument( "--method_prior_mean", type=float, default=1.0, help="The mean of the prior distribution.", ) parser.add_argument( "--method_log_variance", type=float, default=-3.0, help="The initial value of the log of the standard deviation of the weights.", ) parser.add_argument( "--method_prior_log_variance", type=float, default=-3.0, help="The initial value of the log of the standard deviation of the prior distribution.", ) parser.add_argument( "--method_method", type=str, default="additive", help="The method whether `additive` or `multiplicative` to be used for the rank-1 approximation.", ) return parser