Source code for yamle.models.transformer
from typing import Tuple, Dict, Any
import math
import torch
import torch.nn as nn
import argparse
from einops import rearrange
from yamle.models.operations import (
OutputActivation,
MatrixMultiplication,
ResidualLayer,
)
from yamle.models.model import BaseModel
from yamle.models.specific.mcdropout import disable_dropout_replacement
from yamle.defaults import TEXT_CLASSIFICATION_KEY
[docs]
class PreNorm(nn.Sequential):
"""This class implements the pre-normalization layer.
Args:
dim (int): The dimension of the input.
module (nn.Module): The module to be applied after the normalization.
"""
def __init__(self, dim: int, module: nn.Module) -> None:
super().__init__(nn.LayerNorm(dim), module)
[docs]
class FeedForward(nn.Sequential):
"""This class implements the feed-forward layer.
It consists of two linear layers with GELU activation and dropout.
Args:
dim (int): The dimension of the input.
hidden_dim (int): The dimension of the hidden layer.
dropout (float): The dropout rate.
dense (nn.Module): The dense layer to be used. Defaults to nn.Linear.
"""
def __init__(
self, dim: int, hidden_dim: int, dropout: float, dense: nn.Module = nn.Linear
) -> None:
super().__init__(
dense(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
dense(hidden_dim, dim),
nn.Dropout(dropout),
)
self._dim = dim
self._hidden_dim = hidden_dim
self._dropout = dropout
# Disable dropout replacement for the feed-forward layer
disable_dropout_replacement(self._modules["0"])
disable_dropout_replacement(self._modules["3"])
[docs]
def extra_repr(self) -> str:
return (
super().extra_repr()
+ f", dim={self._dim}, hidden_dim={self._hidden_dim}, dropout={self._dropout}"
)
[docs]
class Attention(nn.Module):
"""This class implements the attention layer.
It computes multi-head attention.
Args:
dim (int): The dimension of the input.
heads (int): The number of heads.
dim_head (int): The dimension of each head.
dropout (float): The dropout rate.
causal (bool): Whether to use causal attention. Defaults to False.
"""
def __init__(
self, dim: int, heads: int, dim_head: int, dropout: float, causal: bool = False
) -> None:
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self._dim = dim
self._inner_dim = inner_dim
self._dim_head = dim_head
self._dropout = dropout
self._heads = heads
self._scale = dim_head**-0.5
self._causal = causal
self._attend = nn.Softmax(dim=-1)
self._drop = nn.Dropout(dropout)
self._to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self._to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
self._matrix_multiplication1 = MatrixMultiplication()
self._matrix_multiplication2 = MatrixMultiplication()
# Disable dropout replacement for qkv layer and the first layer of the output layer
disable_dropout_replacement(self._to_qkv)
if isinstance(self._to_out, nn.Sequential):
disable_dropout_replacement(self._to_out._modules["0"])
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the model."""
qkv = self._to_qkv(x).chunk(3, dim=-1)
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self._heads), qkv
)
dots = self._matrix_multiplication1(q, k.transpose(-1, -2)) * self._scale
if self._causal:
mask = torch.ones_like(dots).triu_(1).bool()
dots.masked_fill_(mask, float("-inf"))
attn = self._attend(dots)
attn = self._drop(attn)
out = self._matrix_multiplication2(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self._to_out(out)
[docs]
def extra_repr(self) -> str:
return (
super().extra_repr()
+ f", dim={self._dim}, inner_dim={self._inner_dim}, dim_head={self._dim_head}, dropout={self._dropout}, heads={self._heads}, scale={self._scale}, causal={self._causal}"
)
[docs]
class TransformerEncoderLayer(nn.Sequential):
"""This class implements the transformer encoder layer.
It consists of a multi-head attention layer and a feed-forward layer.
It also implements the residual connection and layer normalization.
Args:
dim (int): The dimension of the input.
heads (int): The number of heads.
dim_head (int): The dimension of each head.
mlp_dim (int): The dimension of the hidden layer in the feed-forward layer.
dropout (float): The dropout rate.
causal (bool): Whether to use causal attention. Defaults to False.
"""
def __init__(
self,
dim: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float,
causal: bool = False,
) -> None:
super().__init__(
ResidualLayer(
PreNorm(
dim,
Attention(
dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
causal=causal,
),
)
),
ResidualLayer(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))),
)
self._dim = dim
self._heads = heads
self._dim_head = dim_head
self._mlp_dim = mlp_dim
self._dropout = dropout
self._causal = causal
[docs]
def extra_repr(self) -> str:
return (
super().extra_repr()
+ f", dim={self._dim}, heads={self._heads}, dim_head={self._dim_head}, mlp_dim={self._mlp_dim}, dropout={self._dropout}, causal={self._causal}"
)
[docs]
class PositionalEncoding(nn.Module):
"""This class is used to create a module to implement the positional encoding.
Args:
inputs_dim (int): The total size of token embeddings.
embedding_dim (int): The number of expected features in the input.
dropout (float): The dropout value.
max_len (int): The max length of the expected input.
"""
def __init__(
self, inputs_dim: int, embedding_dim: int, dropout: float, max_len: int = 5000
) -> None:
super().__init__()
self._embedding = nn.Embedding(inputs_dim, embedding_dim)
self._dropout = nn.Dropout(p=dropout)
self._scale = math.sqrt(embedding_dim)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim)
)
pe = torch.zeros(max_len, 1, embedding_dim)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer("_pe", pe)
self.reset_parameters()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the model."""
x = self._embedding(x) * self._scale
x = x + self._pe[: x.size(0)]
return self._dropout(x)
[docs]
def reset_parameters(self) -> None:
"""This function is used to initialize the parameters of the model."""
self._embedding.weight.data.uniform_(-0.1, 0.1)
[docs]
class TransformerModel(BaseModel):
"""This class is used to create a Transformer decoder model.
It is based on the PyTorch implementation of the Transformer model.
Args:
embedding_dim (int): The embedding dimensions of the model.
num_heads (int): The number of heads in the multiheadattention models.
num_decoder_layers (int): The number of sub-decoder-layers in the decoder.
hidden_dim (int): The dimension of the feedforward network model.
dropout (float): The dropout value.
"""
tasks = [TEXT_CLASSIFICATION_KEY]
def __init__(
self,
embedding_dim: int,
num_heads: int,
num_decoder_layers: int,
hidden_dim: int,
dropout: float,
*args: Any,
**kwargs: Any,
) -> None:
super(TransformerModel, self).__init__(*args, **kwargs)
self._positional_encoding = PositionalEncoding(
self._outputs_dim, embedding_dim, dropout
)
# We use TransformerEncoderLayer as the decoder layer beacause it is easier to set
# Causal mask just by setting the `is_causal` parameter to True.
head_dim = embedding_dim // num_heads
self._decoder = nn.ModuleList(
[
TransformerEncoderLayer(
dim=embedding_dim,
heads=num_heads,
dim_head=head_dim,
mlp_dim=hidden_dim,
dropout=dropout,
causal=True,
)
for _ in range(num_decoder_layers)
]
)
self._output = nn.Linear(embedding_dim, self._outputs_dim)
# Implement weight parameter sharing in the output layer and the positional encoding layer
self._output.weight = self._positional_encoding._embedding.weight
self._output_activation = OutputActivation(self._task, dim=2)
self._depth = num_decoder_layers
self.reset_parameters()
[docs]
def reset_parameters(self) -> None:
"""This function is used to initialize the parameters of the model."""
self._output.bias.data.zero_()
self._output.weight.data.uniform_(-0.1, 0.1)
[docs]
def forward(
self,
x: torch.Tensor,
staged_output: bool = False,
input_kwargs: Dict[str, Any] = {},
output_kwargs: Dict[str, Any] = {},
) -> torch.Tensor:
"""Forward pass of the model.
Note that the input has a shape of `(batch_size, seq_len)`.
Args:
x (torch.Tensor): The input tensor.
staged_output (bool): Whether to return the output of each layer.
input_kwargs (Dict[str, Any]): The kwargs for the input layer.
output_kwargs (Dict[str, Any]): The kwargs for the output layer.
"""
layers_outputs = []
x = self._positional_encoding(x)
for e in self._decoder:
x = e(x, is_causal=True)
if staged_output:
layers_outputs.append(x)
x = self.final_layer(x, **output_kwargs)
if staged_output:
return x, layers_outputs
return x
[docs]
def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
"""This function is used to get the final layer output."""
x = self._output(x, **output_kwargs)
return self._output_activation(x)
[docs]
def generate(
self, input: torch.Tensor, max_len: int, temperature: float = 1.0, **kwargs: Any
) -> torch.Tensor:
"""This function is used to generate output by passing the input through the model."""
for _ in range(max_len):
x = self._positional_encoding(input)
for e in self._decoder:
x = e(x, is_causal=True)
# Get the last token
x = x[:, [-1], :]
x = self._output(x)
x = torch.softmax(x / temperature, dim=-1)
x = torch.multinomial(x, num_samples=1)
# Add the new token to the input
input = torch.cat([input, x], dim=1)
return input[:, :-max_len, :]
[docs]
def add_method_specific_layers(self, method: str, **kwargs: Any) -> None:
"""This method is used to add method specific layers to the model.
Args:
method (str): The method to use.
"""
super().add_method_specific_layers(method, **kwargs)
if method in ["dun", "mimmo"]:
self._reshaping_layers = nn.ModuleList(
[nn.Identity() for _ in range(self._depth - 1)]
)
else:
raise ValueError(f"Method {method} is not supported.")
[docs]
@staticmethod
def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""This function is used to add specific arguments to the parser."""
parser = super(TransformerModel, TransformerModel).add_specific_args(parser)
parser.add_argument(
"--model_embedding_dim",
type=int,
default=200,
help="The number of expected features in the input.",
)
parser.add_argument(
"--model_num_heads",
type=int,
default=2,
help="The number of heads in the multiheadattention models.",
)
parser.add_argument(
"--model_num_decoder_layers",
type=int,
default=2,
help="The number of decoder layers.",
)
parser.add_argument(
"--model_hidden_dim",
type=int,
default=200,
help="The dimension of the feedforward network model.",
)
parser.add_argument(
"--model_dropout", type=float, default=0.2, help="The dropout value."
)
return parser