Extending Method#
In this Tutorial we will demonstrate how to extend the BaseMethod class to create a new model.
More concretely, we will be implemeting the Monte Carlo Dropout method from the paper Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning.
The method is based on inserting a dropout layer before each learnable layer in the network. The dropout layer is then used to sample from the posterior distribution of the weights. The method is implemented in the MCDropoutMethod class.
To start an implementation from scratch, we first need to import the MCSamplingMethod class. This class implements the basic functionality of the Monte Carlo sampling methods. It is a subclass of the BaseMethod class, which implements the basic functionality of all methods. The Monte Carlo sampling runs the network multiple times with different dropout masks. The output of the network is then averaged over all runs.
For a new method you would ideally create a new file in the yamle/methods folder e.g. yamle/methods/mcdropout.py. In this file you would then create a new class e.g. MCDropoutMethod that inherits from the MCSamplingMethod class.
class MCDropoutMethod(MCSamplingMethod):
"""This class is the extension of the base method for which the prediciton is performed using Monte Carlo dropout.
The dropout layers are added either to all layers or only to the last layer. Dropout is always on.
Args:
p (float): The dropout probability to be used for Monte Carlo dropout.
mode (str): Where to add dropout layers, can be either `all`, `last`, `partial` or `custom`.
no_input_replacement (bool): Whether to replace the input with dropout.
conv_filter_dropout (bool): Whether to place 2D dropout on the convolutional filters.
depth_portion_to_replace_start (float): The depth portion of the model to start replacing with dropout.
depth_portion_to_replace_end (float): The depth portion of the model to end replacing with dropout.
depth_indices (Tuple[int]): The indices of the layers to replace with dropout.
"""
def __init__(
self,
p: float = 0.5,
mode: str = "all",
no_input_replacement: bool = False,
conv_filter_dropout: bool = False,
depth_portion_to_replace_start: Optional[float] = None,
depth_portion_to_replace_end: Optional[float] = None,
depth_indices: Optional[Tuple[int]] = None,
*args: Any,
**kwargs: Any,
) -> None:
super(MCDropoutMethod, self).__init__(*args, **kwargs)
assert mode in [
"all",
"last",
"partial",
"custom",
], f"mode must be either `all`, `last`, `partial` or `custom`, got {mode}"
assert 0 <= p <= 1, f"p must be in [0, 1], got {p}"
if mode == "partial":
assert (
depth_portion_to_replace_start is not None
and depth_portion_to_replace_end is not None
), f"`depth_portion_to_replace_start` and `depth_portion_to_replace_end` must be provided for `partial` mode."
if mode == "custom":
assert (
depth_indices is not None
), f"`depth_indices` must be provided for `custom` mode."
self._p = p
self._mode = mode
if no_input_replacement:
disable_dropout_replacement(self.model._input)
dropout_mapping: Dict[nn.Module, nn.Module] = {
nn.Linear: Dropout1d,
nn.Conv1d: Dropout1d,
nn.Conv2d: Dropout2d if conv_filter_dropout else Dropout1d,
nn.Conv3d: Dropout3d if conv_filter_dropout else Dropout1d,
nn.Dropout: Dropout1d,
nn.Dropout2d: Dropout2d,
nn.Dropout3d: Dropout3d,
}
if mode == "all":
replace_with_dropout(self.model, p, dropout_mapping)
elif mode == "last":
self.model._output = nn.Sequential(Dropout1d(p=p), self.model._output)
elif mode == "partial":
max_count = count_linear_conv(self)
start_count = int(depth_portion_to_replace_start * max_count)
end_count = int(depth_portion_to_replace_end * max_count)
replace_with_dropout(
self.model, p, dropout_mapping, (start_count, end_count)
)
if max_count == 0:
# Maybe an LSTM model
max_count = count_lstm(self)
start_count = int(depth_portion_to_replace_start * max_count)
end_count = int(depth_portion_to_replace_end * max_count)
replace_with_dropout(
self.model, p, dropout_mapping, (start_count, end_count)
)
elif mode == "custom":
# Check if the model is an LSTM model
max_count = count_lstm(self)
if max_count != 0:
replace_with_dropout(
self.model, p, dropout_mapping, custom_indices=depth_indices
)
else:
max_count = count_linear_conv(self)
replace_with_dropout(
self.model, p, dropout_mapping, custom_indices=depth_indices
)
else:
raise NotImplementedError(
f"mode {mode} is not implemented for MCDropoutMethod"
)
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""This method is used to add the specific arguments for the class."""
parser = super(MCDropoutMethod, MCDropoutMethod).add_specific_args(
parent_parser
)
parser.add_argument(
"--method_p",
type=float,
default=0.5,
help="The dropout probability to be used for Monte Carlo dropout.",
)
parser.add_argument(
"--method_mode",
type=str,
choices=["all", "last", "partial", "custom"],
default="all",
help="Where to add dropout layers, can be either `all`, `last` or `partial`.",
)
parser.add_argument(
"--method_no_input_replacement",
type=int,
choices=[0, 1],
default=1,
help="Whether to place dropout on the input.",
)
parser.add_argument(
"--method_conv_filter_dropout",
type=int,
choices=[0, 1],
default=0,
help="Whether to place 2D dropout on the convolutional filters.",
)
parser.add_argument(
"--method_depth_portion_to_replace_start",
type=float,
default=None,
help="The depth portion of the model to start replacing with dropout.",
)
parser.add_argument(
"--method_depth_portion_to_replace_end",
type=float,
default=None,
help="The depth portion of the model to end replacing with dropout.",
)
parser.add_argument(
"--method_depth_indices",
type=str,
nargs="+",
default=None,
help="The indices of the layers to replace with dropout.",
)
return parser
In the __init__() method we first call the super().__init__() method to initialize the MCSamplingMethod class with any parent arguments. These include for example the number of samples to take or the number of epochs to train for.
In the def __init__ we also define any arguments that are specific to Monte Carlo Dropout e.g. p the dropout probability or mode whether to use insert the dropout layer in the entire network or only before the last layer.
To make a model compatible with the Monte Carlo Dropout method, we need to insert dropout layers into the network and make sure that they are always active. We implement a custom Dropout layer that has the dropout turned always on. Do implement custom modules under the yamle/models/specific folder for consistency.
class Dropout1d(nn.Module):
"""This is the dropout class but the probability is remebered in a `nn.Parameter`.
Args:
p (float): The probability of an element to be zeroed.
inplace (bool): If set to `True`, will do this operation in-place.
"""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super(Dropout1d, self).__init__()
self.register_buffer("_p", torch.tensor(p))
self.inplace = inplace
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to perform a forward pass through the dropout layer."""
return _dropout(x, p=self._p, training=True, inplace=self.inplace)
def extra_repr(self) -> str:
return super().extra_repr() + f"p={self._p}, inplace={self.inplace}"
The def __init__ also gives us the space to modify the Model by inserting the dropout layers via the replace_with_dropout method. Also for any support methods, we can add them to the yamle/models/specific/mcdropout.py file for consistency.
def replace_with_dropout(
model: nn.Module,
p: float,
dropout_mapping: Dict[nn.Module, nn.Module],
depth_start_end: Optional[Tuple[int, int]] = None,
custom_indices: Optional[Tuple[int, ...]] = None,
) -> None:
"""This method is used to replace all the `nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, or `nn.Conv3d` layers
with a Sequential layer containing the original layer followed by a `Dropout` layer.
Args:
model (nn.Module): The model to replace the layers in.
p (float): The probability in the `Dropout` layer.
dropout_mapping (Dict[nn.Module, nn.Module]): The mapping from the original layer to the `Dropout` layer.
depth_start_end (Optional[Tuple[int, int]]): The depth indices to start and end replacing the layers.
The first index is the starting portion of the network where to start replacing the layers.
The second index is the ending portion of the network where to end replacing the layers.
For example, if the model has 10 `nn.Linear` layers and `depth_start_end=(2, 8)`, then
the first 2 layers and the last 2 layers will not be replaced with `Dropout` layers.
custom_indices (Optional[Tuple[int, ...]]): The indices of the layers to replace.
"""
assert (
depth_start_end is None or custom_indices is None
), "Either depth_start_end or custom_indices should be specified, but not both."
for name, child in model.named_children():
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, LSTM)):
if hasattr(child, DISABLED_DROPOUT_KEY) and getattr(
child, DISABLED_DROPOUT_KEY
):
continue
if depth_start_end is not None:
if (
hasattr(child, "_counter") and child._counter < depth_start_end[0]
) or (
hasattr(child, "_counter") and child._counter > depth_start_end[1]
):
continue
if custom_indices is not None:
if not hasattr(child, "_counter") or (
hasattr(child, "_counter") and child._counter not in custom_indices
):
continue
if isinstance(child, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
setattr(
model, name, nn.Sequential(dropout_mapping[type(child)](p=p), child)
)
elif isinstance(child, LSTM):
setattr(
model,
name,
DropoutLSTM(
p=p,
input_size=child._input_size,
hidden_size=child._hidden_size,
),
)
elif isinstance(child, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
setattr(model, name, dropout_mapping[type(child)](p=p))
else:
replace_with_dropout(
child, p, dropout_mapping, depth_start_end, custom_indices
)
Now let’s talk about how to customise the training, validation and test steps of a method. These are generally defined in the _training_step method, the _validation_step method and the _test_step method. In general, these functions call a default _step method that is defined in the BaseMethod class.
def _step(
self,
batch: List[torch.Tensor],
batch_idx: int,
optimizer_idx: Optional[int] = None,
phase: str = TRAIN_KEY,
) -> Dict[str, Any]:
"""This method is used to perform a single training or validation step.
The data is split into inputs and targets and the forward pass is performed.
The predictions have the shame `(batch_size, num_members=1, num_outputs)` shape.
An average of the predictions is also computed across the ensemble members.
Args:
batch (List[torch.Tensor]): The batch of data.
batch_idx (int): The index of the batch.
"""
x, y = batch
y_hat = self._predict(x)
loss = self._loss(y_hat, y)
y_hat_mean = average_predictions(y_hat, self._task)
outputs = {}
outputs[LOSS_KEY] = loss
outputs[TARGET_KEY] = y.detach()
outputs[INPUT_KEY] = x.detach()
outputs[PREDICTION_KEY] = y_hat.detach()
outputs[MEAN_PREDICTION_KEY] = y_hat_mean.detach()
return outputs
This method is responsible for running the network and calculating the loss. The training, validation or test steps can define custom behaviour by overriding them. In this case, it is not necessary to modify any of these methods, since MCSamplingMethod already implements the correct behaviour through overriding the def _predict method.
def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor:
"""This method is used to perform a forward pass of the model.
It is done with respect to the number of samples specified in the constructor.
"""
outputs = []
num_members = self.training_num_members if self.training else self._num_members
for _ in range(num_members):
outputs.append(super()._predict(x, **forward_kwargs))
return torch.cat(outputs, dim=MEMBERS_DIM)
Lastly, we need to be able to provide the arguments of the method to the MCDropoutMethod class. This is done by overriding the add_specific_args method. This method is called by the BaseMethod class when the arguments are parsed.
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""This method is used to add the specific arguments for the class."""
parser = super(MCDropoutMethod, MCDropoutMethod).add_specific_args(
parent_parser
)
parser.add_argument(
"--method_p",
type=float,
default=0.5,
help="The dropout probability to be used for Monte Carlo dropout.",
)
parser.add_argument(
"--method_mode",
type=str,
choices=["all", "last", "partial", "custom"],
default="all",
help="Where to add dropout layers, can be either `all`, `last` or `partial`.",
)
parser.add_argument(
"--method_no_input_replacement",
type=int,
choices=[0, 1],
default=1,
help="Whether to place dropout on the input.",
)
parser.add_argument(
"--method_conv_filter_dropout",
type=int,
choices=[0, 1],
default=0,
help="Whether to place 2D dropout on the convolutional filters.",
)
parser.add_argument(
"--method_depth_portion_to_replace_start",
type=float,
default=None,
help="The depth portion of the model to start replacing with dropout.",
)
parser.add_argument(
"--method_depth_portion_to_replace_end",
type=float,
default=None,
help="The depth portion of the model to end replacing with dropout.",
)
parser.add_argument(
"--method_depth_indices",
type=str,
nargs="+",
default=None,
help="The indices of the layers to replace with dropout.",
)
return parser
Notice the --method_ prefix in the argument names. This is necessary to avoid conflicts with other arguments that can be parsed via the command line.
The last step would be to add the new method to the __init__ file in the yamle/methods folder. This is necessary to make the method available via the command line.
from typing import Type
from yamle.methods.method import BaseMethod
from yamle.methods.augmentation_classification import (
CutOutImageClassificationMethod,
CutMixImageClassificationMethod,
MixUpImageClassificationMethod,
RandomErasingImageClassificationMethod,
)
from yamle.methods.contrastive import SimCLRVisionMethod
from yamle.methods.mimo import (
MIMOMethod,
MixMoMethod,
DataMUXMethod,
UnMixMoMethod,
MIMMOMethod,
SAEMethod,
MixVitMethod,
)
from yamle.methods.pe import PEMethod
from yamle.methods.mcdropout import (
MCDropoutMethod,
MCDropConnectMethod,
MCStandOutMethod,
MCDropBlockMethod,
MCStochasticDepthMethod,
)
from yamle.methods.ensemble import (
EnsembleMethod,
SnapsotEnsembleMethod,
GradientBoostingEnsembleMethod,
)
from yamle.methods.moe import (
MultiHeadEnsembleMethod,
MixtureOfExpertsMethod,
)
from yamle.methods.dun import DUNMethod
from yamle.methods.early_exit import EarlyExitMethod
from yamle.methods.sngp import SNGPMethod
from yamle.methods.be import BEMethod
from yamle.methods.temperature_scaling import TemperatureMethod
from yamle.methods.rbnn import RBNNMethod
from yamle.methods.svi import (
SVIRTMethod,
SVILRTMethod,
SVIFlipOutRTMethod,
SVIFlipOutDropConnectMethod,
SVILRTVDMethod,
)
from yamle.methods.delta_uq import DeltaUQMethod
from yamle.methods.gp import GPMethod
from yamle.methods.evidential_regression import (
EvidentialRegressionMethod,
)
from yamle.methods.sgld import SGLDMethod
from yamle.methods.laplace import LaplaceMethod
from yamle.methods.swag import SWAGMethod
AVAILABLE_METHODS = {
"base": BaseMethod,
"simclrvision": SimCLRVisionMethod,
"cutout": CutOutImageClassificationMethod,
"cutmix": CutMixImageClassificationMethod,
"mixup": MixUpImageClassificationMethod,
"random_erasing": RandomErasingImageClassificationMethod,
"mimo": MIMOMethod,
"mimmo": MIMMOMethod,
"sae": SAEMethod,
"mixmo": MixMoMethod,
"mixvit": MixVitMethod,
"unmixmo": UnMixMoMethod,
"datamux": DataMUXMethod,
"pe": PEMethod,
"svirt": SVIRTMethod,
"svilrt": SVILRTMethod,
"svilrtvd": SVILRTVDMethod,
"sviflipout_gaussian": SVIFlipOutRTMethod,
"sviflipout_dropconnect": SVIFlipOutDropConnectMethod,
"mcdropout": MCDropoutMethod,
"mcdropconnect": MCDropConnectMethod,
"mcstandout": MCStandOutMethod,
"mcdropblock": MCDropBlockMethod,
"mcstochasticdepth": MCStochasticDepthMethod,
"ensemble": EnsembleMethod,
"snapshot_ensemble": SnapsotEnsembleMethod,
"multi_head_ensemble": MultiHeadEnsembleMethod,
"mixture_of_experts": MixtureOfExpertsMethod,
"gradient_boosting_ensemble": GradientBoostingEnsembleMethod,
"sgld": SGLDMethod,
"dun": DUNMethod,
"swag": SWAGMethod,
"early_exit": EarlyExitMethod,
"sngp": SNGPMethod,
"be": BEMethod,
"temperature": TemperatureMethod,
"rbnn": RBNNMethod,
"delta_uq": DeltaUQMethod,
"gp": GPMethod,
"laplace": LaplaceMethod,
"evidential_regression": EvidentialRegressionMethod,
}
def method_factory(method_type: str) -> Type[BaseMethod]:
"""This function is used to create a method instance based on the method type.
Args:
method_type (str): The type of method to create.
"""
if method_type not in AVAILABLE_METHODS:
raise ValueError(f"Unknown method type {method_type}.")
return AVAILABLE_METHODS[method_type]
Afterwards you can run the method via the command line. For example:
python3 yamle/cli/train.py --method mcdropout --trainer_devices "[0]" --datamodule mnist --datamodule_batch_size 256 --method_optimizer adam --method_learning_rate 3e-4 --regularizer l2 --method_regularizer_weight 1e-5 --loss crossentropy --save_path ./experiments --trainer_epochs 3 --model_hidden_dim 32 --model_depth 3 --datamodule_validation_portion 0.1 --save_path ./experiments --datamodule_pad_to_32 1 --method_p 0.3 --method_mode all --method_num_members 10