from typing import Callable, Any, Tuple, Union
import torch
from torchvision.datasets import Cityscapes
from pytorch_lightning import LightningModule
import matplotlib.pyplot as plt
import numpy as np
import os
from yamle.data.datamodule import BaseDataModule
from yamle.data.transforms import (
JointToTensor,
JointResize,
JointTargetSqueeze,
JointNormalize,
)
from yamle.utils.file_utils import plots_file
from yamle.utils.operation_utils import classification_uncertainty_decomposition
from yamle.defaults import (
SEGMENTATION_KEY,
MEAN_PREDICTION_KEY,
PREDICTION_KEY,
TRAIN_KEY,
VALIDATION_KEY,
TEST_KEY,
INPUT_KEY,
TARGET_KEY,
AVERAGE_WEIGHTS_KEY,
)
[docs]
class TorchvisionSegmentationDataModule(BaseDataModule):
"""Data module for the torchvision segmentation datasets.
Args:
dataset (str): Name of the torchvision dataset. Currently supported are `cityscapes`.
seed (int): Seed for the random number generator.
data_dir (str): Path to the data directory.
train_tranform (Callable): Transformations to apply to the training data. Default: `transforms.ToTensor(), transforms.Normalize(mean, str)`.
test_transform (Callable): Transformations to apply to the test data. Default: `transforms.ToTensor(), transforms.Normalize(mean, str)`.
"""
mean = None
std = None
inputs_dim = None
outputs_dim = None
task = SEGMENTATION_KEY
inputs_dtype = torch.float32
outputs_dtype = torch.long
def __init__(self, dataset: str, **kwargs: Any) -> None:
super().__init__(**kwargs)
if dataset not in ["cityscapes"]:
raise ValueError("Dataset not supported.")
self._dataset = dataset
self.available_test_augmentations = []
self.available_transforms += [
"jointtotensor",
"jointresize",
"jointtargetsqueeze",
"jointnormalize",
]
if self._train_joint_transform is None:
self._train_joint_transform = [
"jointtotensor",
"jointresize",
"jointtargetsqueeze",
"jointnormalize",
]
if self._test_joint_transform is None:
self._test_joint_transform = [
"jointtotensor",
"jointresize",
"jointtargetsqueeze",
"jointnormalize",
]
def _denormalize(self, image: torch.Tensor) -> torch.Tensor:
"""Denormalize the image."""
mean = torch.tensor(self.mean)
std = torch.tensor(self.std)
return image * std[:, None, None] + mean[:, None, None]
def _get_prediction(
self,
tester: LightningModule,
x: torch.Tensor,
y: Union[torch.Tensor, int],
phase: str = TRAIN_KEY,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns the prediction of the model."""
super()._get_prediction(tester, x, y, phase)
x = x.to(tester.device).unsqueeze(0)
y = y.to(tester.device).unsqueeze(0)
if phase == TRAIN_KEY:
output = tester.training_step([x, y], batch_idx=0)
elif phase == VALIDATION_KEY:
output = tester.validation_step([x, y], batch_idx=0)
elif phase == TEST_KEY:
output = tester.test_step([x, y], batch_idx=0)
x = output[INPUT_KEY].cpu()
y = output[TARGET_KEY].cpu()
y_hat = output[PREDICTION_KEY]
y_hat_mean = output[MEAN_PREDICTION_KEY]
averaging_weights = output[AVERAGE_WEIGHTS_KEY].cpu() if AVERAGE_WEIGHTS_KEY in output else None
labels = y_hat_mean.argmax(dim=1).cpu()
total, aleatoric, epistemic = classification_uncertainty_decomposition(y_hat, probabilities=True, weights=averaging_weights)
return (
labels.cpu().squeeze(0),
total.cpu().squeeze(0),
aleatoric.cpu().squeeze(0),
epistemic.cpu().squeeze(0),
)
[docs]
@torch.no_grad()
def plot(
self, tester: LightningModule, save_path: str, specific_name: str = ""
) -> None:
"""Plot random samples from the training and validation set to check if the data is correctly predicted"""
[docs]
def prepare_data(self) -> None:
"""Download and prepare the data, the data is stored in `self._train_dataset`, `self._validation_dataset` and `self._test_dataset`."""
super().prepare_data()
if self._dataset == "cityscapes":
self._train_dataset = Cityscapes(
os.path.join(self._data_dir, "cityscapes"),
split="train",
mode="fine",
target_type="semantic",
)
self._validation_dataset = Cityscapes(
os.path.join(self._data_dir, "cityscapes"),
split="val",
mode="fine",
target_type="semantic",
)
self._test_dataset = Cityscapes(
os.path.join(self._data_dir, "cityscapes"),
split="test",
mode="fine",
target_type="semantic",
)
else:
raise ValueError("Dataset not supported.")
[docs]
class TorchvisionSegmentationDataModuleCityscapes(TorchvisionSegmentationDataModule):
"""Data module for the Cityscapes dataset."""
inputs_dim = (3, 512, 256)
mean = [0.28689554, 0.32513303, 0.28389177]
std = [0.18696375, 0.19017339, 0.18720214]
ignore_indices = [i for i, c in enumerate(Cityscapes.classes) if c.ignore_in_eval]
outputs_dim = len(Cityscapes.classes)
targets_dim = (outputs_dim, 512, 256)
def __init__(self, **kwargs: Any) -> None:
super().__init__("cityscapes", **kwargs)
def _decode_target_to_rgb(self, target: torch.Tensor) -> np.ndarray:
"""Decode the integer target to RGB.
The target is a tensor of shape `(num_classes, height, width)` with integer values."""
r = torch.zeros_like(target)
g = torch.zeros_like(target)
b = torch.zeros_like(target)
for c in range(self.outputs_dim):
r[target == c] = Cityscapes.classes[c].color[0]
g[target == c] = Cityscapes.classes[c].color[1]
b[target == c] = Cityscapes.classes[c].color[2]
rgb = np.zeros((target.shape[0], target.shape[1], 3))
rgb[:, :, 0] = r / 255.0
rgb[:, :, 1] = g / 255.0
rgb[:, :, 2] = b / 255.0
return rgb
[docs]
@torch.no_grad()
def plot(
self, tester: LightningModule, save_path: str, specific_name: str = ""
) -> None:
"""Plot random samples from the training and validation set to check if the data is correctly predicted"""
fig, axs = plt.subplots(2, 6, figsize=(20, 10))
idx = np.random.randint(0, len(self._train_dataset))
x, y = self._train_dataset[idx]
labels, total, aleatoric, epistemic = self._get_prediction(
tester, x, y, TRAIN_KEY
)
axs[0, 0].imshow(self._denormalize(x).permute(1, 2, 0))
axs[0, 0].set_title("Input")
axs[0, 0].axis("off")
axs[0, 1].imshow(self._decode_target_to_rgb(y))
axs[0, 1].set_title("Target")
axs[0, 1].axis("off")
axs[0, 2].imshow(self._decode_target_to_rgb(labels))
axs[0, 2].set_title("Prediction")
axs[0, 2].axis("off")
axs[0, 3].imshow(total, cmap="jet")
axs[0, 3].set_title("Total uncertainty")
axs[0, 3].axis("off")
axs[0, 4].imshow(aleatoric, cmap="jet")
axs[0, 4].set_title("Aleatoric uncertainty")
axs[0, 4].axis("off")
axs[0, 5].imshow(epistemic, cmap="jet")
axs[0, 5].set_title("Epistemic uncertainty")
axs[0, 5].axis("off")
if self._validation_dataset is not None:
idx = np.random.randint(0, len(self._validation_dataset))
x, y = self._validation_dataset[idx]
labels, total, aleatoric, epistemic = self._get_prediction(
tester, x, y, VALIDATION_KEY
)
axs[1, 0].imshow(self._denormalize(x).permute(1, 2, 0))
axs[1, 0].set_title("Input")
axs[1, 0].axis("off")
axs[1, 1].imshow(self._decode_target_to_rgb(y))
axs[1, 1].set_title("Target")
axs[1, 1].axis("off")
axs[1, 2].imshow(self._decode_target_to_rgb(labels))
axs[1, 2].set_title("Prediction")
axs[1, 2].axis("off")
axs[1, 3].imshow(total, cmap="jet")
axs[1, 3].set_title("Total uncertainty")
axs[1, 3].axis("off")
axs[1, 4].imshow(aleatoric, cmap="jet")
axs[1, 4].set_title("Aleatoric uncertainty")
axs[1, 4].axis("off")
axs[1, 5].imshow(epistemic, cmap="jet")
axs[1, 5].set_title("Epistemic uncertainty")
axs[1, 5].axis("off")
plt.tight_layout()
plt.savefig(plots_file(save_path, specific_name), bbox_inches="tight")
plt.close(fig)