Extending DataModule#

In this Tutorial we will demonstrate how to extend the BaseDataModule class to create a custom DataModule.

We will be adding or looking at how to add the MNIST dataset to YAMLE through a custom DataModule. MNIST is a dataset of handwritten digits, which is a popular dataset for testing image classification models. The dataset is available through the torchvision package.

To start an implementation of any datamodule we recommend to look at the BaseDataModule class. It has many arguments which can be used to customize the datamodule.



class BaseDataModule(ABC):
    """General data module returning training, validation and test data loaders.

    Args:
        validation_portion (float): Portion of the training data to use for validation.
        test_portion (float): Portion of the training data to use for test if test data is not provided.
        calibration_portion (float): Portion of the training data to use for calibration.
        seed (int): Seed for the random number generator.
        data_dir (str): Path to the data directory.
        train_splits (Optional[int]): Number of splits to use for the training data.
        train_splits_proportions (Optional[List[float]]): Proportions of the training data to use for each split.
        train_size (Optional[int]): Size of the training data.
        train_tranform (Optional[List[str]]): Transformations to apply to the training data. Note that if the list is provided, it is ordered.
        test_transform (Optional[List[str]]): Transformations to apply to the test data. Note that if the list is provided, it is ordered.
        test_augmentations (Optional[List[str]]): Augmentations to apply to the test data. Note that if the list is provided, it is ordered.
        train_target_transform (Optional[List[str]]): Transformations to apply to the training targets. Note that if the list is provided, it is ordered.
        test_target_transform (Optional[List[str]]): Transformations to apply to the test targets. Note that if the list is provided, it is ordered.
        train_joint_transform (Optional[List[str]]): Transformations to apply to the training data as well as the targets. Note that if the list is provided, it is ordered.
        test_joint_transform (Optional[List[str]]): Transformations to apply to the test data as well as the targets. Note that if the list is provided, it is ordered.
        num_workers (Optional[int]): Number of workers to use for the data loaders. Defaults to None.
        batch_size (int): Batch size to use for the data loaders. Defaults to 32.
        pin_memory (bool): Whether to use pinned memory for the data loaders. Defaults to True.
    """

This class also does already cointain a lot of useful functionality e.g. to do automatic splitting of the dataset to training, validation and calibration portions e.g. through the setup method.

    def setup(self, *args: Any, **kwargs: Any) -> None:
        """Split the data into training, validation, calibration and test sets.

        The training and test sets need to be always provided, the validation and calibration sets are optional.
        The validation and calibration sets can be also provided in the base datamodule, then the portions are
        ignored.
        The splitting with respect to validation and calibration sets is done with respect to the training set.
        """

        if (self._validation_dataset is None and self._validation_portion > 0) or (
            self._calibration_dataset is None and self._calibration_portion > 0
        ):
            validation_portion = (
                self._validation_portion if self._validation_dataset is None else 0
            )
            calibration_portion = (
                self._calibration_portion if self._calibration_dataset is None else 0
            )

            validation_size = int(validation_portion * len(self._train_dataset))
            calibration_size = int(calibration_portion * len(self._train_dataset))
            train_size = len(self._train_dataset) - validation_size - calibration_size
            (
                train_dataset,
                validation_dataset,
                calibration_dataset,
            ) = random_split(
                self._train_dataset,
                [train_size, validation_size, calibration_size],
                generator=torch.Generator().manual_seed(self._seed),
            )

            split = False  # This checks if the dataset was split
            if len(validation_dataset) != 0 and self._validation_dataset is None:
                self._validation_dataset = validation_dataset
                split = True

            if len(calibration_dataset) != 0 and self._calibration_dataset is None:
                self._calibration_dataset = calibration_dataset
                split = True

            if split:
                self._train_dataset = train_dataset

        if not isinstance(self._train_dataset, SurrogateDataset):
            self._train_dataset = SurrogateDataset(
                self._train_dataset,
                transform=self.train_transform(),
                target_transform=self.train_target_transform(),
                joint_transform=self.train_joint_transform(),
            )
        if self._validation_dataset is not None and not isinstance(
            self._validation_dataset, SurrogateDataset
        ):
            self._validation_dataset = SurrogateDataset(
                self._validation_dataset,
                transform=self.test_transform(),
                target_transform=self.test_target_transform(),
                joint_transform=self.test_joint_transform(),
            )
        if self._calibration_dataset is not None and not isinstance(
            self._calibration_dataset, SurrogateDataset
        ):
            self._calibration_dataset = SurrogateDataset(
                self._calibration_dataset,
                transform=self.test_transform(),
                target_transform=self.test_target_transform(),
                joint_transform=self.test_joint_transform(),
            )
        if not isinstance(self._test_dataset, SurrogateDataset):
            self._test_dataset = SurrogateDataset(
                self._test_dataset,
                transform=self.test_transform(),
                target_transform=self.test_target_transform(),
                joint_transform=self.test_joint_transform(),
            )

        logging.info(f"Train dataset total size: {self.train_dataset_size()}")
        if self._train_splits is not None:
            for i in range(self._train_splits):
                logging.info(
                    f"Train dataset size for split {i}: {len(self.train_dataset(split=i))}"
                )
        logging.info(f"Validation dataset size: {self.validation_dataset_size()}")
        logging.info(f"Calibration dataset size: {self.calibration_dataset_size()}")
        logging.info(f"Test dataset size: {self.test_dataset_size()}")

