from argparse import ArgumentParser
from typing import Tuple, Union, Any
import torch
from torch.utils.data import TensorDataset
from pytorch_lightning import LightningModule
import scienceplots
import matplotlib.pyplot as plt
plt.style.use("science")
from yamle.data.custom import ECG5000DataModule
from yamle.defaults import (
RECONSTRUCTION_KEY,
TRAIN_KEY,
VALIDATION_KEY,
TEST_KEY,
INPUT_KEY,
TARGET_KEY,
PREDICTION_KEY,
AVERAGE_WEIGHTS_KEY,
)
from yamle.utils.operation_utils import regression_uncertainty_decomposition
from yamle.utils.file_utils import plots_file
[docs]
class ECG5000ReconstructionDataModule(ECG5000DataModule):
"""Reconstruction data module for the ECG5000 dataset.
If `anomaly` is set to `True`, then all the anomalous cases from the training set are appended to
the test set. This can be used to create an autoencoder trained only on normal cases, being unable
to reconstruct the anomalous cases. Thus training an autoencoder for anomaly detection.
Args:
anomaly (bool): If `True`, then all the anomalous cases from the training set are appended to
the test set. Default: `False`.
"""
outputs_dim = 2
outputs_dtype = torch.float32
inputs_dtype = torch.float32
targets_dim = 140
task = RECONSTRUCTION_KEY
def __init__(self, anomaly: bool = False, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
if self._train_target_transform is None:
self._train_target_transform = ["datanormalize"]
if self._test_target_transform is None:
self._test_target_transform = ["datanormalize"]
self._anomaly = anomaly
self._train_classes: torch.Tensor = None # The classes of the train data
self._test_classes: torch.Tensor = None # The classes of the test data
[docs]
def prepare_data(self) -> None:
"""Download and prepare the data"""
# The train and test data are pandas dataframes
train_data, test_data = super().prepare_data()
# Separate the inputs and classes for the train and test data
train_inputs = train_data.iloc[:, 1:]
train_classes = train_data.iloc[:, 0] - 1 # The classes start from 1
test_inputs = test_data.iloc[:, 1:]
test_classes = test_data.iloc[:, 0] - 1 # The classes start from 1
# Convert the inputs and classes to tensors
train_inputs = torch.from_numpy(train_inputs.to_numpy()).float()
train_classes = torch.from_numpy(train_classes.to_numpy()).long()
test_inputs = torch.from_numpy(test_inputs.to_numpy()).float()
test_classes = torch.from_numpy(test_classes.to_numpy()).long()
if self._anomaly:
# Append the anomalous cases to the test set
anomalous_cases = train_classes != 0
test_inputs = torch.cat([test_inputs, train_inputs[anomalous_cases]])
test_classes = torch.cat([test_classes, train_classes[anomalous_cases]])
# Remove the anomalous cases from the training set
train_inputs = train_inputs[~anomalous_cases]
train_classes = train_classes[~anomalous_cases]
train_targets = train_inputs
test_targets = test_inputs
self._train_inputs = train_inputs
self._test_classes = test_classes
# Convert them into TensorDatasets
self._train_dataset = TensorDataset(
train_inputs.unsqueeze(1)
.float()
.permute(0, 2, 1),
train_targets.unsqueeze(1)
.float()
.permute(0, 2, 1),
)
self._test_dataset = TensorDataset(
test_inputs.unsqueeze(1)
.float()
.permute(0, 2, 1),
test_targets.unsqueeze(1)
.float()
.permute(0, 2, 1),
)
# Calculate the mean and standard deviation of the training data
self._data_mean, self._data_std = self._mean_std(self._train_dataset, index=0)
# Calculate the maximum and minimum of the training data
self._data_max, self._data_min = self._max_min(self._train_dataset, index=0)
def _get_prediction(
self,
tester: LightningModule,
x: torch.Tensor,
y: Union[torch.Tensor, int],
phase: str = TRAIN_KEY,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns the prediction, input and target for the given input and target."""
super()._get_prediction(tester, x, y, phase)
x = x.to(tester.device)
y = (
torch.tensor(y).long().to(tester.device)
if isinstance(y, int)
else y.long().to(tester.device)
)
if phase == TRAIN_KEY:
output = tester.training_step([x, y], batch_idx=0)
elif phase == VALIDATION_KEY:
output = tester.validation_step([x, y], batch_idx=0)
elif phase == TEST_KEY:
output = tester.test_step([x, y], batch_idx=0)
y_hat = output[PREDICTION_KEY].cpu()
x = output[INPUT_KEY].cpu()
y = output[TARGET_KEY].cpu()
average_weights = (
output[AVERAGE_WEIGHTS_KEY].cpu() if AVERAGE_WEIGHTS_KEY in output else None
)
return y_hat, x, y, average_weights
[docs]
def plot(
self, tester: LightningModule, save_path: str, specific_name: str = ""
) -> None:
"""Plots the dataset."""
super().plot(tester, save_path, specific_name)
fig, axs = plt.subplots(nrows=3, ncols=10, figsize=(60, 15))
for i, dataloader in enumerate(
[
self.train_dataloader(),
self.validation_dataloader(),
self.test_dataloader(),
]
):
inputs, targets = next(iter(dataloader))
outputs = self._get_prediction(tester, inputs, targets, TEST_KEY)
y_hat, _, _, average_weights = outputs
(
mean,
predictive_variance,
aleatoric_variance,
epistemic_variance,
) = regression_uncertainty_decomposition(y_hat, weights=average_weights)
# Convert the variance to standard deviation
predictive_variance = torch.sqrt(predictive_variance)
aleatoric_variance = torch.sqrt(aleatoric_variance)
epistemic_variance = torch.sqrt(epistemic_variance)
# Renormalize the data
inputs = inputs * self._data_std + self._data_mean
targets = targets * self._data_std + self._data_mean
mean = mean * self._data_std.squeeze() + self._data_mean.squeeze()
predictive_variance = (
predictive_variance * self._data_std.squeeze()
+ self._data_mean.squeeze()
)
aleatoric_variance = (
aleatoric_variance * self._data_std.squeeze()
+ self._data_mean.squeeze()
)
epistemic_variance = (
epistemic_variance * self._data_std.squeeze()
+ self._data_mean.squeeze()
)
# Plot the input and the target
# Also plot the variance as a shaded region
# Plot the aleatoric variance as a shaded region
# Plot the data
for j in range(10):
axs[i, j].plot(inputs[j, :], color="black", label="Input")
axs[i, j].plot(targets[j, :], color="red", label="Target")
axs[i, j].plot(mean[j, :], color="blue", linestyle="--", label="Mean")
axs[i, j].fill_between(
torch.arange(140),
mean[j, :] - predictive_variance[j, :],
mean[j, :] + predictive_variance[j, :],
color="blue",
alpha=0.3,
label="Predictive Uncertainty",
)
axs[i, j].fill_between(
torch.arange(140),
mean[j, :] - aleatoric_variance[j, :],
mean[j, :] + aleatoric_variance[j, :],
color="orange",
alpha=0.3,
label="Aleatoric Uncertainty",
)
axs[i, j].grid()
if i == 0 and j == 0:
axs[i, j].legend()
plt.savefig(plots_file(save_path, specific_name), bbox_inches="tight")
plt.close(fig)
plt.clf()
[docs]
@staticmethod
def add_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = super(
ECG5000ReconstructionDataModule, ECG5000ReconstructionDataModule
).add_specific_args(parent_parser)
parser.add_argument(
"--datamodule_anomaly",
type=int,
choices=[0, 1],
default=0,
help="If `True`, then all the anomalous cases from the training set are appended to the test set. Default: `False`.",
)
return parser