Source code for yamle.methods.pe

from typing import Any
import torch
import torch.nn as nn
import argparse

from yamle.methods.mimo import MIMOMethod
from yamle.models.operations import (
    Unsqueeze,
    OutputActivation,
    ReshapeOutput,
    ReshapeInput,
)


[docs] def replace_layers_with_grouped_convs( model: nn.Module, M: int, alpha: int, gamma: int ) -> nn.Module: """This method replaces all `nn.Linear` and `nn.Conv2d` layers with grouped versions. Each layer is replaced with a layer where the input and output dimensions are multiplied by the `alpha`. Args: model (nn.Module): The model to replace the layers of. M (int): The number of members in the ensemble. alpha (int): The width multiplier. gamma (int): The subgroups multiplier. """ for name, child in model.named_children(): if isinstance(child, nn.Linear) and not hasattr(child, "do_not_replace"): setattr( model, name, nn.Conv2d( child.in_features * alpha, child.out_features * alpha, kernel_size=1, groups=gamma * M, bias=child.bias is not None, ), ) elif isinstance(child, nn.Conv2d) and not hasattr(child, "do_not_replace"): setattr( model, name, nn.Conv2d( child.in_channels * alpha, child.out_channels * alpha, child.kernel_size, child.stride, child.padding, groups=gamma * M, bias=child.bias is not None, ), ) elif isinstance(child, nn.BatchNorm2d) and not hasattr(child, "do_not_replace"): setattr(model, name, nn.BatchNorm2d(child.num_features * alpha)) else: replace_layers_with_grouped_convs(child, M, alpha, gamma) return model
[docs] class PEMethod(MIMOMethod): """This class is the extension of the base method for packed-ensemble methods. Args: alpha (int): The expansion multiplier for the width of the model. It is used to multiply the number of input and output channels of each layer. gamma (int): The subgroups multiplier. It is used to multiply the number of groups in each convolutional layer together with the `num_members` parameter, such as `(num_members * gamma)`. """ def __init__(self, alpha: int, gamma: int, *args: Any, **kwargs: Any) -> None: assert ( alpha >= 1 ), f"The alpha parameter should be larger or equal to 1, but got {alpha}." assert ( gamma >= 1 ), f"The gamma parameter should be larger or equal to 1, but got {gamma}." assert ( int(alpha) == alpha ), f"The alpha parameter should be an integer, but got {alpha}." assert ( int(gamma) == gamma ), f"The gamma parameter should be an integer, but got {gamma}." self._alpha = alpha self._gamma = gamma super(PEMethod, self).__init__(*args, **kwargs) replace_layers_with_grouped_convs( self.model, self._num_members, self._alpha, self._gamma ) def _post_init(self) -> None: """This method is called after the initialization of the method.""" self._replace_input_and_output_layers()
[docs] def analyse(self, save_path: str) -> None: """This method analyses the model and saves the results to a file.""" pass
def _replace_input_and_output_layers(self) -> None: # Replace the first layer with one where the input dimension is multiplied by the number of # members. if isinstance(self.model._input, nn.Linear): self.model._input = nn.Sequential( Unsqueeze(shape_length=4), ReshapeInput(), nn.Conv2d( in_channels=self.model._input.in_features * self._num_members, out_channels=self.model._input.out_features * self._alpha, kernel_size=1, groups=self._gamma * self._num_members, bias=self.model._input.bias is not None, ), ) self.model._input[1].do_not_replace = True elif isinstance(self.model._input, torch.nn.Conv2d): self.model._input = nn.Sequential( ReshapeInput(), torch.nn.Conv2d( in_channels=self.model._input.in_channels * self._num_members, out_channels=self.model._input.out_channels * self._alpha, kernel_size=self.model._input.kernel_size, stride=self.model._input.stride, padding=self.model._input.padding, groups=self._gamma * self._num_members, bias=self.model._input.bias is not None, ), ) self.model._input[1].do_not_replace = True else: raise ValueError( "The first layer of the model should be either a `torch.nn.Linear` or a " "`torch.nn.Conv2d`." ) if isinstance(self.model._output, torch.nn.Linear): self.model._output = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d( self.model._output.in_features * self._alpha, self.model._output.out_features * self._num_members, kernel_size=1, groups=self._num_members * self._gamma, bias=self.model._output.bias is not None, ), ReshapeOutput(self._num_members), ) self.model._output[1].do_not_replace = True self.model._output_activation = OutputActivation(task=self._task, dim=2) else: raise ValueError( "The last layer of the model should be a `torch.nn.Linear`." ) # Check if model has adaptive avg pooling layer and flatten in the end, if yes remove them if isinstance(self.model._layers[-2], nn.AdaptiveAvgPool2d) and isinstance( self.model._layers[-1], nn.Flatten ): self.model._layers = self.model._layers[:-2]
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method adds the specific arguments for the MIMO method.""" parser = super(PEMethod, PEMethod).add_specific_args(parent_parser) parser.add_argument( "--method_alpha", type=int, default=1, help="The width multiplier." ) parser.add_argument( "--method_gamma", type=int, default=1, help="The subgroups multiplier." ) return parser