yamle.data.segmentation module#
- class yamle.data.segmentation.TorchvisionSegmentationDataModule(dataset, **kwargs)[source]#
Bases:
BaseDataModuleData module for the torchvision segmentation datasets.
- Parameters:
dataset¶ (str) – Name of the torchvision dataset. Currently supported are cityscapes.
seed¶ (int) – Seed for the random number generator.
data_dir¶ (str) – Path to the data directory.
train_tranform¶ (Callable) – Transformations to apply to the training data. Default: transforms.ToTensor(), transforms.Normalize(mean, str).
test_transform¶ (Callable) – Transformations to apply to the test data. Default: transforms.ToTensor(), transforms.Normalize(mean, str).
- mean = None#
- std = None#
- inputs_dim = None#
- outputs_dim = None#
- task = 'segmentation'#
- inputs_dtype = torch.float32#
- outputs_dtype = torch.int64#
- available_test_augmentations: List[str]#
- get_transform(name)[source]#
This is a helper function to get the transform by name.
- Return type:
Callable
- plot(tester, save_path, specific_name='')[source]#
Plot random samples from the training and validation set to check if the data is correctly predicted
- Return type:
None
- prepare_data()[source]#
Download and prepare the data, the data is stored in self._train_dataset, self._validation_dataset and self._test_dataset.
- Return type:
None
- available_transforms: List[str]#
- test_augmentations: List[str]#
- class yamle.data.segmentation.TorchvisionSegmentationDataModuleCityscapes(**kwargs)[source]#
Bases:
TorchvisionSegmentationDataModuleData module for the Cityscapes dataset.
- inputs_dim = (3, 512, 256)#
- mean = [0.28689554, 0.32513303, 0.28389177]#
- std = [0.18696375, 0.19017339, 0.18720214]#
- ignore_indices = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, 34]#
- outputs_dim = 35#
- targets_dim = (35, 512, 256)#
- plot(tester, save_path, specific_name='')[source]#
Plot random samples from the training and validation set to check if the data is correctly predicted
- Return type:
None
- available_transforms: List[str]#
- available_test_augmentations: List[str]#
- test_augmentations: List[str]#