from typing import Any, Dict, List, Union
import torch
import torch.nn as nn
import argparse
import copy
from yamle.methods.method import BaseMethod
from yamle.defaults import (
MEMBERS_DIM,
TRAIN_KEY,
VALIDATION_KEY,
TEST_KEY,
TINY_EPSILON,
)
from yamle.evaluation.metrics.algorithmic import metrics_factory
import logging
logging = logging.getLogger("pytorch_lightning")
[docs]
class SWAGMethod(BaseMethod):
"""This class is the extension of the base method for stochastic weight averaging.
This method was described in the paper "A Simple Baseline for Bayesian Uncertainty in Deep Learning":
https://arxiv.org/pdf/1902.02476.pdf.
Args:
covariance (bool): Whether to estimate the full covariance matrix.
fullrank (bool): Whether to use the full rank covariance matrix.
apply_to_normalisation (bool): Whether to apply the method to the normalisation layers.
scale (float): The scale of the sampling.
num_members (int): The number of samples to take when sampling the weights during testing.
epochs_to_collect (List[int]): The epochs to collect the weights from.
"""
def __init__(
self,
covariance: bool,
fullrank: bool,
apply_to_normalisation: bool,
scale: float,
num_members: int,
epochs_to_collect: List[int],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self._covariance = covariance
self._fullrank = fullrank
self._apply_to_normalisation = apply_to_normalisation
self._scale = scale
self._num_members = num_members
assert (
len(epochs_to_collect) > 1
), f"The number of epochs to collect must be greater than 1. Got {len(epochs_to_collect)}."
assert (
min(epochs_to_collect) >= 0
), f"The epochs to collect must be greater than or equal to 0. Got {min(epochs_to_collect)}."
self._epochs_to_collect = epochs_to_collect
self._max_num_of_collected_models = len(self._epochs_to_collect)
self._num_of_collected_models = 0
# This list stores the referene to the parameters in the model
# We add attributes to the parameters to store the mean, square mean and covariance matrix
self._swag_parameters: List[nn.Parameter] = []
self._initialise_swag_model()
[docs]
def state_dict(self) -> Dict[str, Any]:
"""This method is used to get the state dictionary of the method."""
state_dict = super().state_dict()
# Store the actual values of the parameters
state_dict["swag_parameters"] = [
{
"_mean": getattr(p, "_mean", None),
"_sq_mean": getattr(p, "_sq_mean", None),
"_cov_mat_sqrt": getattr(p, "_cov_mat_sqrt", None),
"_training": getattr(p, "_training", None),
}
for p in self._swag_parameters
]
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""This method is used to load the state dictionary of the method."""
super().load_state_dict(state_dict)
# Load the actual values of the parameters
for p, state in zip(self._swag_parameters, state_dict.pop("swag_parameters")):
assert (
p.data.shape == state["_mean"].shape
), f"The shape of the mean is not the same as the parameter. Got {p.data.shape} and {state['_mean'].shape}."
setattr(p, "_training", state["_training"])
setattr(p, "_mean", state["_mean"])
setattr(p, "_sq_mean", state["_sq_mean"])
setattr(p, "_cov_mat_sqrt", state["_cov_mat_sqrt"])
def _create_metrics(self, metrics_kwargs: Dict[str, Any]) -> None:
"""This method is used to create the metrics to be used for training, validation and testing.
For the Monte Carlo sampling, we do not care about the individual members.
"""
self.metrics = {
TRAIN_KEY: metrics_factory(**metrics_kwargs, per_member=False),
VALIDATION_KEY: metrics_factory(**metrics_kwargs, per_member=False),
TEST_KEY: metrics_factory(**metrics_kwargs, per_member=False),
}
def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor:
"""This method is used to perform a forward pass through the current model."""
# During training simple give the current model
if self.training:
return super()._predict(x, **forward_kwargs)
else:
# During testing, we need to sample from the distribution
outputs = []
for i in range(self._num_members):
self._sample()
self._set_weights_to_sample()
outputs.append(super()._predict(x, **forward_kwargs))
return torch.cat(outputs, dim=MEMBERS_DIM)
def _sample(self) -> None:
"""This method is used to sample from the distribution of the weights."""
if self._fullrank:
self._sample_fullrank()
else:
self._sample_blockwise()
def _set_weights_to_sample(self) -> None:
"""This method is used to set the weights to the sampled weights."""
for p in self._swag_parameters:
p.data = (
getattr(p, "_sample").clone().to(next(self.model.parameters()).device)
)
def _set_weights_to_training(self) -> None:
"""This method is used to set the weights to the training weights."""
for p in self._swag_parameters:
p.data = (
getattr(p, "_training").clone().to(next(self.model.parameters()).device)
)
def _update_training_weights(self) -> None:
"""This method is used to update the training weights."""
for p in self._swag_parameters:
setattr(p, "_training", p.data.clone().cpu())
def _sample_blockwise(self) -> None:
"""This method is used to sample blockwise from the distribution of the weights."""
for p in self._swag_parameters:
mean = getattr(p, "_mean", None)
sq_mean = getattr(p, "_sq_mean", None)
eps = torch.randn_like(mean)
var = torch.clamp(sq_mean - mean**2, min=TINY_EPSILON)
scaled_diag_sample = self._scale * torch.sqrt(var) * eps
if self._covariance:
cov_mat_sqrt = getattr(p, "_cov_mat_sqrt", None)
if cov_mat_sqrt is None:
continue
eps = cov_mat_sqrt.new_empty((cov_mat_sqrt.size(0), 1)).normal_()
cov_sample = (
self._scale / ((self._max_num_of_collected_models - 1) ** 0.5)
) * cov_mat_sqrt.t().matmul(eps).view_as(mean)
if self._fullrank:
w = mean + scaled_diag_sample + cov_sample
else:
w = mean + scaled_diag_sample
else:
w = mean + scaled_diag_sample
setattr(p, "_sample", w)
@staticmethod
def _flatten_params(
parameters: List[Union[nn.Parameter, torch.Tensor]]
) -> torch.Tensor:
"""This method is used to flatten the parameters."""
temp = (
[p.data.view(-1, 1) for p in parameters]
if isinstance(parameters[0], nn.Parameter)
else [p.view(-1, 1) for p in parameters]
)
return torch.cat(temp, dim=0)
@staticmethod
def _unflatten_params_like(
parameters: torch.Tensor, like: List[Union[nn.Parameter, torch.Tensor]]
) -> List[torch.Tensor]:
"""This method is used to unflatten the parameters."""
temp = []
index = 0
for p in like:
temp.append(parameters[index : index + p.numel()].view(p.size()))
index += p.numel()
return temp
def _sample_fullrank(self) -> None:
"""This method is used to sample from the full rank distribution of the weights."""
scale_sqrt = self._scale**0.5
mean_list: List[torch.Tensor] = []
sq_mean_list: List[torch.Tensor] = []
if self._covariance:
cov_mat_sqrt_list = []
for p in self._swag_parameters:
mean = getattr(p, "_mean", None)
sq_mean = getattr(p, "_sq_mean", None)
if self._covariance:
cov_mat_sqrt = getattr(p, "_cov_mat_sqrt", None)
if cov_mat_sqrt is None:
continue
cov_mat_sqrt_list.append(cov_mat_sqrt)
mean_list.append(mean)
sq_mean_list.append(sq_mean)
mean = self._flatten_params(mean_list)
sq_mean = self._flatten_params(sq_mean_list)
# Draw diagonal variance sample
var = torch.clamp(sq_mean - mean**2, min=TINY_EPSILON)
var_sample = var.sqrt() * torch.randn_like(var)
# If covariance draw low rank sample
if self._covariance:
cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1)
cov_sample = cov_mat_sqrt.t().matmul(
cov_mat_sqrt.new_empty((cov_mat_sqrt.size(0),)).normal_()
)
cov_sample /= (self._max_num_of_collected_models - 1) ** 0.5
rand_sample = var_sample + cov_sample
else:
rand_sample = var_sample
# Update sample with mean and scale
sample = mean + scale_sqrt * rand_sample
sample = sample.unsqueeze(0)
# Unflatten new sample like the mean sample
samples_list = self._unflatten_params_like(sample, mean_list)
for p, sample in zip(self._swag_parameters, samples_list):
setattr(p, "_sample", sample)
def _collect_model(self) -> None:
"""This method is used to collect the model.
It is done by updating the mean and covariance parameters.
"""
for p in self._swag_parameters:
mean = getattr(p, "_mean", None)
sq_mean = getattr(p, "_sq_mean", None)
# First moment
mean = mean * self._num_of_collected_models / (
self._num_of_collected_models + 1.0
) + p.data.detach().clone().cpu() / (self._num_of_collected_models + 1.0)
# Second moment
sq_mean = sq_mean * self._num_of_collected_models / (
self._num_of_collected_models + 1.0
) + p.data.detach().clone().cpu() ** 2 / (
self._num_of_collected_models + 1.0
)
# Square root of covariance matrix
if self._covariance:
cov_mat_sqrt = getattr(p, "_cov_mat_sqrt", None)
# Block covariance matrices, store deviation from current mean
dev = (p.data.detach().clone().cpu() - mean).view(-1, 1)
cov_mat_sqrt = torch.cat((cov_mat_sqrt, dev.view(-1, 1).t()), dim=0)
# remove first column if we have stored too many models
if (
self._num_of_collected_models + 1
) > self._max_num_of_collected_models:
cov_mat_sqrt = cov_mat_sqrt[1:, :]
# Update the parameters
setattr(p, "_cov_mat_sqrt", cov_mat_sqrt)
# Update the parameters
setattr(p, "_mean", mean)
setattr(p, "_sq_mean", sq_mean)
self._num_of_collected_models += 1
def _initialise_swag_model(self) -> None:
"""This method is used to initialise the SWAG model.
It iterates through all the parameters in the model and appends references to the parameters
to a list and initialises the attributes of the parameters.
"""
for module in self.model.modules():
# If there are no children it is a leaf module
if not len(list(module.children())) == 0:
continue
if not self._apply_to_normalisation and isinstance(
module,
(
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
nn.GroupNorm,
),
):
continue
for p in module.parameters():
data = p.data
setattr(p, "_mean", data.new(data.size()).zero_().cpu())
setattr(p, "_sq_mean", data.new(data.size()).zero_().cpu())
if self._covariance:
setattr(
p,
"_cov_mat_sqrt",
data.new_empty((0, data.numel())).zero_().cpu(),
)
setattr(p, "_training", copy.deepcopy(p.data).cpu())
self._swag_parameters.append(p)
[docs]
def on_train_epoch_end(self) -> None:
"""This method is used to collect the model at the end of each epoch."""
super().on_train_epoch_end()
if self._datamodule.validation_dataset() is None:
if self.current_epoch in self._epochs_to_collect:
logging.info(f"Collecting the model at epoch {self.current_epoch}.")
self._collect_model()
# The validation epoch start is not called if there is no validation dataset
self._update_training_weights()
[docs]
def on_validation_epoch_start(self) -> None:
"""This method is used to cache the training weights."""
super().on_validation_epoch_start()
if not self.trainer.sanity_checking:
if self.current_epoch in self._epochs_to_collect:
logging.info(f"Collecting the model at epoch {self.current_epoch}.")
self._collect_model()
self._update_training_weights()
[docs]
def on_train_epoch_start(self) -> None:
"""This method is used to set the model to training mode."""
super().on_train_epoch_start()
self._set_weights_to_training()
[docs]
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""This method is used to add the specific arguments for the class."""
parser = super(SWAGMethod, SWAGMethod).add_specific_args(parent_parser)
parser.add_argument(
"--method_covariance",
type=int,
choices=[0, 1],
default=0,
help="Whether to estimate the full covariance matrix.",
)
parser.add_argument(
"--method_fullrank",
type=int,
choices=[0, 1],
default=0,
help="Whether to use the full rank covariance matrix.",
)
parser.add_argument(
"--method_apply_to_normalisation",
type=int,
choices=[0, 1],
default=0,
help="Whether to apply the method to the normalisation layers.",
)
parser.add_argument(
"--method_scale",
type=float,
default=1.0,
help="The scale of the sampling.",
)
parser.add_argument(
"--method_num_members",
type=int,
default=1,
help="The number of members to be used for the prediction. Default: 1.",
)
parser.add_argument(
"--method_epochs_to_collect",
type=str,
nargs="+",
help="The epochs to collect the weights from.",
)
return parser