yamle.models.rnn module#
- class yamle.models.rnn.RNNModel(hidden_dim, width_multiplier, depth, *args, **kwargs)[source]#
Bases:
BaseModelThis class is used to create a LSTM model with the given parameters.
- Parameters:
- tasks = ['regression', 'classification']#
- forward(x, staged_output=False, input_kwargs={}, output_kwargs={})[source]#
The forward function of the model.
- Parameters:
- Return type:
Union[Tensor,Tuple[Tensor,List[Tensor]]]
- final_layer(x, **output_kwargs)[source]#
This function is used to get the final layer output.
- Return type:
Tensor
- add_method_specific_layers(method, **kwargs)[source]#
This method is used to add method specific layers to the model.
- Parameters:
method¶ (str) – The method to use.
- Return type:
None
- static add_specific_args(parent_parser)[source]#
This method is used to add the model specific arguments to the parent parser.
- Return type:
ArgumentParser
-
training:
bool#
- class yamle.models.rnn.RNNAutoEncoderModel(hidden_dim, width_multiplier, encoder_depth, decoder_depth, *args, **kwargs)[source]#
Bases:
BaseModelThis class is used to create a LSTM model with the given parameters.
This is an autoencoder model, so the input and output dimensions are the same. The encoder consists of LSTM layers, while the decoder consists of LSTM layers. The encoder’s last hidden state is repeated and used as the input to the decoder. Then all the hidden states of the decoder are processed through a linear layer to get the two times the input dimension, one for mean and one for variance.
- Parameters:
- tasks = ['reconstruction']#
- forward(x, staged_output=False, input_kwargs={}, output_kwargs={})[source]#
The forward function of the model.
- Parameters:
- Return type:
Union[Tensor,Tuple[Tensor,List[Tensor]]]
- final_layer(x, **output_kwargs)[source]#
This function is used to get the final layer output.
- Return type:
Tensor
- static add_specific_args(parent_parser)[source]#
This method is used to add the model specific arguments to the parent parser.
- Return type:
ArgumentParser
-
training:
bool#