Source code for yamle.models.visual_transformer
from typing import Tuple, Dict, Any, Union, List
import torch
from torch import nn
import argparse
from einops.layers.torch import Rearrange
from einops import repeat
import math
from yamle.models.operations import (
OutputActivation,
Reduction,
Add,
Lambda,
ReshapeOutput,
)
from yamle.models.transformer import TransformerEncoderLayer
from yamle.models.model import BaseModel
from yamle.models.specific.mcdropout import disable_dropout_replacement
from yamle.defaults import REGRESSION_KEY, CLASSIFICATION_KEY
[docs]
class SpatialPositionalEmbedding(nn.Module):
"""This class is used to create a spatial positional embedding to be used in the
visual transformer for 2D images.
Args:
inputs_dim (Tuple[int, int, int]): The dimension of the input.
patch_size (int): The size of the patch.
embedding_dim (int): The dimension of the embedding.
dropout (float): The dropout rate.
num_cls_tokens (int): The number of class tokens. Defaults to 1.
positional_embedding (bool): Whether to use positional embedding. Defaults to True.
"""
def __init__(
self,
inputs_dim: Tuple[int, int, int],
patch_size: int,
embedding_dim: int,
dropout: float = 0.0,
num_cls_tokens: int = 1,
positional_embedding: bool = True,
) -> None:
super().__init__()
C, H, W = inputs_dim
assert (
H % patch_size == 0
), f"Image dimensions must be divisible by the patch size. Got {H} and {patch_size}."
assert (
W % patch_size == 0
), f"Image dimensions must be divisible by the patch size. Got {W} and {patch_size}."
self._inputs_dim = inputs_dim
self._patch_size = patch_size
self._embedding_dim = embedding_dim
self._dropout = dropout
self._num_cls_tokens = num_cls_tokens
self._num_patches = (H // patch_size) ** 2
self._patch_dim = C * patch_size**2
self._to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size
),
nn.LayerNorm(self._patch_dim),
nn.Linear(self._patch_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
)
# Disable dropout replacement for positional embedding
disable_dropout_replacement(self._to_patch_embedding._modules["2"])
self._positional_embedding = (
nn.Parameter(
torch.randn(1, self._num_patches + self._num_cls_tokens, embedding_dim),
requires_grad=True,
)
if positional_embedding
else None
)
self._cls_token = (
nn.Parameter(
torch.randn(1, self._num_cls_tokens, embedding_dim), requires_grad=True
)
if num_cls_tokens > 0
else None
)
self._drop = nn.Dropout(dropout)
self._add = Add()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to get the forward pass of the model."""
x = self._to_patch_embedding(x)
b, _, _ = x.shape
if self._cls_token is not None:
cls_tokens = repeat(self._cls_token, "() n d -> b n d", b=b)
x = torch.cat((cls_tokens, x), dim=1)
if self._positional_embedding is not None:
x = self._add(x, self._positional_embedding)
return self._drop(x)
[docs]
def get_cls_token_indices(self) -> torch.Tensor:
"""This method is used to get the indices of the class tokens.
They are added as the first tokens in the sequence.
"""
return torch.arange(self._num_cls_tokens)
[docs]
class VisualTransformerModel(BaseModel):
"""This class is used to create a visual transformer model.
Args:
patch_size (int): The size of the patch to be used.
pooling (str): The pooling to be used. It can be either `mean` or `cls`.
embedding_dim (int): The number of expected features in the input.
num_heads (int): The number of heads in the multiheadattention models.
depth (int): The number of sub-encoder-layers in the encoder.
num_cls_tokens (int): The number of class tokens. Defaults to 1.
hidden_dim (int): The dimension of the feedforward network model.
width_multiplier (int): The width multiplier for the hidden dimension.
dropout (float): The dropout value.
"""
tasks = [
CLASSIFICATION_KEY,
REGRESSION_KEY,
]
def __init__(
self,
patch_size: int = 4,
pooling: str = "mean",
embedding_dim: int = 128,
num_heads: int = 6,
depth: int = 4,
num_cls_tokens: int = 1,
hidden_dim: int = 512,
width_multiplier: int = 1,
dropout: float = 0.0,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
assert pooling in [
"mean",
"cls",
], f"Pooling must be either `mean` or `cls`. Got {pooling}."
self._embedding_dim = embedding_dim
self._num_heads = num_heads
self._head_dim = embedding_dim // num_heads
self._hidden_dim = hidden_dim * width_multiplier
self._patch_size = patch_size
self._pooling = pooling
self._num_cls_tokens = num_cls_tokens
self._input = SpatialPositionalEmbedding(
self._inputs_dim, patch_size, embedding_dim, dropout, num_cls_tokens
)
self._layers = nn.ModuleList()
for _ in range(depth):
self._layers.append(
TransformerEncoderLayer(
self._embedding_dim,
self._num_heads,
self._head_dim,
self._hidden_dim,
dropout,
causal=False,
)
)
self._layers.append(nn.LayerNorm(self._embedding_dim))
self._layers.append(
Reduction(dim=1, reduction="mean") if pooling == "mean" else nn.Identity()
)
self._output = (
nn.Linear(self._embedding_dim, self._outputs_dim)
if pooling == "mean"
else nn.Linear(
self._embedding_dim * self._num_cls_tokens, self._outputs_dim
)
)
self._output_activation = OutputActivation(self._task, dim=1)
self._depth = depth
[docs]
def add_method_specific_layers(self, method: str, **kwargs: Any) -> None:
"""This method is used to add method specific layers to the model.
Args:
method (str): The method to use.
"""
super().add_method_specific_layers(method, **kwargs)
if method in ["dun", "mimmo"]:
self._reshaping_layers = nn.ModuleList()
available_heads = [True] * (self._depth - 1)
if "heads" in kwargs and kwargs["heads"]:
self._heads = nn.ModuleList()
if "available_heads" in kwargs:
available_heads = kwargs["available_heads"]
for i, available_head in enumerate(available_heads):
if not available_head:
continue
layers = []
layers.append(nn.Linear(self._embedding_dim, self._embedding_dim))
layers.append(nn.GELU())
layers.append(nn.LayerNorm(self._embedding_dim))
layers.append(
Reduction(dim=1, reduction="mean")
if self._pooling == "mean"
else Lambda(lambda x: x[:, self._input.get_cls_token_indices()])
)
self._reshaping_layers.append(nn.Sequential(*layers))
if "heads" in kwargs and kwargs["heads"]:
head = []
head.append(
nn.Linear(self._embedding_dim, self._output[0].out_features)
)
head.append(ReshapeOutput(num_members=kwargs["num_members"]))
self._heads.append(nn.Sequential(*head))
elif method in ["early_exit"]:
gamma = kwargs["gamma"]
self._reshaping_layers = nn.ModuleList()
hidden_feature_size_output = (
self._output.in_features
if method == "early_exit"
else self._output[0].in_features
)
size_output = (
self._output.out_features
if method == "early_exit"
else self._output[0].out_features
)
heads = [1] * (self._depth)
if "heads" in kwargs and kwargs["heads"] is not None:
heads = kwargs["heads"]
assert (
len(heads) == self._depth
), f"Number of heads should be {self._depth}, but got {len(heads)}"
for i in range(1, self._depth):
if not heads[i - 1]:
continue
sequence = []
hidden_feature_size = int(
math.sqrt(1 + gamma) ** (self._depth - i)
* hidden_feature_size_output
)
sequence.append(
Reduction(dim=1, reduction="mean")
if self._pooling == "mean"
else Lambda(
lambda x: x[:, self._input.get_cls_token_indices()].reshape(
x.shape[0], -1
)
)
)
sequence.append(nn.LayerNorm(hidden_feature_size_output))
if gamma > 0:
sequence.append(
nn.Linear(hidden_feature_size_output, hidden_feature_size)
)
sequence.append(nn.GELU())
sequence.append(nn.LayerNorm(hidden_feature_size))
sequence.append(nn.Linear(hidden_feature_size, size_output))
else:
sequence.append(nn.Linear(hidden_feature_size_output, size_output))
self._reshaping_layers.append(nn.Sequential(*sequence))
else:
raise ValueError(f"Method {method} is not supported.")
[docs]
def forward(
self,
x: torch.Tensor,
staged_output: bool = False,
input_kwargs: Dict[str, Any] = {},
output_kwargs: Dict[str, Any] = {},
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""Forward pass of the model.
Args:
x (torch.Tensor): The input tensor.
staged_output (bool, optional): Whether to return the output of each layer. Defaults to False.
input_kwargs (Dict[str, Any], optional): The input kwargs. Defaults to {}.
output_kwargs (Dict[str, Any], optional): The output kwargs. Defaults to {}.
"""
layers_outputs = []
x = self._input(x)
for layer in self._layers:
x = layer(x)
if staged_output and isinstance(layer, TransformerEncoderLayer):
layers_outputs.append(x)
if isinstance(self._layers[-1], Reduction) and staged_output:
layers_outputs[-1] = x
x = self.final_layer(x, **output_kwargs)
if staged_output:
return x, layers_outputs
return x
[docs]
def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
"""This function is used to get the final layer output."""
if self._pooling == "cls":
# CLS token is the first token
B = x.shape[0]
x = x[:, self._input.get_cls_token_indices()].reshape(B, -1)
x = self._output(x, **output_kwargs)
return self._output_activation(x)
[docs]
@staticmethod
def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""This function is used to add specific arguments to the parser."""
parser = super(
VisualTransformerModel, VisualTransformerModel
).add_specific_args(parser)
parser.add_argument(
"--model_patch_size",
type=int,
default=4,
help="The size of the patch to be used.",
)
parser.add_argument(
"--model_embedding_dim",
type=int,
default=128,
help="The number of expected features in the input.",
)
parser.add_argument(
"--model_pooling",
type=str,
default="cls",
choices=["mean", "cls"],
help="The pooling to be used.",
)
parser.add_argument(
"--model_num_heads",
type=int,
default=4,
help="The number of heads in the multiheadattention models.",
)
parser.add_argument(
"--model_depth",
type=int,
default=2,
help="The number of sub-encoder-layers in the encoder.",
)
parser.add_argument(
"--model_num_cls_tokens",
type=int,
default=1,
help="The number of cls tokens.",
)
parser.add_argument(
"--model_hidden_dim",
type=int,
default=512,
help="The dimension of the feedforward network model.",
)
parser.add_argument(
"--model_width_multiplier",
type=int,
default=1,
help="The width multiplier for the hidden dimension.",
)
parser.add_argument(
"--model_dropout", type=float, default=0.1, help="The dropout value."
)
return parser