Note that the setup method wraps the datasets into a SurrogateDataset which is a wrapper around the torch.utils.data.Dataset class. This wrapper allows to manually control the data or the target transformations.

The transformations are generally managed through a get_transform method which is being called for each dataset split: training, validation, calibration and testing.

Then there is the prepare_data method which is used to download the dataset. This method is only called once per machine and not per GPU. This is important to know if you want to download the dataset multiple times. The prepare_data method is called before the setup method.

Now let’s start with the implementation of the MNIST datamodule. In fact, many of the torchvision datasets can be processed in a similar way hence we will create two classes. One for general torchvision classification datasets and one concretely for MNIST.

The torchvision classification datamodule is implemented in TorchvisionClassificationDataModule.

class TorchvisionClassificationDataModule(VisionClassificationDataModule):
    """Data module for the torchvision datasets.

    Args:
        dataset (str): Name of the torchvision dataset. Currently supported are `mnist`, `fashion_mnist`, `cifar10`, `cifar100` and `tinyimagenet`.
        pad_to_32 (bool): Whether to pad the images to 32x32. Defaults to False.
    """

    outputs_dtype = torch.long

    def __init__(
        self, dataset: str, pad_to_32: bool = False, *args: Any, **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)
        if dataset not in [
            "mnist",
            "fashionmnist",
            "cifar10",
            "cifar3",
            "cifar5",
            "cifar100",
            "tinyimagenet",
            "svhn",
        ]:
            raise ValueError("Dataset not supported.")
        self._dataset = dataset
        self._pad_to_32 = pad_to_32
        if pad_to_32:
            assert dataset in [
                "mnist",
                "fashionmnist",
            ], "Padding only supported for 28x28 images."
            self.inputs_dim = (self.inputs_dim[0], 32, 32)

    def prepare_data(self) -> None:
        """Download and prepare the data, the data is stored in `self._train_dataset`, `self._validation_dataset` and `self._test_dataset`."""
        super().prepare_data()
        if self._dataset == "mnist":
            self._train_dataset = torchvision.datasets.MNIST(
                self._data_dir, train=True, download=True
            )
            self._test_dataset = torchvision.datasets.MNIST(
                self._data_dir, train=False, download=True
            )
        elif self._dataset == "fashionmnist":
            self._train_dataset = torchvision.datasets.FashionMNIST(
                self._data_dir, train=True, download=True
            )
            self._test_dataset = torchvision.datasets.FashionMNIST(
                self._data_dir, train=False, download=True
            )
        elif self._dataset == "svhn":
            self._train_dataset = torchvision.datasets.SVHN(
                self._data_dir, split="train", download=True
            )
            self._test_dataset = torchvision.datasets.SVHN(
                self._data_dir, split="test", download=True
            )
        elif self._dataset == "cifar10":
            self._train_dataset = torchvision.datasets.CIFAR10(
                self._data_dir, train=True, download=True
            )
            self._test_dataset = torchvision.datasets.CIFAR10(
                self._data_dir, train=False, download=True
            )
        elif self._dataset == "cifar3":
            self._train_dataset = ClassificationDatasetSubset(
                torchvision.datasets.CIFAR10(self._data_dir, train=True, download=True),
                indices=self._indices,
            )
            self._test_dataset = ClassificationDatasetSubset(
                torchvision.datasets.CIFAR10(
                    self._data_dir, train=False, download=True
                ),
                indices=self._indices,
            )
        # This is a version of the cifar10 dataset only with 5 classes
        elif self._dataset == "cifar5":
            self._train_dataset = ClassificationDatasetSubset(
                torchvision.datasets.CIFAR10(self._data_dir, train=True, download=True),
                indices=self._indices,
            )
            self._test_dataset = ClassificationDatasetSubset(
                torchvision.datasets.CIFAR10(
                    self._data_dir, train=False, download=True
                ),
                indices=self._indices,
            )
        elif self._dataset == "cifar100":
            self._train_dataset = torchvision.datasets.CIFAR100(
                self._data_dir, train=True, download=True
            )
            self._test_dataset = torchvision.datasets.CIFAR100(
                self._data_dir, train=False, download=True
            )
        elif self._dataset == "tinyimagenet":
            self._train_dataset = TinyImageNet(
                self._data_dir, split="train", download=True
            )
            self._test_dataset = TinyImageNet(
                self._data_dir, split="val", download=True
            )
        else:
            raise ValueError("Dataset not supported.")

        if self._pad_to_32:
            self._train_dataset = InputImagePaddingDataset(self._train_dataset, 2)
            self._test_dataset = InputImagePaddingDataset(self._test_dataset, 2)

    @staticmethod
    def add_specific_args(
        parent_parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
        parser = super(
            TorchvisionClassificationDataModule, TorchvisionClassificationDataModule
        ).add_specific_args(parent_parser)
        parser.add_argument(
            "--datamodule_pad_to_32",
            type=int,
            choices=[0, 1],
            default=0,
            help="Whether to pad the images to 32x32.",
        )
        return parser

It inherits from a VisionClassificationDataModule which implements useful methods for debugging and plotting of the predictions or the applied augmentations.

Any datamodule also allows specification of custom arguments e.g. the datamodule_pad_to_32 argument through add_specific_args.

    @staticmethod
    def add_specific_args(
        parent_parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
        parser = super(
            TorchvisionClassificationDataModule, TorchvisionClassificationDataModule
        ).add_specific_args(parent_parser)
        parser.add_argument(
            "--datamodule_pad_to_32",
            type=int,
            choices=[0, 1],
            default=0,
            help="Whether to pad the images to 32x32.",
        )
        return parser

Note the datamodule_ prefix which is used to avoid name clashes with other arguments and separate the datamodule arguments from any other arguments.

The module can accept custom arguments such as pad_to_32 which can pad the image to a size of 32x32 pixels. This is useful if you want to use a model which requires a certain input size or to be used to apply out-ouf-distribution augmentations common in the field of out-of-distribution detection. Notice that, in practice the user only needs to fill in the prepare_data method which downloads the training or the test datasets and places them at the _data_dir location. The setup method is then used to wrap the datasets into a SurrogateDataset and to split the training dataset into training, validation and calibration portions.

Finally we create a concrete MNIST datamodule TorchvisionClassificationDataModuleMNIST which inherits from the TorchvisionClassificationDataModule

class TorchvisionClassificationDataModuleMNIST(TorchvisionClassificationDataModule):
    """Data module for the MNIST dataset."""

    inputs_dim = (1, 28, 28)
    outputs_dim = 10
    targets_dim = 1
    mean = (0.1307,)
    std = (0.3081,)

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__("mnist", *args, **kwargs)

    def prepare_data(self) -> None:
        super().prepare_data()
        if self.test_augmentations is not None:
            for augmentation in self.test_augmentations:
                if augmentation in VisionCorruption.available_augmentations:
                    assert (
                        self._pad_to_32
                    ), f"Padding to 32 is required for MNIST with augmentation {augmentation}."

Note that each end datamodule which implements a concrete dataset needs to specify the inputs_dim, outputs_dim, targets_dim and optionally mean and std attributes. These attributes are used to normalize the data and to calculate the input and output dimensions of the model.

The last step is to register the new datamodule in the __init__ module along all the other available datamodules.

from typing import Type
from yamle.data.datamodule import BaseDataModule
from yamle.data.classification import (
    ToyTwoMoonsClassificationDataModule,
    ToyTwoCirclesClassificationDataModule,
    TorchvisionClassificationDataModuleMNIST,
    TorchvisionClassificationDataModuleCIFAR10,
    TorchvisionClassificationDataModuleCIFAR5,
    TorchvisionClassificationDataModuleCIFAR3,
    TorchvisionClassificationDataModuleCIFAR100,
    TorchvisionClassificationDataModuleFashionMNIST,
    TinyImageNetClassificationDataModule,
    TorchvisionClassificationDataModuleSVHN,
    BreastCancerUCIClassificationDataModule,
    AdultIncomeUCIClassificationDataModule,
    CarEvaluationUCIClassificationDataModule,
    CreditUCIClassificationDataModule,
    DermatologyUCIClassificationDataModule,
    PneumoniaMNISTClassificationDataModule,
    DermaMNISTClassificationDataModule,
    BreastMNISTClassificationDataModule,
    BloodMNISTClassificationDataModule,
    ECG5000ClassificationDataModule,
)
from yamle.data.regression import (
    ToyRegressionDataModule,
    ConcreteUCIRegressionDataModule,
    EnergyUCIRegressionDataModule,
    BostonUCIRegressionDataModule,
    TemperatureTimeSeriesDataModule,
    WineQualityUCIRegressionDataModule,
    YachtUCIRegressionDataModule,
    AbaloneUCIRegressionDataModule,
    TelemonitoringUCIRegressionDataModule,
    RetinaMNISTDataModule,
    WikiFaceRegressionDataModule,
    TorchvisionRotationRegressionDataModuleMNIST,
    TorchvisionRotationRegressionDataModuleCIFAR10,
    TorchvisionRotationRegressionDataModuleFashionMNIST,
    TorchvisionRotationRegressionDataModuleSVHN,
    TorchvisionRotationRegressionDataModuleCIFAR100,
    TinyImageNetRotationRegressionDataModule,
)
from yamle.data.segmentation import TorchvisionSegmentationDataModuleCityscapes
from yamle.data.text import (
    TorchtextClassificationModelWikiText2,
    TorchtextClassificationModelWikiText103,
    TorchtextClassificationModelIMDB,
    Shakespeare,
)
from yamle.data.depth import NYUv2DataModule
from yamle.data.reconstruction import ECG5000ReconstructionDataModule

AVAILABLE_DATAMODULES = {
    "mnist": TorchvisionClassificationDataModuleMNIST,
    "cifar3": TorchvisionClassificationDataModuleCIFAR3,
    "cifar5": TorchvisionClassificationDataModuleCIFAR5,
    "cifar10": TorchvisionClassificationDataModuleCIFAR10,
    "cifar100": TorchvisionClassificationDataModuleCIFAR100,
    "svhn": TorchvisionClassificationDataModuleSVHN,
    "fashionmnist": TorchvisionClassificationDataModuleFashionMNIST,
    "tinyimagenet": TinyImageNetClassificationDataModule,
    "wiki_face": WikiFaceRegressionDataModule,
    "pneumoniamnist": PneumoniaMNISTClassificationDataModule,
    "breastmnist": BreastMNISTClassificationDataModule,
    "retinamnist": RetinaMNISTDataModule,
    "dermamnist": DermaMNISTClassificationDataModule,
    "bloodmnist": BloodMNISTClassificationDataModule,
    "toyregression": ToyRegressionDataModule,
    "toymoons": ToyTwoMoonsClassificationDataModule,
    "toycircles": ToyTwoCirclesClassificationDataModule,
    "ecg5000classification": ECG5000ClassificationDataModule,
    "ecg5000reconstruction": ECG5000ReconstructionDataModule,
    "cityscapes": TorchvisionSegmentationDataModuleCityscapes,
    "wikitext2": TorchtextClassificationModelWikiText2,
    "wikitext103": TorchtextClassificationModelWikiText103,
    "imdb": TorchtextClassificationModelIMDB,
    "shakespeare": Shakespeare,
    "concrete": ConcreteUCIRegressionDataModule,
    "energy": EnergyUCIRegressionDataModule,
    "boston": BostonUCIRegressionDataModule,
    "wine": WineQualityUCIRegressionDataModule,
    "yacht": YachtUCIRegressionDataModule,
    "abalone": AbaloneUCIRegressionDataModule,
    "telemonitoring": TelemonitoringUCIRegressionDataModule,
    "breastcancer": BreastCancerUCIClassificationDataModule,
    "adultincome": AdultIncomeUCIClassificationDataModule,
    "carevaluation": CarEvaluationUCIClassificationDataModule,
    "credit": CreditUCIClassificationDataModule,
    "dermatology": DermatologyUCIClassificationDataModule,
    "temperature": TemperatureTimeSeriesDataModule,
    "nyuv2": NYUv2DataModule,
    "rotation_mnist": TorchvisionRotationRegressionDataModuleMNIST,
    "rotation_cifar10": TorchvisionRotationRegressionDataModuleCIFAR10,
    "rotation_fashionmnist": TorchvisionRotationRegressionDataModuleFashionMNIST,
    "rotation_svhn": TorchvisionRotationRegressionDataModuleSVHN,
    "rotation_cifar100": TorchvisionRotationRegressionDataModuleCIFAR100,
    "rotation_tinyimagenet": TinyImageNetRotationRegressionDataModule,
}


def data_factory(data_type: str) -> Type[BaseDataModule]:
    """This function is used to create a data module instance based on the data type."""
    if data_type not in AVAILABLE_DATAMODULES:
        raise ValueError(f"Unknown data type {data_type}.")
    return AVAILABLE_DATAMODULES[data_type]