Extending Model#

In this Tutorial we will demonstrate how to extend the :py:mod: BaseModel class to create a new model.

class BaseModel(nn.Module, abc.ABC):
    """This is the base class for all the models.

    By default it should have an input and output layer in `_input` and `_output` respectively.
    All the intermediate layers should be in `_layers`.
    The depth of the model should be in `_depth`.

    Args:
        inputs_dim (Tuple[int,...]): The input dimensions.
        outputs_dim (int): The output dimension.
        task (str): The task to perform.
    """

    tasks = SUPPORTED_TASKS

    def __init__(
        self,
        inputs_dim: Tuple[int, ...],
        outputs_dim: int,
        task: str,
        seed: int,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        self._inputs_dim = inputs_dim
        self._outputs_dim = outputs_dim
        assert (
            task in self.tasks
        ), f"The task {task} is not supported. Supported tasks are {self.tasks}."
        self._task = task
        self._output: nn.Module = None
        self._input: nn.Module = None
        self._output_activation: nn.Module = None
        self._layers: Union[nn.ModuleList, nn.Sequential] = None

        self._added_method_specific_layers = False
        self._method: str = None
        self._method_kwargs: Dict[str, Any] = None
        self._depth: int = None
        self._seed = seed

    @abc.abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """This method is used to perform a forward pass of the model."""
        raise NotImplementedError("The forward method must be implemented.")

    @abc.abstractmethod
    def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
        """This function is used to get the final layer output."""
        raise NotImplementedError("The final_layer method must be implemented.")

    @classmethod
    def add_specific_args(
        cls, parent_parser: argparse.ArgumentParser
    ) -> argparse.ArgumentParser:
        """This method adds model arguments to the given parser."""
        return argparse.ArgumentParser(parents=[parent_parser], add_help=False)

    def reset(self) -> None:
        """This method is used to reset the model e.g. at the start of a new epoch."""
        pass

    def replace_layers_for_quantization(self) -> None:
        """Fuses all the operations in the network.

        In this function we only need to fuse layers that are not in the blocks.
        e.g. the reshaping layers added by the method.
        """
        pass

    def add_method_specific_layers(self, method: str, **kwargs: Any) -> None:
        """This method is used to add method specific layers to the model.

        Args:
            method (str): The method to use.
        """
        self._added_method_specific_layers = True
        self._method = method
        self._method_kwargs = kwargs

Each model which is added to YAMLE needs to inherit from the :py:mod: BaseModel class. The BaseModel class provides a number of methods which are used to cross-interact the model with a method BaseMethod and a datamodule BaseDataModule.

Note that each model needs to be able to accept the inputs_dim, outputs_dim and task which automatically decides the number of inputs and outputs for the model. The task is a string which is used to determine the type of task the model is being used for. The task usually determines the output activation, for example softmax for classification and exponential applied to one of the outputs in regression to model the variance.

It is expected that the very first learnable layer will be in the _input attribute and the very last learnable layer will be in the _output attribute. The output activation is expected to be in the _output_activation attribute. This is such that it is possible to easily extract the model’s input and output layers and the output activation if needed by the underlying BaseMethod.

There are also other functions which can be used to define the exact behaviour when quantising the model, reset the model each training epoch or to add method-specific layer to the model while keeping the backbone of the model the same. These are all optional and can be overridden if needed.

The most important methods are the forward or final_layer which specify the forward pass of the model or the processing of the last hidden features with respect to the output layer and the output activation.

A concrete example is a fully connected network with multiple hidden layers FC.

class FCModel(BaseModel):
    """This class is used to create a FC model with the given parameters.

    Args:
        hidden_dim (int): The dimensions of the hidden layers.
        width_multiplier (int): The width multiplier for the hidden layers. Default: 1.
        depth (int): The number of hidden layers.
        normalization (Optional[str]): The normalization to use. Either 'batch', 'linear', 'instance' or `None`.
        activation (Optional[str]): The activation to use. Either 'relu', 'linear' or `None`.
    """

    tasks = [CLASSIFICATION_KEY, REGRESSION_KEY]

    def __init__(
        self,
        hidden_dim: int,
        width_multiplier: int,
        depth: int,
        normalization: str,
        activation: str,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        super(FCModel, self).__init__(*args, **kwargs)
        self._inputs_dim = np.prod(self._inputs_dim)
        self._hidden_dim = hidden_dim * width_multiplier

        self._flatten = nn.Flatten()
        self._layers = nn.ModuleList()
        self._input = nn.Linear(self._inputs_dim, self._hidden_dim)
        self._relu = nn.ReLU()
        for i in range(depth):
            self._layers.append(
                LinearNormActivation(
                    self._hidden_dim,
                    self._hidden_dim,
                    normalization=normalization,
                    activation=activation,
                )
            )
        self._normalization = normalization
        self._output = nn.Linear(self._hidden_dim, self._outputs_dim)
        self._output_activation = OutputActivation(self._task, dim=1)
        self._depth = depth

    def forward(
        self,
        x: torch.Tensor,
        staged_output: bool = False,
        input_kwargs: Dict[str, Any] = {},
        output_kwargs: Dict[str, Any] = {},
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
        """This method is used to perform a forward pass through the model.

        The input is expected to be of shape `(batch_size, inputs_dim)`.
        The output is of shape `(batch_size, outputs_dim)`.

        Args:
            x (torch.Tensor): The input to the model.
            staged_output (bool): If True, the output is a tuple of the last layer and the hidden layers.
            input_kwargs (Dict[str, Any]): The kwargs for the input layer.
            output_kwargs (Dict[str, Any]): The kwargs for the output layer.
        """
        layers_outputs = []
        x = self._flatten(x)
        x = self._input(x, **input_kwargs)
        x = self._relu(x)
        for layer in self._layers:
            x = layer(x)
            if staged_output:
                layers_outputs.append(x)

        x = self.final_layer(x, **output_kwargs)
        if staged_output:
            return x, layers_outputs
        return x

    def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
        """This function is used to get the final layer output."""
        x = self._output(x, **output_kwargs)
        return self._output_activation(x)

    def add_method_specific_layers(self, method: str, **kwargs: Any) -> None:
        """This method is used to add method specific layers to the model.

        Args:
            method (str): The method to use.
        """
        super().add_method_specific_layers(method, **kwargs)
        norm_kwargs = {"affine": True}
        if method in ["dun", "mimmo"]:
            self._reshaping_layers = nn.ModuleList()
            available_heads = [True] * (self._depth - 1)
            if "heads" in kwargs and kwargs["heads"]:
                self._heads = nn.ModuleList()
            if "available_heads" in kwargs:
                available_heads = kwargs["available_heads"]
            for i, available_head in enumerate(available_heads):
                if not available_head:
                    continue
                layers = []
                layers.append(nn.Linear(self._hidden_dim, self._hidden_dim))
                layers.append(
                    Normalization(
                        norm=self._normalization,
                        dimension=1,
                        norm_kwargs={**norm_kwargs, "num_features": self._hidden_dim},
                    )
                )
                layers.append(nn.ReLU())
                self._reshaping_layers.append(nn.Sequential(*layers))

                if "heads" in kwargs and kwargs["heads"]:
                    head = []
                    head.append(
                        nn.Linear(self._hidden_dim, self._output[0].out_features)
                    )
                    head.append(ReshapeOutput(num_members=kwargs["num_members"]))
                    self._heads.append(nn.Sequential(*head))

        elif method in ["early_exit"]:
            gamma = kwargs["gamma"]
            self._reshaping_layers = nn.ModuleList()
            hidden_feature_size_output = (
                self._output.in_features
                if method == "early_exit"
                else self._output[0].in_features
            )
            size_output = (
                self._output.out_features
                if method == "early_exit"
                else self._output[0].out_features
            )
            heads = [1] * self._depth
            if "heads" in kwargs and kwargs["heads"] is not None:
                heads = kwargs["heads"]
                assert (
                    len(heads) == self._depth
                ), "Number of heads should be equal to number of layers."
            for i in range(1, self._depth):
                if not heads[i - 1]:
                    continue
                sequence = []
                hidden_feature_size = int(
                    math.sqrt(1 + gamma) ** (self._depth - i)
                    * hidden_feature_size_output
                )
                if gamma > 0:
                    sequence.append(nn.Linear(self._hidden_dim, hidden_feature_size))
                    sequence.append(nn.ReLU())
                    sequence.append(nn.Linear(hidden_feature_size, size_output))
                else:
                    sequence.append(nn.Linear(self._hidden_dim, size_output))
                self._reshaping_layers.append(nn.Sequential(*sequence))
        else:
            raise ValueError(f"Method {method} is not supported.")

    def replace_layers_for_quantization(self) -> None:
        """Fuses all the operations in the network.

        In this function we only need to fuse layers that are not in the blocks.
        e.g. the reshaping layers added by the method.
        """
        if self._added_method_specific_layers:
            if self._method in ["dun", "mimmo"]:
                for i in range(len(self._reshaping_layers)):
                    self._reshaping_layers[i] = fuse_modules(
                        self._reshaping_layers[i], [["0", "1._norm"]]
                    )

                if "heads" in self._method_kwargs and self._method_kwargs["heads"]:
                    for i in range(len(self._heads)):
                        self._heads[i] = nn.Sequential(self._heads[i], DeQuantStub())

            elif self._method in ["early_exit"]:
                for i in range(1, self._depth):
                    if not self._method_kwargs["heads"][i - 1]:
                        continue
                    if self._method_kwargs["gamma"] > 0:
                        self._reshaping_layers[i - 1] = fuse_modules(
                            self._reshaping_layers[i - 1], [["2", "3"]]
                        )
                    self._reshaping_layers[i - 1] = nn.Sequential(
                        self._reshaping_layers[i - 1], DeQuantStub()
                    )

        # Add quantization stubs to the input and dequantization stubs to the output.
        self._input = nn.Sequential(QuantStub(), self._input)
        self._output = nn.Sequential(self._output, DeQuantStub())

    @staticmethod
    def add_specific_args(
        parent_parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
        """This method is used to add the model specific arguments to the parent parser."""
        parser = super(FCModel, FCModel).add_specific_args(parent_parser)
        parser.add_argument(
            "--model_hidden_dim",
            type=int,
            default=128,
            help="The dimensions of the hidden layers.",
        )
        parser.add_argument(
            "--model_width_multiplier",
            type=int,
            default=1,
            help="The width multiplier for the hidden layers.",
        )
        parser.add_argument(
            "--model_depth", type=int, default=3, help="The number of hidden layers."
        )
        parser.add_argument(
            "--model_normalization",
            type=str,
            default=None,
            choices=["batch", "instance", None],
            help="The normalization to use.",
        )
        parser.add_argument(
            "--model_activation",
            type=str,
            default="relu",
            choices=["relu", "linear", "sigmoid", "tanh", None],
            help="The activation to use.",
        )
        return parser

Notice the implementation of the _input and _output layers which automatically take into the account the input and output dimensions passed down by the datamodule that has been chosen to run the experiment with. The _output_activation is also automatically chosen based on the task.

Notice that the forward method takes in extra keyword arguments e.g. to output the hidden representation by each hidden layer, this is used by certain specific methods along with the function to add extra layers for some specfic methods.

To specify the arguments of the model there is the fucntion add_specific_args which is used to add the arguments of the model to the :py:mod: ArgumentParser of the experiment in the command line. This is used to specify the number of hidden layers, the activation or the width of the network.

The model also uses some general layers such as LinearNormActivation which is a linear layer followed by a normalisation layer and an activation layer. This class is used also in other models since it is quite general. If you feel that you will be using/implementing a general layer, place it in the operations module. For a method-specific layer, place it in the specific folder.

The last step is to register the new model in the __init__ file of the models module. This is done by adding the model to the following list:

from typing import Type, Optional

import torch.nn as nn

from yamle.models.fc import FCModel, ResidualFCModel
from yamle.models.convnet import ConvNetModel, ResidualConvNetModel
from yamle.models.resnet import ResNetModel
from yamle.models.densenet import DenseNetModel
from yamle.models.unet import UNetModel
from yamle.models.transformer import TransformerModel
from yamle.models.visual_transformer import VisualTransformerModel
from yamle.models.rnn import RNNModel, RNNAutoEncoderModel
from yamle.models.mixer import MixerModel
from yamle.models.vgg import VGGModel

AVAILABLE_MODELS = {
    "fc": FCModel,
    "convnet": ConvNetModel,
    "residualconvnet": ResidualConvNetModel,
    "residualfc": ResidualFCModel,
    "resnet": ResNetModel,
    "densenet": DenseNetModel,
    "vgg": VGGModel,
    "unet": UNetModel,
    "transformer": TransformerModel,
    "visualtransformer": VisualTransformerModel,
    "mixer": MixerModel,
    "rnn": RNNModel,
    "rnnautoencoder": RNNAutoEncoderModel,
    None: nn.Identity,
}


def model_factory(model_type: Optional[str] = None) -> Type[nn.Module]:
    """This function is used to return a model instance based on the model type.

    Args:
        model_type (str): The type of model to create.
    """
    if model_type not in AVAILABLE_MODELS:
        raise ValueError(f"Unknown model type {model_type}.")
    return AVAILABLE_MODELS[model_type]