Source code for yamle.methods.mimo

from yamle.utils.optimization_utils import (
    LinearScalarScheduler,
    AVAILABLE_SCALAR_SCHEDULERS,
)
from yamle.utils.operation_utils import average_predictions
from yamle.utils.regularizer_utils import disable_regularizer
from yamle.models.operations import (
    OutputActivation,
    ParallelModel,
    LinearExtractor,
    Conv2dExtractor,
    ReshapeInput,
    ReshapeOutput,
)
from yamle.models.specific.mixmo import MixMoBlock, UnmixingBlock
from yamle.models.specific.mixvit import MixVitWrapper
from yamle.models.specific.datamux import Multiplexer, Demultiplexer
from yamle.models.visual_transformer import SpatialPositionalEmbedding
from yamle.models.specific.mimmo import MIMMMOWrapper
from yamle.methods.uncertain_method import MemberMethod
from typing import List, Dict, Any, Tuple, Optional
from yamle.evaluation.metrics.algorithmic import metrics_factory
from yamle.defaults import (
    TINY_EPSILON,
    LOSS_KEY,
    TARGET_KEY,
    TARGET_PER_MEMBER_KEY,
    MEAN_PREDICTION_KEY,
    PREDICTION_KEY,
    PREDICTION_PER_MEMBER_KEY,
    TRAIN_KEY,
    VALIDATION_KEY,
    TEST_KEY,
    MEMBERS_DIM,
    INPUT_KEY,
    AVERAGE_WEIGHTS_KEY,
    MIN_TENDENCY,
)
import torch
import torch.nn as nn
import argparse
import logging
import torchmetrics
import copy

from yamle.utils.specific.mimo_experiments.plotting_utils import (
    plot_input_layer_norm_bar,
    plot_output_layer_norm_bar,
    plot_weight_trajectories,
    plot_overlap_between_members,
)

logging = logging.getLogger("pytorch_lightning")


