Source code for yamle.third_party.nyuv2

"""
author: Mihai Suteu
date: 15/05/19

Adapted from: https://github.com/xapharius/pytorch-nyuv2/
"""

import os
import h5py
import shutil
import tarfile
import zipfile
import numpy as np

from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url


[docs] class NYUv2(Dataset): """ PyTorch wrapper for the NYUv2 dataset. Data sources available: Semantic Segmentation and Depth Estimation. 1. Semantic Segmentation: 1 channel representing one of the 14 (0 - background) classes. 2. Depth Images: 1 channel with floats. Args: root (str): Root directory of dataset where ``NYUv2/processed/training.pt`` and ``NYUv2/processed/test.pt`` exist. train (bool, optional): If True, creates dataset from ``training.pt``, download (bool, optional): If true, downloads the dataset from the internet and task (str, optional): Choose from: segmentation, depth. """ def __init__( self, root: str, train: bool = True, download: bool = False, task: str = "segmentation", ) -> None: super().__init__() self.root = root assert task in [ "segmentation", "depth", ], f"Task {task} not supported. Choose from: segmentation, depth" self.task = task self.train = train self._split = "train" if train else "test" if download: self.download() if not self._check_exists(): raise RuntimeError( "Dataset not complete." + " You can use download=True to download it" ) # rgb folder as ground truth self._files = sorted(os.listdir(os.path.join(root, f"{self._split}_rgb"))) def __getitem__(self, index: int): def folder(name): return os.path.join(self.root, f"{self._split}_{name}") x = Image.open(os.path.join(folder("rgb"), self._files[index])) if self.task == "segmentation": y = Image.open(os.path.join(folder("seg13"), self._files[index])) elif self.task == "depth": y = Image.open(os.path.join(folder("depth"), self._files[index])) else: raise NotImplementedError(f"Task {self.task} not implemented") return x, y def __len__(self): return len(self._files) def _check_exists(self) -> bool: """ Only checking for folder existence """ try: for split in ["train", "test"]: part = "seg13" if self.task == "segmentation" else "depth" path = os.path.join(self.root, f"{split}_{part}") if not os.path.exists(path): raise FileNotFoundError("Missing Folder") except FileNotFoundError as e: return False return True
[docs] def download(self): if self._check_exists(): return download_rgb(self.root) if self.task == "segmentation": download_seg(self.root) if self.task == "depth": download_depth(self.root)
[docs] def download_rgb(root: str): train_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz" test_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz" def _proc(url: str, dst: str): if not os.path.exists(dst): tar = os.path.join(root, url.split("/")[-1]) if not os.path.exists(tar): download_url(url, root) if os.path.exists(tar): _unpack(tar) _replace_folder(tar.rstrip(".tgz"), dst) _rename_files(dst, lambda x: x.split("_")[2]) _proc(train_url, os.path.join(root, "train_rgb")) _proc(test_url, os.path.join(root, "test_rgb"))
[docs] def download_seg(root: str): train_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz" test_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz" def _proc(url: str, dst: str): if not os.path.exists(dst): tar = os.path.join(root, url.split("/")[-1]) if not os.path.exists(tar): download_url(url, root) if os.path.exists(tar): _unpack(tar) _replace_folder(tar.rstrip(".tgz"), dst) _rename_files(dst, lambda x: x.split("_")[3]) _proc(train_url, os.path.join(root, "train_seg13")) _proc(test_url, os.path.join(root, "test_seg13"))
[docs] def download_depth(root: str): url = ( "http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat" ) train_dst = os.path.join(root, "train_depth") test_dst = os.path.join(root, "test_depth") if not os.path.exists(train_dst) or not os.path.exists(test_dst): tar = os.path.join(root, url.split("/")[-1]) if not os.path.exists(tar): download_url(url, root) if os.path.exists(tar): train_ids = [ f.split(".")[0] for f in os.listdir(os.path.join(root, "train_rgb")) ] _create_depth_files(tar, root, train_ids)
def _unpack(file: str): """ Unpacks tar and zip, does nothing for any other type :param file: path of file """ path = file.rsplit(".", 1)[0] if file.endswith(".tgz"): tar = tarfile.open(file, "r:gz") tar.extractall(path) tar.close() elif file.endswith(".zip"): zip = zipfile.ZipFile(file, "r") zip.extractall(path) zip.close() def _rename_files(folder: str, rename_func: callable): """ Renames all files inside a folder based on the passed rename function :param folder: path to folder that contains files :param rename_func: function renaming filename (not including path) str -> str """ imgs_old = os.listdir(folder) imgs_new = [rename_func(file) for file in imgs_old] for img_old, img_new in zip(imgs_old, imgs_new): shutil.move(os.path.join(folder, img_old), os.path.join(folder, img_new)) def _replace_folder(src: str, dst: str): """ Rename src into dst, replacing/overwriting dst if it exists. """ if os.path.exists(dst): shutil.rmtree(dst) shutil.move(src, dst) def _create_depth_files(mat_file: str, root: str, train_ids: list): """ Extract the depth arrays from the mat file into images :param mat_file: path to the official labelled dataset .mat file :param root: The root directory of the dataset :param train_ids: the IDs of the training images as string (for splitting) """ os.mkdir(os.path.join(root, "train_depth")) os.mkdir(os.path.join(root, "test_depth")) train_ids = set(train_ids) depths = h5py.File(mat_file, "r")["depths"] for i in range(len(depths)): img = (depths[i] * 1e4).astype(np.uint16).T id_ = str(i + 1).zfill(4) folder = "train" if id_ in train_ids else "test" save_path = os.path.join(root, f"{folder}_depth", id_ + ".png") Image.fromarray(img).save(save_path)