Source code for yamle.methods.be

from typing import List, Dict, Any
import torch
import torch.nn as nn

from yamle.methods.uncertain_method import MemberMethod
from yamle.models.specific.be import LinearBE, Conv2dBE
from yamle.utils.operation_utils import average_predictions, repeat_inputs
from yamle.defaults import LOSS_KEY, TARGET_KEY, PREDICTION_KEY, MEAN_PREDICTION_KEY, TARGET_PER_MEMBER_KEY, PREDICTION_PER_MEMBER_KEY, AVERAGE_WEIGHTS_KEY


[docs] def replace_with_be(model: nn.Module, num_members: int) -> None: """This method is used to replace all the `nn.Linear`, `nn.Conv2d` layers with a `LinearBE`, `Conv2dBE` respectively. Args: model (nn.Module): The model to replace the layers in. num_members (int): The number of members in the ensemble. """ for name, child in model.named_children(): if isinstance(child, nn.Linear): setattr( model, name, LinearBE( in_features=child.in_features, out_features=child.out_features, bias=child.bias, num_members=num_members, weight=child.weight, ), ) elif isinstance(child, nn.Conv2d): setattr( model, name, Conv2dBE( in_channels=child.in_channels, out_channels=child.out_channels, kernel_size=child.kernel_size, stride=child.stride, padding=child.padding, dilation=child.dilation, groups=child.groups, num_members=num_members, weight=child.weight, bias=child.bias, ), ) else: replace_with_be(child, num_members)
[docs] class BEMethod(MemberMethod): """This class is the extension of the base method for BatchEnsemble models. The difference is in having to change the prediction to concatenate the `num_members` dimension. into the batch dimension during validation and testing. Note that only Linear and Conv2d layers are supported, not the batch norm layers. In practice this is not a problem https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/normalization.py#L111 """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) replace_with_be(self.model, self._num_members) def _predict( self, x: torch.Tensor, unsqueeze: bool = True, **forward_kwargs: Any ) -> torch.Tensor: """This method is used to perform a forward pass of the model. If the model is in evaluation mode it replicates the inputs `num_members` times and concatenates them into the batch dimension. Args: x (torch.Tensor): The input to the model. **forward_kwargs (Any): The keyword arguments to be passed to the forward pass of the model. """ if self.evaluation: x = repeat_inputs(x, self._num_members) output = self.model(x, **forward_kwargs) if self.evaluation: output = output.reshape(-1, self._num_members, *output.shape[1:]) elif unsqueeze: output = output.unsqueeze(1) return output def _validation_test_step( self, batch: List[torch.Tensor], batch_idx: int ) -> Dict[str, Any]: """This method is used to perform a single validation/test step. It assumes that the batch has a shape `(batch_size, num_features)`. It assumes that the output of the model has a shape `(batch_size, n_samples, num_classes)`. """ x, y = batch y_hat_permember = self._predict(x) # Repeat the labels num_members times y_permember = torch.stack([y] * self._num_members, dim=1) loss = self._loss_per_member(y_hat_permember, y_permember) y_hat = average_predictions(y_hat_permember, self._task) output = { LOSS_KEY: loss, TARGET_KEY: y.detach(), PREDICTION_KEY: y_hat_permember.detach(), MEAN_PREDICTION_KEY: y_hat.detach(), TARGET_PER_MEMBER_KEY: y_permember.detach(), PREDICTION_PER_MEMBER_KEY: y_hat_permember.detach(), AVERAGE_WEIGHTS_KEY: None, } return output def _validation_step( self, batch: List[torch.Tensor], batch_idx: int ) -> Dict[str, Any]: """This method is used to perform a single validation step.""" return self._validation_test_step(batch, batch_idx) def _test_step(self, batch: List[torch.Tensor], batch_idx: int) -> Dict[str, Any]: """This method is used to perform a single test step.""" return self._validation_test_step(batch, batch_idx)