[docs] class MIMOMethod(MemberMethod): """This class is the extension of the base method for MIMO methods. The difference is in having to change the prediction to concatenate the `num_members` dimension. into the first feature dimension. Args: initialise_encoder_members_same (bool): Whether to initialise the members in the encoder with the same weights. num_batch_repetitions (int): The number of times some samples are repeated in the batch. input_repetition_probability (Optional[float]): The probability that the inputs are identical for the ensemble members. repeat_evaluation (bool): Whether to repeat samples in the evaluation. """ def __init__( self, initialise_encoder_members_same: bool = False, num_batch_repetitions: int = 1, input_repetition_probability: Optional[float] = None, repeat_evaluation: bool = True, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) assert ( num_batch_repetitions >= 1 ), "The number of batch duplication should be greater than or equal to 1." self._num_batch_repetitions = num_batch_repetitions self._initialise_encoder_members_same = initialise_encoder_members_same self._input_repetition_probability = input_repetition_probability # Remember all the input and output layer information. self._input_inputs_dim = None self._input_outputs_dim = None if isinstance(self.model._input, nn.Linear): self._input_inputs_dim = self.model._input.in_features self._input_outputs_dim = self.model._input.out_features self._input_layer_kwargs = { "in_features": self._input_inputs_dim, "out_features": self._input_outputs_dim, "bias": self.model._input.bias is not None, } elif isinstance(self.model._input, nn.Conv2d): self._input_inputs_dim = self.model._input.in_channels self._input_outputs_dim = self.model._input.out_channels self._input_layer_kwargs = { "in_channels": self._input_inputs_dim, "out_channels": self._input_outputs_dim, "kernel_size": self.model._input.kernel_size, "stride": self.model._input.stride, "padding": self.model._input.padding, "dilation": self.model._input.dilation, "groups": self.model._input.groups, "bias": self.model._input.bias is not None, "padding_mode": self.model._input.padding_mode, } elif isinstance(self.model._input, SpatialPositionalEmbedding): self._input_inputs_dim = self.model._input._to_patch_embedding[ 2 ].in_features self._input_outputs_dim = self.model._input._to_patch_embedding[ 2 ].out_features self._input_layer_kwargs = { "patch_size": self.model._input._patch_size, "inputs_dim": self.model._input._inputs_dim, "embedding_dim": self.model._input._embedding_dim, "dropout": self.model._input._dropout, "num_cls_tokens": self.model._input._num_cls_tokens, "positional_embedding": self.model._input._positional_embedding is not None, } self._input_layer_type = type(self.model._input) self._output_inputs_dim = None self._output_outputs_dim = None if isinstance(self.model._output, nn.Linear): self._output_inputs_dim = self.model._output.in_features self._output_outputs_dim = self.model._output.out_features self._output_layer_kwargs = { "in_features": self._output_inputs_dim, "out_features": self._output_outputs_dim, "bias": self.model._output.bias is not None, } elif isinstance(self.model._output, nn.Conv2d): self._output_inputs_dim = self.model._output.in_channels self._output_outputs_dim = self.model._output.out_channels self._output_layer_kwargs = { "in_channels": self._output_inputs_dim, "out_channels": self._output_outputs_dim, "kernel_size": self.model._output.kernel_size, "stride": self.model._output.stride, "padding": self.model._output.padding, "dilation": self.model._output.dilation, "groups": self.model._output.groups, "bias": self.model._output.bias is not None, "padding_mode": self.model._output.padding_mode, } self._output_layer_type = type(self.model._output) self._input_layer_overlap_container = torch.zeros((1, 1)) self._output_layer_overlap_container = torch.zeros((1, 1)) self._post_init() self._repeat_evaluation = repeat_evaluation def _create_metrics(self, metrics_kwargs: Dict[str, Any]) -> None: """This method is used to create the metrics to be used for training, validation and testing.""" self.metrics = { TRAIN_KEY: metrics_factory(**metrics_kwargs, per_member=True), VALIDATION_KEY: metrics_factory(**metrics_kwargs, per_member=True), TEST_KEY: metrics_factory(**metrics_kwargs, per_member=True), } def _post_init(self) -> None: """This method is called after the initialisation of the method.""" self._replace_input_layer(self._num_members) self._replace_output_layer(self._num_members) def _train_batch_repetition( self, batch: Tuple[torch.Tensor, torch.Tensor], num_members: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """A helper method to repeat some samples in the training batch to create a new batch. Steps: Randomly select `batch_size`/`batch_repetition` samples from the batch. Duplicate the selected samples `batch_repetition` times to create a new batch of size `batch_size`. Concatenate the new batch with the original batch to create a new batch of size `num_members*batch_size`. """ x, y = batch batch_size = x.shape[0] indices = torch.tile( torch.arange(batch_size), [self._num_batch_repetitions] ).to(x.device) # Shuffle the indices to create a new batch main_shuffle = indices[torch.randperm(indices.shape[0])] to_shuffle = ( torch.tensor(len(main_shuffle), device=main_shuffle.device).float() * (1.0 - self._input_repetition_probability) ).long() shuffle_indices = [ torch.cat( [ main_shuffle[:to_shuffle][torch.randperm(to_shuffle)], main_shuffle[to_shuffle:], ], dim=0, ) for _ in range(num_members) ] x = torch.stack([x[indices] for indices in shuffle_indices], dim=MEMBERS_DIM) y = torch.stack([y[indices] for indices in shuffle_indices], dim=MEMBERS_DIM) return (x, y), torch.stack(shuffle_indices, dim=MEMBERS_DIM) def _validation_test_batch_repetition( self, batch: Tuple[torch.Tensor, torch.Tensor], num_members: int ) -> Tuple[torch.Tensor, torch.Tensor]: """A helper method to repeat the samples in the `batch` `num_members` times.""" x, y = batch if self._repeat_evaluation: # If `True` the samples are repeated for all the members and they are the same batch = torch.stack( [x for _ in range(num_members)], dim=MEMBERS_DIM ), torch.stack([y for _ in range(num_members)], dim=MEMBERS_DIM) else: # If `False` the samples are randomly shuffled for each member xs = [] ys = [] for _ in range(num_members): indices = torch.randperm(x.shape[0]) xs.append(x[indices]) ys.append(y[indices]) batch = torch.stack(xs, dim=MEMBERS_DIM), torch.stack( ys, dim=MEMBERS_DIM) return batch def _replace_input_layer(self, num_members: int) -> None: """A helper function to replace the first layer with one where the input dimension is multiplied by the number of members.""" input_layer: nn.Module = None input_layer_kwargs = copy.deepcopy(self._input_layer_kwargs) if self._input_layer_type == nn.Linear: input_layer_kwargs["in_features"] = ( input_layer_kwargs["in_features"] * num_members ) input_layer = torch.nn.Linear(**input_layer_kwargs) elif self._input_layer_type == nn.Conv2d: input_layer_kwargs["in_channels"] = ( input_layer_kwargs["in_channels"] * num_members ) input_layer = torch.nn.Conv2d(**input_layer_kwargs) elif self._input_layer_type == SpatialPositionalEmbedding: input_layer_kwargs["inputs_dim"] = ( input_layer_kwargs["inputs_dim"][0] * num_members, *input_layer_kwargs["inputs_dim"][1:], ) input_layer_kwargs["num_cls_tokens"] = ( input_layer_kwargs["num_cls_tokens"] * num_members ) input_layer = SpatialPositionalEmbedding(**input_layer_kwargs) else: raise ValueError( f"Input layer type {self._input_layer_type} not supported." ) self.model._input = nn.Sequential(ReshapeInput(), input_layer).to( next(self.model.parameters()).device ) self._input_layer = self.model._input[1] if self._input_layer_type == SpatialPositionalEmbedding: self._input_layer = self.model._input[1]._to_patch_embedding[2] self.model._input.get_cls_token_indices = self.model._input[ 1 ].get_cls_token_indices if self._initialise_encoder_members_same: for member in range(1, num_members): self._initialise_input_layer_weights_same( source_member=0, target_member=member ) def _initialise_input_layer_weights_same( self, source_member: int, target_member: int ) -> None: """A helper method to initialise the weights of the input layer from the `source_member` to the `target_member`. The initialisation is done by copying the weights of the `source_member` to the `target_member`. """ assert ( source_member != target_member ), f"The source member and the target member should be different. Got {source_member} and {target_member}." assert self._input_layer_type in [ nn.Linear, nn.Conv2d, ], f"The input layer type should be either `torch.nn.Linear` or `torch.nn.Conv2d`, but it is {self._input_layer_type}." self._input_layer.weight.data[ :, target_member * self._input_inputs_dim: (target_member + 1) * self._input_inputs_dim, ] = self._input_layer.weight.data[ :, source_member * self._input_inputs_dim: (source_member + 1) * self._input_inputs_dim, ].clone() def _replace_output_layer(self, num_members: int) -> None: """Replace the last layer with one where the output dimension is multiplied by the number of members.""" output_layer: nn.Module = None output_layer_kwargs = copy.deepcopy(self._output_layer_kwargs) if self._output_layer_type == nn.Linear: output_layer_kwargs["out_features"] = ( output_layer_kwargs["out_features"] * num_members ) # If the input is SpatialPositionalEmbedding, # and the pooling type is `cls`, # then the output layer also needs to be multiplied # with respect to `num_cls_tokens`. if ( self._input_layer_type == SpatialPositionalEmbedding and self.model._pooling == "cls" ): output_layer_kwargs["in_features"] = ( output_layer_kwargs["in_features"] * self._input_layer_kwargs["num_cls_tokens"] * num_members ) output_layer = torch.nn.Linear(**output_layer_kwargs) elif self._output_layer_type == nn.Conv2d: output_layer_kwargs["out_channels"] = ( output_layer_kwargs["out_channels"] * num_members ) output_layer = torch.nn.Conv2d(**output_layer_kwargs) else: raise ValueError( f"The last layer of the model should be a `torch.nn.Linear` layer, but it is a \ {self._output_layer_type}." ) self.model._output = nn.Sequential( output_layer, ReshapeOutput(num_members=num_members) ).to(next(self.model.parameters()).device) self._output_layer = self.model._output[0] self.model._output_activation = OutputActivation( self.model._output_activation._task, dim=2 ) def _step( self, batch: List[torch.Tensor], batch_idx: int, optimizer_idx: Optional[int] = None, phase: str = TRAIN_KEY, num_members: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """A helper method to perform a single step whether training, validation or test.""" x, y = batch y_hat = self._predict(x, unsqueeze=False) loss = self._loss_per_member(y_hat, y) # We are in training and the input has not been repeated. y_permember = y y_hat_permember = y_hat if self.training: y_hat = y_hat.reshape(-1, 1, *y_hat.shape[2:]) x = x.reshape(-1, *x.shape[2:]) else: y = y[:, 0] # They are all the same. x = x[:, 0] # They are all the same. y = y.reshape(-1, *y.shape[2:]) y_hat_mean = average_predictions(y_hat, self._task) output = { LOSS_KEY: loss, PREDICTION_KEY: y_hat.detach(), TARGET_KEY: y.detach(), TARGET_PER_MEMBER_KEY: y_permember.detach(), PREDICTION_PER_MEMBER_KEY: y_hat_permember.detach(), MEAN_PREDICTION_KEY: y_hat_mean.detach(), INPUT_KEY: x.detach(), } return output def _validation_test_step( self, batch: List[torch.Tensor], batch_idx: int, optimizer_idx: int = 0, phase: str = VALIDATION_KEY, num_members: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """This method is used to perform a single validation or testing step. It assumes that the batch has a shape `(batch_size, num_features)`. The inputs need to be repeated `num_members` times. Such that the input then has a shape `(batch_size * num_members, num_features)`. The output of the model has a shape `(batch_size, num_members, num_classes)`. """ new_batch = self._validation_test_batch_repetition( batch, num_members=num_members ) return self._step( new_batch, batch_idx, optimizer_idx=optimizer_idx, phase=phase, num_members=num_members, ) def _training_step( self, batch: List[torch.Tensor], batch_idx: int, optimizer_idx: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """This method is used to perform a single training step. It assumes that the batch has a shape `(batch_size, num_features)`. It assumes that the output of the model has a shape `(batch_size, n_samples, num_classes)`. """ new_batch, _ = self._train_batch_repetition( batch, num_members=self._num_members ) output = self._step( new_batch, batch_idx, optimizer_idx=optimizer_idx, phase=TRAIN_KEY, num_members=self._num_members, ) return output def _validation_step( self, batch: List[torch.Tensor], batch_idx: int ) -> Dict[str, Any]: """This method is used to perform a single validation step. It assumes that the batch has a shape `(batch_size, num_features)`. It assumes that the output of the model has a shape `(batch_size, n_samples, num_classes)`. """ return self._validation_test_step( batch, batch_idx, optimizer_idx=None, phase=VALIDATION_KEY, num_members=self._num_members, ) def _test_step(self, batch: List[torch.Tensor], batch_idx: int) -> Dict[str, Any]: """This method is used to perform a single test step. It assumes that the batch has a shape `(batch_size, num_features)`. It assumes that the output of the model has a shape `(batch_size, n_samples, num_classes)`. """ return self._validation_test_step( batch, batch_idx, optimizer_idx=None, phase=TEST_KEY, num_members=self._num_members, ) def _plots(self) -> None: output_weights_per_member = [] input_weights_per_member = [] for member in range(self._num_members): output_weights_per_member.append( self._output_layer.weight.data[ member * self._output_outputs_dim: (member + 1) * self._output_outputs_dim ] ) input_weights_per_member.append( self._input_layer.weight.data[ :, member * self._input_inputs_dim: (member + 1) * self._input_inputs_dim, ] ) input_overlap = plot_input_layer_norm_bar( input_weights_per_member, self._save_path, self.current_epoch ) output_overlap = plot_output_layer_norm_bar( output_weights_per_member, self._save_path, self.current_epoch ) self._input_layer_overlap_container = torch.cat( ( self._input_layer_overlap_container, input_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) self._output_layer_overlap_container = torch.cat( ( self._output_layer_overlap_container, output_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) plot_overlap_between_members( self._input_layer_overlap_container, self._save_path, input=True ) plot_overlap_between_members( self._output_layer_overlap_container, self._save_path, input=False )
[docs] def on_train_epoch_end(self) -> None: """This method is called at the end of the training epoch.""" super().on_train_epoch_end() if self._plotting_training: self._plots()
[docs] def state_dict(self) -> Dict[str, Any]: """This method returns the state dict of the MIMO method.""" state_dict = super().state_dict() state_dict[ "input_layer_overlap_container" ] = self._input_layer_overlap_container state_dict[ "output_layer_overlap_container" ] = self._output_layer_overlap_container return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """This method loads the state dict of the MIMO method.""" super().load_state_dict(state_dict) self._input_layer_overlap_container = state_dict[ "input_layer_overlap_container" ] self._output_layer_overlap_container = state_dict[ "output_layer_overlap_container" ]
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method adds the specific arguments for the MIMO method.""" parser = super(MIMOMethod, MIMOMethod).add_specific_args(parent_parser) parser.add_argument( "--method_initialise_encoder_members_same", type=int, choices=[0, 1], default=0, help="Whether to initialise the members of the ensemble with the same weights.", ) parser.add_argument( "--method_num_batch_repetitions", type=int, default=2, help="The number of times some samples are repeated in the batch.", ) parser.add_argument( "--method_input_repetition_probability", type=float, default=0.0, help="The value to start the input repetition probability.", ) parser.add_argument( "--method_repeat_evaluation", type=int, choices=[0, 1], default=1, help="Whether to repeat the samples in the evaluation.", ) return parser
[docs] class MIMMOMethod(MIMOMethod): """This class is the extension of the MIMO method in which we will try to find the depth for each ensemble member. Args: alpha (float): The alpha value to regularize the depth loss term. prior (str): The prior to use for the depth weights. do_not_optimize_depth_weights (bool): Whether to optimize the depth weights or not. additional_heads (bool): Whether to enable specific heads for each layer. available_heads (List[bool]): Toggles to hard enable or disable heads. warm_starting_epochs (int): The number of epochs to train the model without changing the member or depth weights. """ def __init__( self, alpha: float = 1.0, prior: str = "uniform", do_not_optimize_depth_weights: bool = False, additional_heads: bool = False, available_heads: List[bool] = None, warm_starting_epochs: int = 0, *args: Any, **kwargs: Any, ) -> None: model_depth = kwargs["model"]._depth if available_heads is None: available_heads = [True] * (model_depth) available_heads = [bool(x) for x in available_heads] assert ( len(available_heads) == model_depth ), f"The length of the available heads should be {model_depth} but it is {len(available_heads)}." kwargs["metrics_kwargs"]["num_members"] = ( len(available_heads) * kwargs["metrics_kwargs"]["num_members"] ) super().__init__(*args, **kwargs) if not hasattr(self.model, "_depth"): raise ValueError( "The model should have a `_depth` attribute which is the number of hidden layers." ) self._depth = sum(available_heads) self._available_heads = available_heads self._additional_heads = additional_heads self._do_not_optimize_depth_weights = do_not_optimize_depth_weights self.model.add_method_specific_layers( method="mimmo", heads=additional_heads, num_members=self._num_members, available_heads=available_heads[:-1], ) self._loss.set_reduction_per_member("sum") self._loss.set_reduction_per_sample("sum") logging.warning( f"The reduction per member is set to {self._loss._reduction_per_member} and the reduction per sample is set to {self._loss._reduction_per_sample}." ) # Add the alpha and beta learnable parameters to the model if hasattr(self.model, "_depth_weights"): raise ValueError( "The model should not have a `_depth_weights` attribute.") if hasattr(self.model, "_prior_depth_weights"): raise ValueError( "The model should not have a `_prior_depth_weights` attribute." ) assert alpha >= 0.0, f"The alpha value needs to be positive, got {alpha}." self._alpha = alpha assert ( warm_starting_epochs >= 0 ), f"Warm starting epochs should be non-negative, got {warm_starting_epochs}." self._warm_starting_epochs = warm_starting_epochs self.model._depth_weights = torch.nn.Parameter( torch.zeros((self._depth, self._num_members)), requires_grad=True ) disable_regularizer(self.model._depth_weights) if prior == "uniform": self.model.register_buffer( "_prior_depth_weights", torch.ones_like(self.model._depth_weights) / self._depth, ) else: raise NotImplementedError( f"Prior {prior} not implemented. Only `uniform` and `early` are supported." ) self._depth_weight_container = ( self.model._prior_depth_weights.clone().unsqueeze(0) ) self.model._available_heads = self._available_heads self.model._additional_heads = self._additional_heads self.model = MIMMMOWrapper( self.model, evaluation_depth_weights_function=self._evaluation_depth_weights ) def _create_metrics(self, metrics_kwargs: Dict[str, Any]) -> None: """This method creates the metrics for the method.""" super()._create_metrics(metrics_kwargs) self._add_additional_metrics( { f"{LOSS_KEY}_kl_depth": torchmetrics.MeanMetric(), f"{LOSS_KEY}_individual": torchmetrics.MeanMetric(), }, tendencies=[MIN_TENDENCY, MIN_TENDENCY], ) @property def _prior_depth_weights(self) -> torch.Tensor: """This method returns the prior depth weights.""" return self.model._prior_depth_weights def _train_depth_weights(self) -> torch.Tensor: """This method returns the depth weights in the step function.""" if ( self.current_epoch < self._warm_starting_epochs and self.training ) or self._do_not_optimize_depth_weights: return self._prior_depth_weights return self._evaluation_depth_weights() def _evaluation_depth_weights(self) -> torch.Tensor: """This method returns the depth weights in the step function.""" return torch.softmax(self.model._depth_weights, dim=0) @property def _depth_weights(self) -> torch.Tensor: """This method returns the depth weights in the step function.""" if self.training: return self._train_depth_weights() return self._evaluation_depth_weights() def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor: """This method is used to perform a forward pass of the model. It is done with respect to the number of hidden layers or how hidden layers are being defined in the underlying `self.model`. """ return self.model(x, **forward_kwargs) def _loss_per_depth_per_member( self, y_hat: torch.Tensor, y: torch.Tensor, depth_weights: torch.Tensor ) -> torch.Tensor: """This method computes the loss per depth per member.""" loss = 0.0 for i in range(self._depth): for j in range(self._num_members): # Compute the loss for each depth and member # Add the loss to the total loss loss += ( self._loss(y_hat[:, i, j].unsqueeze(1), y[:, i, j]) * depth_weights[i, j] ) return loss def _loss_kl_depth_weights(self) -> torch.Tensor: """This method computes the KL divergence loss for the depth weights.""" if ( self.current_epoch < self._warm_starting_epochs or self._do_not_optimize_depth_weights ): return torch.tensor(0.0, device=self.device) depth_weights = torch.softmax(self.model._depth_weights, dim=0) return torch.sum( depth_weights * torch.log( depth_weights / (self._prior_depth_weights + TINY_EPSILON) + TINY_EPSILON ) ) def _step( self, batch: List[torch.Tensor], batch_idx: int, optimizer_idx: Optional[int] = None, phase: str = TRAIN_KEY, num_members: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """A helper method to perform a single step whether training, validation or test.""" output = {} x, y = batch # N_train = self._datamodule.train_dataset_size(split=optimizer_idx) batch_size = x.shape[0] # Repeat the labels for the depth dimension # The labels have the shape `(batch_size, depth, num_members)` # The predictions have the shape `(batch_size, depth, num_members, predictions)` y_depth = torch.stack([y] * self._depth, dim=1) y_hat = self._predict(x) # Construct the weights for the loss # The depth weights have the shape `(depth, num_members)` depth_weights = self._depth_weights # loss = -(N_train/batch_size) * self._loss_per_depth_per_member( # y_hat, y_depth, depth_weights) loss = -(1 / batch_size) * self._loss_per_depth_per_member( y_hat, y_depth, depth_weights ) loss_kl_depth = self._loss_kl_depth_weights() # output[f"{LOSS_KEY}_individual"] = -loss.detach()/N_train output[f"{LOSS_KEY}_individual"] = -loss.detach() # output[f"{LOSS_KEY}_kl_depth"] = ( # self._alpha * loss_kl_depth).detach()/N_train output[f"{LOSS_KEY}_kl_depth"] = ( self._alpha * loss_kl_depth ).detach() loss = loss - self._alpha * loss_kl_depth # loss = -loss/N_train loss = -loss # Weight the predictions per depth and member and then sum across the depth dimension depth_weights = depth_weights.unsqueeze(0) while len(depth_weights.shape) < len(y_hat.shape): depth_weights = depth_weights.unsqueeze(-1) # Normalize the weights per depth to sum to 1 depth_weights = depth_weights / \ torch.sum(depth_weights, dim=1, keepdim=True) # Divide the depth weights by the number of members depth_weights = depth_weights / self._num_members # Reshape the depth weights and predictions to have the shape `(batch_size, depth*num_members, predictions)` depth_weights = depth_weights.reshape( 1, depth_weights.shape[1] * depth_weights.shape[2] ) y_hat = y_hat.reshape( y_hat.shape[0], y_hat.shape[1] * y_hat.shape[2], *y_hat.shape[3:] ) y_depth = y_depth.reshape( y_depth.shape[0], y_depth.shape[1] * y_depth.shape[2], *y_depth.shape[3:] ) # We are in training and the input has not been repeated. y_permember = y_depth y_hat_permember = y_hat if self.training: y_hat = y_hat.reshape(-1, 1, *y_hat.shape[2:]) y = y_depth.reshape(-1, 1, *y_depth.shape[2:]) x = x.reshape(-1, *x.shape[2:]) average_weights = None else: y = y[:, 0] # They are all the same. x = x[:, 0] # They are all the same. # Repeat the average weights for all samples in batch average_weights = depth_weights.repeat(batch_size, 1) y = y.reshape(-1, *y.shape[2:]) y_hat_mean = average_predictions( y_hat, self._task, weights=average_weights) output.update( { LOSS_KEY: loss, PREDICTION_KEY: y_hat.detach(), TARGET_KEY: y.detach(), TARGET_PER_MEMBER_KEY: y_permember.detach(), PREDICTION_PER_MEMBER_KEY: y_hat_permember.detach(), MEAN_PREDICTION_KEY: y_hat_mean.detach(), INPUT_KEY: x.detach(), AVERAGE_WEIGHTS_KEY: average_weights.detach() if average_weights is not None else None, } ) return output def _plots(self) -> None: super()._plots() plot_weight_trajectories(self._depth_weight_container, self._save_path)
[docs] def on_train_epoch_end(self) -> None: """A helper method to log the weights of the ensemble.""" depth_weights = self._evaluation_depth_weights() logging.info( f"Depth weights at the end of training epoch: {depth_weights}") self._depth_weight_container = torch.cat( (self._depth_weight_container, depth_weights.unsqueeze(0).detach().cpu()), dim=0, ) super().on_train_epoch_end()
[docs] def state_dict(self) -> Dict[str, Any]: """A helper method to save the weights of the ensemble.""" state_dict = super().state_dict() state_dict["depth_weight_container"] = self._depth_weight_container return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """A helper method to load the weights of the ensemble.""" super().load_state_dict(state_dict) self._depth_weight_container = state_dict["depth_weight_container"]
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """The method to add additional arguments.""" parser = super(MIMMOMethod, MIMMOMethod).add_specific_args( parent_parser) parser.add_argument( "--method_alpha", type=float, default=1.0, help="The alpha value.", ) parser.add_argument( "--method_prior", type=str, default="uniform", help="The prior for the depth weights.", ) parser.add_argument( "--method_additional_heads", type=int, default=1, choices=[0, 1], help="Whether to use a single head or a head per depth.", ) parser.add_argument( "--method_do_not_optimize_depth_weights", type=int, default=0, choices=[0, 1], help="Whether to optimize the depth weights.", ) parser.add_argument( "--method_available_heads", type=str, default=None, help="The available depth heads for the method.", ) parser.add_argument( "--method_warm_starting_epochs", type=int, default=0, help="The number of epochs to warm start the model.", ) return parser
[docs] class MixMoMethod(MIMOMethod): """This is a module which applies the MixMo regularization to a model. As proposed in: "MixMo: Mixing Multiple Inputs for Multiple Outputs via Deep Subnetworks" Show that binary mixing in features - particularly with rectangular patches from CutMix - enhances results by making subnetworks stronger and more diverse. Args: alpha (float): The alpha parameter for the Dirichlet distribution. r (float): The `r` parameter applied to the weighting factor. initial_p (float): Initial probability of applying linear or cutmix augmentation. """ def __init__( self, alpha: float = 0.5, r: float = 0.5, initial_p: float = 0.5, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) assert ( 0 < alpha ), f"The alpha parameter needs to be greater than 0. Got {alpha}." assert 0 < r, f"The r parameter needs to be greater than 0. Got {r}." assert ( 0 <= initial_p <= 1 ), f"The p parameter needs to be between 0 and 1. Got {initial_p}." self._alpha = alpha self._r = r self._p = initial_p self._initial_p = initial_p self._dirichlet = torch.distributions.dirichlet.Dirichlet( torch.tensor([alpha] * self._num_members) ) def _replace_input_layer(self, num_members: int) -> None: # Replace the first layer with one where the input dimension is multiplied by the number of # members. if self._input_layer_type == nn.Linear: self.model._input = ParallelModel( [ torch.nn.Linear( **self._input_layer_kwargs, ) for _ in range(num_members) ] ) elif self._input_layer_type == nn.Conv2d: self.model._input = ParallelModel( [ torch.nn.Conv2d( **self._input_layer_kwargs, ) for _ in range(num_members) ] ) else: raise ValueError( f"The first layer of the model should be either a `torch.nn.Linear` or a \ `torch.nn.Conv2d` layer, but it is a {type(self.model._input)}." ) self.model._input = MixMoBlock(num_members, self.model._input) def _plots(self) -> None: output_weights_per_member = [] input_weights_per_member = [] for member in range(self._num_members): output_weights_per_member.append( self.model._output[0].weight.data[ member * self._output_outputs_dim: (member + 1) * self._output_outputs_dim ] ) input_weights_per_member.append( self.model._input._input[member].weight.data ) input_overlap = plot_input_layer_norm_bar( input_weights_per_member, self._save_path, self.current_epoch ) output_overlap = plot_output_layer_norm_bar( output_weights_per_member, self._save_path, self.current_epoch ) self._input_layer_overlap_container = torch.cat( ( self._input_layer_overlap_container, input_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) self._output_layer_overlap_container = torch.cat( ( self._output_layer_overlap_container, output_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) plot_overlap_between_members( self._input_layer_overlap_container, self._save_path, input=True ) plot_overlap_between_members( self._output_layer_overlap_container, self._save_path, input=False )
[docs] def on_train_epoch_start(self) -> None: """A method which is called at the start of the training epoch.""" if self.current_epoch > (11 / 12) * self.trainer.max_epochs: self._p = ( self._initial_p * (self.trainer.max_epochs - self.current_epoch) / (self.trainer.max_epochs * (1 / 12)) ) else: self._p = self._initial_p self.log("p", self._p, on_epoch=True, on_step=False) super().on_train_epoch_start()
def _step( self, batch: List[torch.Tensor], batch_idx: int, optimizer_idx: Optional[int] = None, phase: str = TRAIN_KEY, num_members: Optional[int] = None, ) -> Dict[str, torch.Tensor]: """A helper method to perform a single step whether training, validation or test.""" x, y = batch K = None w_r = None # Sample K if self.training: K = self._dirichlet.sample((x.shape[0],)).to(x.device) input_kwargs = {"K": K, "p": self._p} y_hat = self._predict(x, input_kwargs=input_kwargs, unsqueeze=False) if K is not None: w_r = num_members * ( (K) ** (1 / self._r) / torch.sum(K ** (1 / self._r) + TINY_EPSILON, dim=1, keepdim=True) ) y_permember = y y_hat_permember = y_hat loss = self._loss_per_member(y_hat, y, weights_per_sample=w_r) # We are in training and the input has not been repeated. if self.training: y_hat = y_hat.reshape(-1, 1, *y_hat.shape[2:]) x = x.reshape(-1, 1, *x.shape[2:]) else: y = y[:, 0] # They are all the same. x = x[:, 0] # They are all the same. y = y.reshape(-1) y_hat_mean = average_predictions(y_hat, self._task) output = { LOSS_KEY: loss, PREDICTION_KEY: y_hat.detach(), TARGET_KEY: y.detach(), TARGET_PER_MEMBER_KEY: y_permember.detach(), PREDICTION_PER_MEMBER_KEY: y_hat_permember.detach(), MEAN_PREDICTION_KEY: y_hat_mean.detach(), INPUT_KEY: x.detach(), } if self.training: output["new_batch"] = batch return output
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method adds the specific arguments for the MixMo method.""" parser = super(MixMoMethod, MixMoMethod).add_specific_args( parent_parser) parser.add_argument( "--method_alpha", type=float, default=2.0, help="The alpha parameter for the MixMo regularization.", ) parser.add_argument( "--method_r", type=float, default=3, help="The r parameter for the MixMo regularization.", ) parser.add_argument( "--method_initial_p", type=float, default=0.5, help="The initial value for the p parameter.", ) return parser
[docs] class MixVitMethod(MIMOMethod): """This is a module which applies MixToken augmentation to a vision transformer. Args: depth (int): The depth at which to add the source attribution. """ def __init__(self, depth: int, *args: Any, **kwargs: Any) -> None: logging.warning( "MixVitMethod is only supported for 2 members. It is not trivial to extend it to more members." ) kwargs["num_members"] = 2 super().__init__(*args, **kwargs) self._input_layer_overlap_container = None self._depth = depth self.model = MixVitWrapper(self.model, self._depth) self._input_layer_overlap_container = None def _post_init(self) -> None: pass def _plots(self) -> None: output_weights_per_member = [] for member in range(self._num_members): output_weights_per_member.append( self.model._vit._output.weight.data[ member * self._output_outputs_dim: (member + 1) * self._output_outputs_dim ] ) output_overlap = plot_output_layer_norm_bar( output_weights_per_member, self._save_path, self.current_epoch ) self._output_layer_overlap_container = torch.cat( ( self._output_layer_overlap_container, output_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) plot_overlap_between_members( self._output_layer_overlap_container, self._save_path, input=False )
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method adds the specific arguments for the MixMo method.""" parser = super(MixVitMethod, MixVitMethod).add_specific_args( parent_parser) parser.add_argument( "--method_depth", type=int, default=2, help="The depth at which to add the source attribution.", ) return parser
[docs] class UnMixMoMethod(MixMoMethod): """This is a module which applies the unmixing regularization to the model in addition to MixMo. The key concept is that instead of processing the inputs just through a convolutional or a linear layer, each input is mixed and then unmixed selectively depending on the mixing mask. Precisely it focuses on `fadeout` unmixing which has demonstrated to be more effective than full unmixing. The `fadeout` unmixing is a gradual unmixing of the inputs. Args: m_start_value (Optional[float]): The initial value for the `m` parameter. Defaults to 0. m_end_value (Optional[float]): The final value for the `m` parameter. Defaults to 1. m_start_epoch (Optional[int]): The epoch at which the `m` parameter starts to increase. Defaults to 0. m_end_epoch (Optional[int]): The epoch at which the `m` parameter reaches its final value. Defaults to 100. """ def __init__( self, m_start_value: float = 0, m_end_value: float = 1, m_start_epoch: int = 0, m_end_epoch: int = 100, *args: Any, **kwargs: Any, ) -> None: assert ( kwargs["num_members"] == 2 ), "The number of members should be 2 for the UnMixMo method." kwargs["initial_p"] = 1.0 logging.warning( "The UnMixMo method uses probability 1.0 for cutmix mixing, which is constant throughout training." ) super().__init__(*args, **kwargs) assert ( m_start_value <= m_end_value ), f"The start value {m_start_value} should be less than or equal to the end value {m_end_value}." self._m_scheduler = LinearScalarScheduler( start_value=m_start_value, end_value=m_end_value, start_epoch=m_start_epoch, end_epoch=m_end_epoch, ) def _replace_output_layer(self, num_members: int) -> None: """Replace the last layer with one where the output dimension is multiplied by the number of members.""" output_layer: nn.Module = None if self._output_layer_type == nn.Linear: output_layer = UnmixingBlock( mixmo_block=self.model._input, in_features=self._output_layer_kwargs["in_features"], out_features=self._output_layer_kwargs["out_features"], num_members=num_members, outputs_dim=1, ) # If the last two layers in `model._layers` are `nn.AdaptiveAvgPool2d` and `nn.Flatten`, then remove them. if ( len(self.model._layers) >= 2 and isinstance(self.model._layers[-1], nn.Flatten) and isinstance(self.model._layers[-2], nn.AdaptiveAvgPool2d) ): self.model._layers = self.model._layers[:-2] else: raise ValueError( f"The last layer of the model should be a `torch.nn.Linear` layer, but it is a \ {self._output_layer_type}." ) self.model._output = output_layer.to( next(self.model.parameters()).device) self.model._output_activation = OutputActivation( self.model._output_activation._task, dim=2 ) def _plots(self) -> None: output_weights_per_member = [] input_weights_per_member = [] for member in range(self._num_members): output_weights_per_member.append( self.model._output._output[member].weight.data ) input_weights_per_member.append( self.model._input._input[member].weight.data ) input_overlap = plot_input_layer_norm_bar( input_weights_per_member, self._save_path, self.current_epoch ) output_overlap = plot_output_layer_norm_bar( output_weights_per_member, self._save_path, self.current_epoch ) self._input_layer_overlap_container = torch.cat( ( self._input_layer_overlap_container, input_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) self._output_layer_overlap_container = torch.cat( ( self._output_layer_overlap_container, output_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) plot_overlap_between_members( self._input_layer_overlap_container, self._save_path, input=True ) plot_overlap_between_members( self._output_layer_overlap_container, self._save_path, input=False )
[docs] def on_train_epoch_start(self) -> None: """A method which is called at the start of the training epoch.""" self.model._output.set_m(self._m_scheduler.get_value()) # Use the grandparent class to avoid calling the MixMoMethod on_train_epoch_start method. super(MixMoMethod, self).on_train_epoch_start()
[docs] def on_train_epoch_end(self) -> None: """A method which is called at the end of the training epoch.""" # Use the grandparent class to avoid calling the MixMoMethod on_train_epoch_end method. self._m_scheduler.step() self.log("m", self._m_scheduler.get_value(), on_epoch=True, on_step=False) super(MixMoMethod, self).on_train_epoch_end()
[docs] def state_dict(self) -> Dict[str, Any]: state_dict = super().state_dict() state_dict["m_scheduler"] = self._m_scheduler.state_dict() return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: super().load_state_dict(state_dict) self._m_scheduler.load_state_dict(state_dict["m_scheduler"])
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method adds the specific arguments for the UnMixMo method.""" parser = super(UnMixMoMethod, UnMixMoMethod).add_specific_args( parent_parser) parser.add_argument( "--method_m_start_value", type=float, default=0, help="The initial value for the m parameter.", ) parser.add_argument( "--method_m_end_value", type=float, default=1, help="The final value for the m parameter.", ) parser.add_argument( "--method_m_start_epoch", type=int, default=0, help="The epoch at which the m parameter starts to increase.", ) parser.add_argument( "--method_m_end_epoch", type=int, default=100, help="The epoch at which the m parameter reaches its final value.", ) return parser
[docs] class DataMUXMethod(MIMOMethod): """This is a module which applies the DataMUX multiplexing and demultiplexing to the input and output of the model. As proposed in: DataMUX: Data Multiplexing for Neural Networks. The key concept is that instead of processing the inputs just through a convolutional or a linear layer, each input has a dedicated encoder and the output is processed through a separate decoder to give predictions. Args: coder_expansion_factor (int): The expansion factor for the coder. Defaults to 1. coder_depth (int): The depth of the coder. Defaults to 1. """ def __init__( self, coder_expansion_factor: int = 1, coder_depth: int = 1, *args: Any, **kwargs: Any, ) -> None: self._coder_expansion_factor = coder_expansion_factor self._coder_depth = coder_depth super().__init__(*args, **kwargs) def _replace_input_layer(self, num_members: int) -> None: """A helper method to replace the input layer with a DataMUXInput layer.""" coder: List[nn.Module] = None inputs_dim = self._input_inputs_dim input_layer: nn.Module = None if self._input_layer_type == nn.Linear: coder = [ LinearExtractor( inputs_dim=inputs_dim, expansion_factor=self._coder_expansion_factor, depth=self._coder_depth, norm=True, outputs_dim=int(inputs_dim * self._coder_expansion_factor), end_activation=True, activation="ReLU", end_normalization=True, ) for _ in range(num_members) ] inputs_dim = int( inputs_dim * self._coder_expansion_factor * num_members) input_layer_kwargs = copy.deepcopy(self._input_layer_kwargs) input_layer_kwargs["in_features"] = inputs_dim input_layer = nn.Linear(**input_layer_kwargs) # If the model has _flatten layer, we need to remove it. # We assume that the _flatten layer is the first layer before the input layer. if hasattr(self.model, "_flatten"): self.model._flatten = nn.Identity() elif self._input_layer_type == nn.Conv2d: coder = [ Conv2dExtractor( input_channels=inputs_dim, expansion_factor=self._coder_expansion_factor, depth=self._coder_depth, norm=True, output_channels=int( inputs_dim * self._coder_expansion_factor), convolution="conv2d", end_activation=True, activation="ReLU", end_normalization=True, ) for _ in range(num_members) ] inputs_dim = int( inputs_dim * self._coder_expansion_factor * num_members) input_layer_kwargs = copy.deepcopy(self._input_layer_kwargs) input_layer_kwargs["in_channels"] = inputs_dim input_layer = nn.Conv2d(**input_layer_kwargs) else: raise ValueError( f"The input layer should be either a Linear or a Conv2d layer. Got {self.model._input}." ) multiplexer = Multiplexer( precoder=None, coder=coder, postcoder=None, reduction="cat", reduction_normalization=None, feature_regularizer=None, inputs_dim=1, outputs_dim=1, ) self.model._input = nn.Sequential(multiplexer, input_layer).to( next(self.model.parameters()).device ) def _replace_output_layer(self, num_members: int) -> None: """A helper method to replace the output layer with a DataMUXOutput layer.""" demultiplexer: nn.Module = None if self._output_layer_type == nn.Linear: demultiplexer = Demultiplexer( parallel_layers=[ LinearExtractor( inputs_dim=self._output_inputs_dim, expansion_factor=1, depth=1, norm=False, end_normalization=False, end_activation=False, activation="ReLU", outputs_dim=self._output_outputs_dim, ) for _ in range(num_members) ], outputs_dim=1, ) else: raise ValueError( f"The output layer should be a Linear layer. Got {self.model._output}." ) self.model._output = demultiplexer.to( next(self.model.parameters()).device) self.model._output_activation = OutputActivation( self.model._output_activation._task, dim=2 ) def _plots(self) -> None: output_weights_per_member = [] input_weights_per_member = [] for member in range(self._num_members): output_weights_per_member.append( self.model._output._parallel_layers[member]._model[0].weight.data ) input_weights_per_member.append( self.model._input[1].weight.data[ :, member * self._input_inputs_dim: (member + 1) * self._input_inputs_dim, ] ) input_overlap = plot_input_layer_norm_bar( input_weights_per_member, self._save_path, self.current_epoch ) output_overlap = plot_output_layer_norm_bar( output_weights_per_member, self._save_path, self.current_epoch ) self._input_layer_overlap_container = torch.cat( ( self._input_layer_overlap_container, input_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) self._output_layer_overlap_container = torch.cat( ( self._output_layer_overlap_container, output_overlap.unsqueeze(0).unsqueeze(0).detach().cpu(), ), dim=0, ) plot_overlap_between_members( self._input_layer_overlap_container, self._save_path, input=True ) plot_overlap_between_members( self._output_layer_overlap_container, self._save_path, input=False )
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """Adds the specific arguments for the DataMUXMethod.""" parser = super(DataMUXMethod, DataMUXMethod).add_specific_args( parent_parser) parser.add_argument( "--method_coder_expansion_factor", type=int, default=1, help="The expansion factor for the coder. Defaults to 1.", ) parser.add_argument( "--method_coder_depth", type=int, default=1, help="The depth of the coder. Defaults to 1.", ) return parser