yamle.data.segmentation module#

class yamle.data.segmentation.TorchvisionSegmentationDataModule(dataset, **kwargs)[source]#

Bases: BaseDataModule

Data module for the torchvision segmentation datasets.

Parameters:
  • dataset (str) – Name of the torchvision dataset. Currently supported are cityscapes.

  • seed (int) – Seed for the random number generator.

  • data_dir (str) – Path to the data directory.

  • train_tranform (Callable) – Transformations to apply to the training data. Default: transforms.ToTensor(), transforms.Normalize(mean, str).

  • test_transform (Callable) – Transformations to apply to the test data. Default: transforms.ToTensor(), transforms.Normalize(mean, str).

mean = None#
std = None#
inputs_dim = None#
outputs_dim = None#
task = 'segmentation'#
inputs_dtype = torch.float32#
outputs_dtype = torch.int64#
available_test_augmentations: List[str]#
get_transform(name)[source]#

This is a helper function to get the transform by name.

Return type:

Callable

plot(tester, save_path, specific_name='')[source]#

Plot random samples from the training and validation set to check if the data is correctly predicted

Return type:

None

prepare_data()[source]#

Download and prepare the data, the data is stored in self._train_dataset, self._validation_dataset and self._test_dataset.

Return type:

None

available_transforms: List[str]#
test_augmentations: List[str]#
class yamle.data.segmentation.TorchvisionSegmentationDataModuleCityscapes(**kwargs)[source]#

Bases: TorchvisionSegmentationDataModule

Data module for the Cityscapes dataset.

inputs_dim = (3, 512, 256)#
mean = [0.28689554, 0.32513303, 0.28389177]#
std = [0.18696375, 0.19017339, 0.18720214]#
ignore_indices = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, 34]#
outputs_dim = 35#
targets_dim = (35, 512, 256)#
plot(tester, save_path, specific_name='')[source]#

Plot random samples from the training and validation set to check if the data is correctly predicted

Return type:

None

available_transforms: List[str]#
available_test_augmentations: List[str]#
test_augmentations: List[str]#