Source code for yamle.data.dataset_wrappers
from typing import Any, Callable, Optional
import torch
import torchvision.transforms.functional as F
from PIL import Image
from torch.utils.data import Dataset
[docs]
class SurrogateDataset(Dataset):
"""This class is a dataset wrapper, ensuring that the transforms are applied to the data and targets
after splitting the dataset into training and validation.
Args:
dataset (Dataset): Dataset to wrap.
transform (Optional[Callable]): Transformations to apply to the data.
target_transform (Optional[Callable]): Transformations to apply to the targets.
joint_transform (Optional[Callable]): Transformations to apply to the input as well as the target.
"""
def __init__(
self,
dataset: Dataset,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
joint_transform: Optional[Callable] = None,
) -> None:
self._dataset = dataset
self._transform = transform
self._target_transform = target_transform
self._joint_transform = joint_transform
def __getitem__(self, index: int) -> Any:
data, target = self._dataset[index]
if self._joint_transform is not None:
data, target = self._joint_transform(data, target)
if self._transform is not None:
data = self._transform(data)
if self._target_transform is not None:
target = self._target_transform(target)
return data, target
def __len__(self) -> int:
return len(self._dataset)
[docs]
class ImageRotationDataset(Dataset):
"""This class is a dataset wrapper for image rotation.
It discards the target and replaces it with the rotation angle which should be predicted.
This changes the task from anything to regression.
Args:
dataset (Dataset): Dataset to wrap.
max_angle (float): Maximum angle to rotate the image by. Defaults to 90 degrees.
min_angle (float): Minimum angle to rotate the image by. Defaults to 0 degrees.
seed (int): Seed for the random number generator. Defaults to 42.
"""
def __init__(
self,
dataset: Dataset,
max_angle: float = 90,
min_angle: float = 0,
seed: int = 42,
) -> None:
self._dataset = dataset
self._max_angle = max_angle
self._min_angle = min_angle
self._seed = seed
self._generator = torch.Generator().manual_seed(self._seed)
def __getitem__(self, index: int) -> Any:
data, _ = self._dataset[index]
angle = torch.randint(
self._min_angle, self._max_angle, (1,), generator=self._generator
)
# Scale the angle to between 0 and 1 through min-max scaling
scaled_angle = (angle - self._min_angle) / (self._max_angle - self._min_angle)
data = F.rotate(data, angle.item())
return data, scaled_angle
def __len__(self) -> int:
return len(self._dataset)