Source code for yamle.methods.laplace

import argparse
import logging
from typing import Any, Dict

import torch.nn as nn
from backpack import backpack, extend, extensions
from tqdm import tqdm
import torch

from yamle.defaults import CLASSIFICATION_KEY, MEMBERS_DIM
from yamle.methods.uncertain_method import MCSamplingMethod
from yamle.models.specific.laplace import LaplaceLinear
from yamle.regularizers.weight import L2Regularizer
import math

logging = logging.getLogger("pytorch_lightning")


[docs] class LaplaceMethod(MCSamplingMethod): """This class implements the Laplace method. It performs Laplace approximation in the neural network to approximate the posterior distribution. It is important the regularisation is the `L2Regularizer` and the `regularizer_weight` is non-zero. These act as the prior distribution. At the moment, it only supports the `last_layer` mode and the `classification` task. This is because the `backpack` library supports only the `CrossEntropyLoss` but not the `GaussianNLLLoss`. Args: mode (str): The mode of the method. It can be `last_layer` for now. """ tasks = [CLASSIFICATION_KEY] def __init__( self, mode: str, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) assert mode in [ "last_layer", ], f"Mode {mode} is not supported. It must be one of ['last_layer']." self._mode = mode assert isinstance( self._regularizer, L2Regularizer ), f"The regularizer must be of type `L2Regularizer`. It is {self.regularizer}." assert ( self.hparams.regularizer_weight > 0 ), f"The regularizer weight must be greater than 0. It is {self.hparams.regularizer_weight}." self._precision = self.hparams.regularizer_weight assert isinstance( self.model._output, nn.Linear ), f"The output layer must be a linear layer. It is {self.model._output}." self._U: torch.Tensor = None self._V: torch.Tensor = None @property def hessian_computed(self) -> bool: """This property returns whether the Hessian is computed.""" return self._U is not None and self._V is not None def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor: """This method is used to perform a forward pass of the model.""" outputs = [] num_members = ( self.training_num_members if self.training or not self.hessian_computed else self._num_members ) for _ in range(num_members): outputs.append(super(MCSamplingMethod, self)._predict(x, **forward_kwargs)) outputs = torch.cat(outputs, dim=MEMBERS_DIM) return outputs def _replace_output_layer(self) -> None: """This is a helper function to replace the output layer with the Laplace linear layer.""" if isinstance(self.model._output, nn.Linear): device = self.model._output.weight.device self.model._output = LaplaceLinear( U=self._U, V=self._V, weight=self.model._output.weight, bias=self.model._output.bias, ).to(device)
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """This function loads the state dictionary of the method.""" super().load_state_dict(state_dict) self._U = state_dict["_U"] self._V = state_dict["_V"]
[docs] def on_after_model_load(self) -> None: """This function is called after the model is loaded.""" super().on_after_model_load() self._replace_output_layer()
[docs] def state_dict(self) -> Dict[str, Any]: """This function returns the state dictionary of the method.""" state_dict = super().state_dict() state_dict["_U"] = self._U state_dict["_V"] = self._V return state_dict
[docs] def on_fit_end(self) -> None: """This method is used to compute the factorised Hessian at the end of the training. It uses the K-FAC approximation to compute the factorised Hessian. """ super().on_fit_end() loss = nn.CrossEntropyLoss( reduction="sum" ) # Very important to set the reduction to sum loss = extend(loss) extend(self.model._output) loader = self._datamodule.train_dataloader() W = self.model._output.weight self.model._output_activation.disable() # Need to disable the output activation because the inputs are logits A, B = None, None for batch in tqdm(loader, desc="Computing the factorised Hessian"): x, y = batch x, y = x.to(self.device), y.to(self.device) output = self.model(x) output = output.view(output.shape[0], -1) self.model.zero_grad() with backpack( extensions.KFAC(), ): loss_value = loss(output, y.long()) loss_value.backward() A_, B_ = W.kfac if A is None: A = A_ B = B_ else: A += A_ B += B_ # Average the factorised Hessian with the number of batches A = A / len(loader) B = B / len(loader) self.model._output_activation.enable() self._U = torch.inverse(A + math.sqrt(self._precision) * torch.eye(A.shape[0])) self._V = torch.inverse(B + math.sqrt(self._precision) * torch.eye(B.shape[1])) self._replace_output_layer()
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This function adds the arguments for the Laplace method.""" parser = super(LaplaceMethod, LaplaceMethod).add_specific_args(parent_parser) parser.add_argument( "--method_mode", type=str, default="last_layer", choices=["last_layer"], help="The mode of the method.", ) return parser