import argparse
import math
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.ao.quantization import DeQuantStub, QuantStub, fuse_modules
from yamle.models.operations import OutputActivation, ReshapeOutput
from yamle.models.model import BaseModel
from yamle.defaults import REGRESSION_KEY, CLASSIFICATION_KEY
[docs]
class EmptyBlock(nn.Module):
"""This class defines an empty block that does nothing, but signals where to cache the hidden states."""
[docs]
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""This function applies the forward pass.
Args:
inputs (torch.Tensor): Input tensor.
"""
return inputs
[docs]
class DenseLayer(nn.Module):
"""This class defines the primitive of the DenseNet architecture.
Args:
inplanes (int): Number of input features.
growth_rate (int): Number of output features.
normalization (Type[nn.Module]): The normalization to use.
normalization_kwargs (Dict[str, Any]): The keyword arguments for the normalization.
bn_size (int): Bottleneck size.
dropout_rate (float): Dropout rate.
"""
def __init__(
self,
inplanes: int,
growth_rate: int,
normalization: Type[nn.Module],
normalization_kwargs: Dict[str, Any],
bn_size: int,
dropout_rate: float,
) -> None:
super(DenseLayer, self).__init__()
self.norm1 = normalization(inplanes, **normalization_kwargs)
self.relu1 = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(
inplanes, bn_size * growth_rate, kernel_size=1, stride=1, bias=False
)
self.norm2 = normalization(bn_size * growth_rate, **normalization_kwargs)
self.relu2 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
bn_size * growth_rate,
growth_rate,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.dropout = nn.Dropout(p=dropout_rate)
[docs]
def bn_function(self, inputs: torch.Tensor) -> torch.Tensor:
"""This function applies bottleneck function.
Args:
inputs (torch.Tensor): Input tensor.
"""
concatenated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concatenated_features)))
return bottleneck_output
[docs]
def forward(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor:
"""This function applies the forward pass.
Args:
inputs (torch.Tensor): Input tensor.
"""
if isinstance(inputs, torch.Tensor):
prev_features = [inputs]
else:
prev_features = inputs
new_features = self.bn_function(prev_features)
new_features = self.conv2(self.relu2(self.norm2(new_features)))
new_features = self.dropout(new_features)
return new_features
[docs]
def replace_layers_for_quantization(self) -> None:
"""This function replaces the layers for quantization."""
# Unfortunately the batch normalisation is the first layer of the block
# which means that we cannot fuse it with the convolutional layer
pass
[docs]
class Transition(nn.Module):
"""This class defines the transition layer of the DenseNet architecture.
Args:
inplanes (int): Number of input features.
outplanes (int): Number of output features.
normalization (Type[nn.Module]): The normalization to use.
normalization_kwargs (Dict[str, Any]): The keyword arguments for the normalization.
"""
def __init__(
self,
inplanes: int,
outplanes: int,
normalization: Type[nn.Module],
normalization_kwargs: Dict[str, Any],
) -> None:
super(Transition, self).__init__()
self.norm = normalization(inplanes, **normalization_kwargs)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False)
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
[docs]
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""This function applies the forward pass.
Args:
inputs (torch.Tensor): Input tensor.
"""
out = self.conv(self.relu(self.norm(inputs)))
out = self.avgpool(out)
return out
[docs]
class DenseBlock(nn.Module):
"""This class defines the DenseBlock of the DenseNet architecture.
Args:
inplanes (int): Number of input features.
growth_rate (int): Number of output features.
normalization (Type[nn.Module]): The normalization to use.
normalization_kwargs (Dict[str, Any]): The keyword arguments for the normalization.
bn_size (int): Bottleneck size.
dropout_rate (float): Dropout rate.
n_layers (int): Number of layers.
"""
def __init__(
self,
inplanes: int,
growth_rate: int,
normalization: Type[nn.Module],
normalization_kwargs: Dict[str, Any],
bn_size: int,
dropout_rate: float,
n_layers: int,
) -> None:
super(DenseBlock, self).__init__()
self.layers = nn.ModuleList()
for i in range(n_layers):
layer = DenseLayer(
inplanes + i * growth_rate,
growth_rate,
normalization,
normalization_kwargs,
bn_size,
dropout_rate,
)
self.layers.append(layer)
[docs]
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""This function applies the forward pass.
Args:
inputs (torch.Tensor): Input tensor.
"""
features = [inputs]
for layer in self.layers:
new_features = layer(features)
features.append(new_features)
return torch.cat(features, dim=1)
[docs]
class DenseNetModel(BaseModel):
"""This class defines the DenseNet architecture as described in the paper.
Densely Connected Convolutional Networks: https://arxiv.org/abs/1608.06993
The code is based on the implementation of torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py
Args:
layers (List[int]): Number of layers in each block.
depth (int): The depth of the network respective to the the length of the layers list.
bn_size (int): Bottleneck size.
growth_rate (int): The growth rate multiplier.
initial_planes (int): Number of initial planes.
width_multiplier (float): Width multiplier to multiply the initial number of features.
normalization (Optional[str]): The normalization to use. Can be either 'batch', 'instance', `group`, `layer`, or `None`. Defaults to 'batch'.
dropout_rate (float): Dropout rate.
"""
tasks = [
CLASSIFICATION_KEY,
REGRESSION_KEY,
]
def __init__(
self,
layers: List[int] = [6, 12, 24, 16],
depth: int = 4,
bn_size: int = 4,
growth_rate: int = 32,
initial_planes: int = 64,
width_multiplier: float = 1.0,
normalization: Optional[str] = "batch",
dropout_rate: float = 0.0,
*args: Any,
**kwargs: Any,
) -> None:
super(DenseNetModel, self).__init__(*args, **kwargs)
assert depth is None or (
len(layers) >= depth and depth >= 1
), f"Depth must be between 1 and {len(layers)} but got {depth}"
self._layers = layers
self._initial_planes = initial_planes * width_multiplier
self._width_multiplier = width_multiplier
self._dropout_rate = dropout_rate
if depth is None:
self._depth = len(layers)
else:
self._depth = depth
norm = nn.Identity
assert normalization in [
"batch",
"instance",
"group",
"layer",
None,
], f"Normalization {normalization} is not supported."
norm_kwargs = {}
if normalization == "batch":
norm = nn.BatchNorm2d
elif normalization == "instance":
norm = nn.InstanceNorm2d
norm_kwargs = {"affine": True}
elif normalization == "group":
norm = nn.GroupNorm
elif normalization == "layer":
norm = nn.LayerNorm
self._normalization = norm
self._norm_kwargs = norm_kwargs
if (
len(self._inputs_dim) == 3
and self._inputs_dim[1] <= 64
and self._inputs_dim[2] <= 64
):
self._input = nn.Conv2d(
self._inputs_dim[0],
self._initial_planes,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
else:
self._input = nn.Conv2d(
self._inputs_dim[0],
self._initial_planes,
kernel_size=7,
stride=2,
padding=3,
bias=False,
)
self._layers = nn.ModuleList()
self._layers.append(norm(self._initial_planes, **norm_kwargs))
self._layers.append(nn.ReLU(inplace=True))
if self._inputs_dim[1] > 64 or self._inputs_dim[2] > 64:
self._layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
num_features = self._initial_planes
self._planes = [] # Will collet the number of output planes for each block
for i in range(self._depth):
self._layers.append(
DenseBlock(
inplanes=num_features,
n_layers=layers[i],
bn_size=bn_size,
growth_rate=growth_rate,
dropout_rate=dropout_rate,
normalization=norm,
normalization_kwargs=norm_kwargs,
)
)
num_features += layers[i] * growth_rate
if i != self._depth-1:
self._layers.append(
Transition(
inplanes=num_features,
outplanes=num_features // 2,
normalization=norm,
normalization_kwargs=norm_kwargs,
)
)
num_features = num_features // 2
self._layers.append(EmptyBlock())
self._planes.append(num_features)
self._layers.append(nn.AdaptiveAvgPool2d((1, 1)))
self._layers.append(nn.Flatten())
self._output = nn.Linear(self._planes[self._depth-1], self._outputs_dim)
self._output_activation = OutputActivation(self._task, dim=1)
[docs]
def forward(
self,
x: torch.Tensor,
staged_output: bool = False,
input_kwargs: Dict[str, Any] = {},
output_kwargs: Dict[str, Any] = {},
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""The forward function of the ResNet model.
Args:
x (torch.Tensor): The input tensor.
staged_output (bool): Whether to return the output of each layer. Defaults to False.
input_kwargs (Dict[str, Any]): The kwargs for the input layer.
output_kwargs (Dict[str, Any]): The kwargs for the output layer.
"""
layers_outputs = []
x = self._input(x, **input_kwargs)
for i in range(len(self._layers)):
x = self._layers[i](x)
# Make sure to cache the output of the transition layer
# If transition layer
if isinstance(self._layers[i], EmptyBlock):
layers_outputs.append(x)
if isinstance(self._layers[-1], nn.Flatten) and staged_output:
layers_outputs[-1] = x
x = self.final_layer(x, **output_kwargs)
if staged_output:
return x, layers_outputs
return x
[docs]
def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
"""This function is used to get the final layer output."""
x = self._output(x, **output_kwargs)
return self._output_activation(x)
[docs]
def add_method_specific_layers(self, method: str, **kwargs: Any) -> None:
"""This method is used to add method specific layers to the model.
Args:
method (str): The method to use.
"""
super().add_method_specific_layers(method, **kwargs)
if method in ["dun", "mimmo"]:
self._reshaping_layers = nn.ModuleList()
available_heads = [True] * (self._depth-1)
if "heads" in kwargs and kwargs["heads"]:
self._heads = nn.ModuleList()
if "available_heads" in kwargs:
available_heads = kwargs["available_heads"]
for i, available_head in enumerate(available_heads):
if not available_head:
continue
layers = []
layers.append(nn.Conv2d(self._planes[i], self._planes[self._depth-1], 1))
layers.append(self._normalization(self._planes[self._depth-1]))
layers.append(nn.ReLU())
layers.append(nn.AdaptiveAvgPool2d(1))
layers.append(nn.Flatten())
self._reshaping_layers.append(nn.Sequential(*layers))
if "heads" in kwargs and kwargs["heads"]:
head = []
head.append(
nn.Linear(
self._planes[self._depth-1], self._output[0].out_features
)
)
head.append(ReshapeOutput(num_members=kwargs["num_members"]))
self._heads.append(nn.Sequential(*head))
elif method in ["early_exit"]:
gamma = kwargs["gamma"]
hidden_feature_size_output = (
self._output.in_features
if method == "early_exit"
else self._output[0].in_features
)
size_output = (
self._output.out_features
if method == "early_exit"
else self._output[0].out_features
)
self._reshaping_layers = nn.ModuleList()
heads = [1] * self._depth
if "heads" in kwargs and kwargs["heads"] is not None:
heads = kwargs["heads"]
assert (
len(heads) == self._depth
), f"Number of heads should be {self._depth}, but got {len(heads)}"
kwargs["heads"] = heads
for i in range(1, self._depth):
if not heads[i - 1]:
continue
sequence = []
hidden_feature_size = int(
math.sqrt(1 + gamma) ** (self._depth - i)
* hidden_feature_size_output
)
if gamma > 0:
sequence.append(nn.AdaptiveAvgPool2d(1))
sequence.append(nn.Flatten())
sequence.append(nn.Linear(self._planes[i - 1], hidden_feature_size))
sequence.append(nn.ReLU())
sequence.append(nn.Linear(hidden_feature_size, size_output))
else:
sequence.append(nn.AdaptiveAvgPool2d(1))
sequence.append(nn.Flatten())
sequence.append(nn.Linear(self._planes[i - 1], size_output))
self._reshaping_layers.append(nn.Sequential(*sequence))
else:
raise ValueError(f"Method {method} is not supported.")
[docs]
def replace_layers_for_quantization(self) -> None:
"""Fuses all the operations in the network.
In this function we only need to fuse layers that are not in the blocks.
e.g. the reshaping layers added by the method.
"""
if self._added_method_specific_layers:
if self._method in ["dun", "mimmo"]:
for i in range(len(self._reshaping_layers)):
self._reshaping_layers[i] = fuse_modules(
self._reshaping_layers[i], [["0", "1", "2"]], inplace=True
)
if "heads" in self._method_kwargs and self._method_kwargs["heads"]:
for i in range(len(self._heads)):
self._heads[i] = nn.Sequential(self._heads[i], DeQuantStub())
if self._method == "dun":
fuse_modules(
self, [["_input", "_layers.0", "_layers.1"]], inplace=True
)
else:
fuse_modules(
self, [["_input.1", "_layers.0", "_layers.1"]], inplace=True
)
elif self._method in ["early_exit"]:
for i in range(1, self._depth):
if not self._method_kwargs["heads"][i - 1]:
continue
if self._method_kwargs["gamma"] > 0:
self._reshaping_layers[i - 1] = fuse_modules(
self._reshaping_layers[i - 1], [["2", "3"]]
)
self._reshaping_layers[i - 1] = nn.Sequential(
self._reshaping_layers[i - 1], DeQuantStub()
)
fuse_modules(self, [["_input", "_layers.0", "_layers.1"]], inplace=True)
else:
# Find the input convolution, could be either _input or _input.X if the input is
# in a sequential block.
input_conv = "_input"
if isinstance(self._input, (nn.Sequential, nn.ModuleList)):
for i, layer in enumerate(self._input):
if isinstance(layer, nn.Conv2d):
input_conv = f"_input.{i}"
break
fuse_modules(self, [[input_conv, "_layers.0", "_layers.1"]], inplace=True)
# Add quantization stubs to the input and dequantization stubs to the output.
self._input = nn.Sequential(QuantStub(), self._input)
self._output = nn.Sequential(self._output, DeQuantStub())
[docs]
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""Add specific arguments to the parser."""
parser = super(DenseNetModel, DenseNetModel).add_specific_args(parent_parser)
parser.add_argument(
"--model_layers",
type=str,
default="[6,12,24,16]",
help="Number of layers in each block.",
)
parser.add_argument(
"--model_depth",
type=int,
default=None,
help="The depth of the network respective to the the length of the layers list.",
)
parser.add_argument(
"--model_bn_size", type=int, default=4, help="Bottleneck size."
)
parser.add_argument(
"--model_initial_planes",
type=int,
default=64,
help="Number of initial planes.",
)
parser.add_argument(
"--model_growth_rate",
type=int,
default=32,
help="Number of output features.",
)
parser.add_argument(
"--model_width_multiplier",
type=float,
default=1,
help="Width multiplier that multiplies the initial number of features.",
)
parser.add_argument(
"--model_normalization",
type=str,
default="batch",
help="The normalization to use. Can be either 'batch', 'instance', `group`, `layer`, or `None`.",
)
parser.add_argument(
"--model_dropout_rate", type=float, default=0.0, help="Dropout rate."
)
return parser