yamle.data.depth module#
- class yamle.data.depth.DepthEstimationDataModule(dataset, **kwargs)[source]#
Bases:
BaseDataModuleData module for depth estimation.
- Parameters:
dataset¶ (str) – Name of the torchvision dataset. Currently supported are nyudepthv2.
- mean = None#
- std = None#
- inputs_dim = None#
- outputs_dim = None#
- task = 'depth_estimation'#
- inputs_dtype = torch.float32#
- outputs_dtype = torch.float32#
- available_test_augmentations: List[str]#
- get_transform(name)[source]#
This is a helper function to get the transform by name.
- Return type:
Callable
- plot(tester, save_path, specific_name='')[source]#
Plot random samples from the training and validation set to check if the data is correctly predicted
- Return type:
None
- prepare_data()[source]#
Download and prepare the data, the data is stored in self._train_dataset, self._validation_dataset and self._test_dataset.
- Return type:
None
- available_transforms: List[str]#
- test_augmentations: List[str]#
- class yamle.data.depth.NYUv2DataModule(*args, **kwargs)[source]#
Bases:
DepthEstimationDataModuleData module for NYUv2.
- inputs_dim = (3, 320, 240)#
- mean = [0.485, 0.456, 0.406]#
- std = [0.229, 0.224, 0.225]#
- outputs_dim = 2#
- targets_dim = (2, 320, 240)#
- plot(tester, save_path, specific_name='')[source]#
Plot random samples from the training and validation set to check if the data is correctly predicted
- Return type:
None
- available_transforms: List[str]#
- available_test_augmentations: List[str]#
- test_augmentations: List[str]#