yamle.models.mixer module#
- class yamle.models.mixer.MixerLayer(tokens_dim, tokens_hidden_dim, channels_dim, channels_hidden_dim, dropout=0.0)[source]#
Bases:
SequentialThis class implements MLP-Mixer layer.
It consists of a token-mixing MLP and a channel-mixing MLP.
- Parameters:
tokens_dim¶ (int) – The dimension of the token mixing MLP. This is the embedding dimension.
tokens_hidden_dim¶ (int) – The dimension of the hidden layer in the token mixing MLP.
channels_dim¶ (int) – The dimension of the channel mixing MLP. This is the number of patches.
channels_hidden_dim¶ (int) – The dimension of the hidden layer in the channel mixing MLP.
dropout¶ (float) – The dropout rate.
- class yamle.models.mixer.MixerModel(patch_size=4, tokens_dim=128, tokens_hidden_dim=512, channels_hidden_dim=2048, num_layers=8, dropout=0.0, *args, **kwargs)[source]#
Bases:
BaseModelThis class is used to create a MLP-Mixer model.
- Parameters:
patch_size¶ (int) – The size of the patch to be used.
tokens_dim¶ (int) – The dimension of the token mixing MLP. This is the embedding dimension.
tokens_hidden_dim¶ (int) – The dimension of the hidden layer in the token mixing MLP.
channels_hidden_dim¶ (int) – The dimension of the hidden layer in the channel mixing MLP.
num_layers¶ (int) – The number of layers in the model.
dropout¶ (float) – The dropout value.
task¶ (str) – The task to be performed. It can be either classification or regression.
- tasks = ['classification', 'regression']#
- add_method_specific_layers(method, **kwargs)[source]#
This method is used to add method specific layers to the model.
- Parameters:
method¶ (str) – The method to use.
- Return type:
None
- forward(x, staged_output=False, input_kwargs={}, output_kwargs={})[source]#
Forward pass of the model.
- Parameters:
- Return type:
Union[Tensor,Tuple[Tensor,List[Tensor]]]
- final_layer(x, **output_kwargs)[source]#
This function is used to get the final layer output.
- Return type:
Tensor
- static add_specific_args(parser)[source]#
This function is used to add specific arguments to the parser.
- Return type:
ArgumentParser
-
training:
bool#