from typing import List, Tuple, Union, Dict, Any
import torch.nn as nn
import torch
import argparse
from yamle.models.operations import LSTM, OutputActivation
from yamle.models.model import BaseModel
from yamle.defaults import REGRESSION_KEY, CLASSIFICATION_KEY, RECONSTRUCTION_KEY
[docs]
class RNNModel(BaseModel):
"""This class is used to create a LSTM model with the given parameters.
Args:
hidden_dim (int): The dimension of the hidden layers.
width_multiplier (int): The width multiplier for the hidden layers.
depth (int): The number of hidden layers.
"""
tasks = [REGRESSION_KEY, CLASSIFICATION_KEY]
def __init__(
self,
hidden_dim: int,
width_multiplier: int,
depth: int,
*args: Any,
**kwargs: Any,
) -> None:
super(RNNModel, self).__init__(*args, **kwargs)
self._inputs_dim = self._inputs_dim[-1]
self._hidden_dim = hidden_dim
self._width_multiplier = width_multiplier
self._depth = depth
self._hidden_dim = self._hidden_dim * self._width_multiplier
self._input = nn.Linear(self._inputs_dim, hidden_dim)
self._layers = nn.ModuleList()
for i in range(depth):
self._layers.append(LSTM(hidden_dim, hidden_dim))
self._output = nn.Linear(hidden_dim, self._outputs_dim)
self._output_activation = OutputActivation(self._task, dim=1)
[docs]
def forward(
self,
x: torch.Tensor,
staged_output: bool = False,
input_kwargs: Dict[str, Any] = {},
output_kwargs: Dict[str, Any] = {},
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""The forward function of the model.
Args:
x (torch.Tensor): The input tensor.
staged_output (bool): Whether to return the output of each layer. Defaults to False.
input_kwargs (Dict[str, Any]): The kwargs for the input layer.
output_kwargs (Dict[str, Any]): The kwargs for the output layer.
"""
layers_outputs = []
h = None
assert (
len(x.shape) == 3
), f"The input shape should be `(batch_size, seq_len, inputs_dim)`, but got {x.shape}."
x = self._input(x, **input_kwargs)
for i in range(len(self._layers)):
x, h, _ = self._layers[i](x)
if staged_output:
layers_outputs.append(h)
x = h
x = self.final_layer(x, **output_kwargs)
if staged_output:
return x, layers_outputs
return x
[docs]
def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
"""This function is used to get the final layer output."""
x = self._output(x, **output_kwargs)
return self._output_activation(x)
[docs]
def add_method_specific_layers(self, method: str, **kwargs: Any) -> None:
"""This method is used to add method specific layers to the model.
Args:
method (str): The method to use.
"""
super().add_method_specific_layers(method, **kwargs)
if method == "dun":
self._reshaping_layers = nn.ModuleList(
[
nn.Linear(self._hidden_dims[i], self._hidden_dims[-1])
for i in range(self._depth - 1)
]
)
else:
raise ValueError(f"Method {method} is not supported.")
[docs]
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""This method is used to add the model specific arguments to the parent parser."""
parser = super(RNNModel, RNNModel).add_specific_args(parent_parser)
parser.add_argument(
"--model_hidden_dim",
type=int,
default=128,
help="The dimensions of the hidden layers.",
)
parser.add_argument(
"--model_width_multiplier",
type=int,
default=1,
help="The width multiplier for the hidden layers.",
)
parser.add_argument(
"--model_depth", type=int, default=3, help="The number of hidden layers."
)
return parser
[docs]
class RNNAutoEncoderModel(BaseModel):
"""This class is used to create a LSTM model with the given parameters.
This is an autoencoder model, so the input and output dimensions are the same.
The encoder consists of LSTM layers, while the decoder consists of LSTM layers.
The encoder's last hidden state is repeated and used as the input to the decoder.
Then all the hidden states of the decoder are processed through a linear layer to get the
two times the input dimension, one for mean and one for variance.
Args:
hidden_dim (int): The dimension of the hidden layers.
width_multiplier (int): The width multiplier for the hidden layers.
encoder_depth (int): The number of hidden layers for the encoder.
decoder_depth (int): The number of hidden layers for the decoder.
"""
tasks = [RECONSTRUCTION_KEY]
def __init__(
self,
hidden_dim: int,
width_multiplier: int,
encoder_depth: int,
decoder_depth: int,
*args: Any,
**kwargs: Any,
) -> None:
super(RNNAutoEncoderModel, self).__init__(*args, **kwargs)
self._inputs_dim = self._inputs_dim[-1]
self._hidden_dim = hidden_dim * width_multiplier
self._width_multiplier = width_multiplier
self._encoder_depth = encoder_depth
self._decoder_depth = decoder_depth
self._depth = self._encoder_depth + self._decoder_depth
self._input = nn.Linear(self._inputs_dim, hidden_dim)
self._layers = nn.ModuleList()
for i in range(encoder_depth):
self._layers.append(LSTM(hidden_dim, hidden_dim))
self._layers.append(InputRepeater())
for i in range(decoder_depth):
self._layers.append(LSTM(hidden_dim, hidden_dim))
self._output = nn.Linear(hidden_dim, self._outputs_dim)
self._output_activation = OutputActivation(self._task, dim=2)
[docs]
def forward(
self,
x: torch.Tensor,
staged_output: bool = False,
input_kwargs: Dict[str, Any] = {},
output_kwargs: Dict[str, Any] = {},
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
"""The forward function of the model.
Args:
x (torch.Tensor): The input tensor.
staged_output (bool): Whether to return the output of each layer. Defaults to
False.
input_kwargs (Dict[str, Any]): The kwargs for the input layer.
output_kwargs (Dict[str, Any]): The kwargs for the output layer.
"""
layers_outputs = []
h = None
assert (
len(x.shape) == 3
), f"The input shape should be `(batch_size, seq_len, inputs_dim)`, but got {x.shape}."
T = x.shape[1]
x = self._input(x, **input_kwargs)
for i in range(len(self._layers)):
if isinstance(self._layers[i], InputRepeater):
x = self._layers[i](h, T)
continue
else:
x, h, _ = self._layers[i](x)
if staged_output:
layers_outputs.append(x)
x = self.final_layer(x, **output_kwargs)
# Permute the output to be of shape (batch_size, outputs_dim, seq_len)
x = x.permute(0, 2, 1)
if staged_output:
return x, layers_outputs
return x
[docs]
def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor:
"""This function is used to get the final layer output."""
x = self._output(x, **output_kwargs)
return self._output_activation(x)
[docs]
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""This method is used to add the model specific arguments to the parent parser."""
parser = super(RNNAutoEncoderModel, RNNAutoEncoderModel).add_specific_args(
parent_parser
)
parser.add_argument(
"--model_hidden_dim",
type=int,
default=128,
help="The dimensions of the hidden layers.",
)
parser.add_argument(
"--model_width_multiplier",
type=int,
default=1,
help="The width multiplier for the hidden layers.",
)
parser.add_argument(
"--model_encoder_depth",
type=int,
default=3,
help="The number of hidden layers for the encoder.",
)
parser.add_argument(
"--model_decoder_depth",
type=int,
default=3,
help="The number of hidden layers for the decoder.",
)
return parser