yamle.data.depth module#

class yamle.data.depth.DepthEstimationDataModule(dataset, **kwargs)[source]#

Bases: BaseDataModule

Data module for depth estimation.

Parameters:

dataset (str) – Name of the torchvision dataset. Currently supported are nyudepthv2.

mean = None#
std = None#
inputs_dim = None#
outputs_dim = None#
task = 'depth_estimation'#
inputs_dtype = torch.float32#
outputs_dtype = torch.float32#
available_test_augmentations: List[str]#
get_transform(name)[source]#

This is a helper function to get the transform by name.

Return type:

Callable

plot(tester, save_path, specific_name='')[source]#

Plot random samples from the training and validation set to check if the data is correctly predicted

Return type:

None

prepare_data()[source]#

Download and prepare the data, the data is stored in self._train_dataset, self._validation_dataset and self._test_dataset.

Return type:

None

available_transforms: List[str]#
test_augmentations: List[str]#
class yamle.data.depth.NYUv2DataModule(*args, **kwargs)[source]#

Bases: DepthEstimationDataModule

Data module for NYUv2.

inputs_dim = (3, 320, 240)#
mean = [0.485, 0.456, 0.406]#
std = [0.229, 0.224, 0.225]#
outputs_dim = 2#
targets_dim = (2, 320, 240)#
plot(tester, save_path, specific_name='')[source]#

Plot random samples from the training and validation set to check if the data is correctly predicted

Return type:

None

available_transforms: List[str]#
available_test_augmentations: List[str]#
test_augmentations: List[str]#