Source code for yamle.data

from typing import Type
from yamle.data.datamodule import BaseDataModule
from yamle.data.classification import (
    ToyTwoMoonsClassificationDataModule,
    ToyTwoCirclesClassificationDataModule,
    TorchvisionClassificationDataModuleMNIST,
    TorchvisionClassificationDataModuleCIFAR10,
    TorchvisionClassificationDataModuleCIFAR5,
    TorchvisionClassificationDataModuleCIFAR3,
    TorchvisionClassificationDataModuleCIFAR100,
    TorchvisionClassificationDataModuleFashionMNIST,
    TinyImageNetClassificationDataModule,
    TorchvisionClassificationDataModuleSVHN,
    BreastCancerUCIClassificationDataModule,
    AdultIncomeUCIClassificationDataModule,
    CarEvaluationUCIClassificationDataModule,
    CreditUCIClassificationDataModule,
    DermatologyUCIClassificationDataModule,
    PneumoniaMNISTClassificationDataModule,
    DermaMNISTClassificationDataModule,
    BreastMNISTClassificationDataModule,
    BloodMNISTClassificationDataModule,
    ECG5000ClassificationDataModule,
)
from yamle.data.regression import (
    ToyRegressionDataModule,
    ConcreteUCIRegressionDataModule,
    EnergyUCIRegressionDataModule,
    BostonUCIRegressionDataModule,
    TemperatureTimeSeriesDataModule,
    WineQualityUCIRegressionDataModule,
    YachtUCIRegressionDataModule,
    AbaloneUCIRegressionDataModule,
    TelemonitoringUCIRegressionDataModule,
    RetinaMNISTDataModule,
    WikiFaceRegressionDataModule,
    TorchvisionRotationRegressionDataModuleMNIST,
    TorchvisionRotationRegressionDataModuleCIFAR10,
    TorchvisionRotationRegressionDataModuleFashionMNIST,
    TorchvisionRotationRegressionDataModuleSVHN,
    TorchvisionRotationRegressionDataModuleCIFAR100,
    TinyImageNetRotationRegressionDataModule,
)
from yamle.data.segmentation import TorchvisionSegmentationDataModuleCityscapes
from yamle.data.text import (
    TorchtextClassificationModelWikiText2,
    TorchtextClassificationModelWikiText103,
    TorchtextClassificationModelIMDB,
    Shakespeare,
)
from yamle.data.depth import NYUv2DataModule
from yamle.data.reconstruction import ECG5000ReconstructionDataModule

AVAILABLE_DATAMODULES = {
    "mnist": TorchvisionClassificationDataModuleMNIST,
    "cifar3": TorchvisionClassificationDataModuleCIFAR3,
    "cifar5": TorchvisionClassificationDataModuleCIFAR5,
    "cifar10": TorchvisionClassificationDataModuleCIFAR10,
    "cifar100": TorchvisionClassificationDataModuleCIFAR100,
    "svhn": TorchvisionClassificationDataModuleSVHN,
    "fashionmnist": TorchvisionClassificationDataModuleFashionMNIST,
    "tinyimagenet": TinyImageNetClassificationDataModule,
    "wiki_face": WikiFaceRegressionDataModule,
    "pneumoniamnist": PneumoniaMNISTClassificationDataModule,
    "breastmnist": BreastMNISTClassificationDataModule,
    "retinamnist": RetinaMNISTDataModule,
    "dermamnist": DermaMNISTClassificationDataModule,
    "bloodmnist": BloodMNISTClassificationDataModule,
    "toyregression": ToyRegressionDataModule,
    "toymoons": ToyTwoMoonsClassificationDataModule,
    "toycircles": ToyTwoCirclesClassificationDataModule,
    "ecg5000classification": ECG5000ClassificationDataModule,
    "ecg5000reconstruction": ECG5000ReconstructionDataModule,
    "cityscapes": TorchvisionSegmentationDataModuleCityscapes,
    "wikitext2": TorchtextClassificationModelWikiText2,
    "wikitext103": TorchtextClassificationModelWikiText103,
    "imdb": TorchtextClassificationModelIMDB,
    "shakespeare": Shakespeare,
    "concrete": ConcreteUCIRegressionDataModule,
    "energy": EnergyUCIRegressionDataModule,
    "boston": BostonUCIRegressionDataModule,
    "wine": WineQualityUCIRegressionDataModule,
    "yacht": YachtUCIRegressionDataModule,
    "abalone": AbaloneUCIRegressionDataModule,
    "telemonitoring": TelemonitoringUCIRegressionDataModule,
    "breastcancer": BreastCancerUCIClassificationDataModule,
    "adultincome": AdultIncomeUCIClassificationDataModule,
    "carevaluation": CarEvaluationUCIClassificationDataModule,
    "credit": CreditUCIClassificationDataModule,
    "dermatology": DermatologyUCIClassificationDataModule,
    "temperature": TemperatureTimeSeriesDataModule,
    "nyuv2": NYUv2DataModule,
    "rotation_mnist": TorchvisionRotationRegressionDataModuleMNIST,
    "rotation_cifar10": TorchvisionRotationRegressionDataModuleCIFAR10,
    "rotation_fashionmnist": TorchvisionRotationRegressionDataModuleFashionMNIST,
    "rotation_svhn": TorchvisionRotationRegressionDataModuleSVHN,
    "rotation_cifar100": TorchvisionRotationRegressionDataModuleCIFAR100,
    "rotation_tinyimagenet": TinyImageNetRotationRegressionDataModule,
}


[docs] def data_factory(data_type: str) -> Type[BaseDataModule]: """This function is used to create a data module instance based on the data type.""" if data_type not in AVAILABLE_DATAMODULES: raise ValueError(f"Unknown data type {data_type}.") return AVAILABLE_DATAMODULES[data_type]