Source code for yamle.methods.dun

from yamle.utils.operation_utils import average_predictions
from yamle.defaults import (
    TINY_EPSILON,
    LOSS_KEY,
    LOSS_KL_KEY,
    TARGET_KEY,
    PREDICTION_KEY,
    MEAN_PREDICTION_KEY,
    TRAIN_KEY,
    PREDICTION_PER_MEMBER_KEY,
    TARGET_PER_MEMBER_KEY,
    VALIDATION_KEY,
    TEST_KEY,
    MEMBERS_DIM,
    INPUT_KEY,
    AVERAGE_WEIGHTS_KEY,
    MIN_TENDENCY,
)
from yamle.utils.regularizer_utils import disable_regularizer
from yamle.evaluation.metrics.algorithmic import metrics_factory
from yamle.methods.uncertain_method import MemberMethod
from typing import Any, List, Dict, Optional
import torch
import torchmetrics
import argparse
import logging

logging = logging.getLogger("pytorch_lightning")


[docs] class DUNMethod(MemberMethod): """This class is the extension of the base method for which the prediciton is performed through the method of: Depth Uncertainty in Neural Networks where the `_output` layer is used repatedly to get the prediction per each hidden layer. Args: alpha (float): The alpha parameter for the KL divergence. warm_starting_epochs (int): The number of epochs to train the model without changing the depth weights. """ def __init__( self, alpha: float, warm_starting_epochs: int, *args: Any, **kwargs: Any ) -> None: if not hasattr(kwargs["model"], "_depth"): raise ValueError( "The model should have a `_depth` attribute which is the number of hidden layers." ) kwargs["num_members"] = kwargs["model"]._depth + 1 kwargs["metrics_kwargs"]["num_members"] = kwargs["model"]._depth + 1 self._depth = kwargs["model"]._depth super().__init__(*args, **kwargs) logging.warning( f"The number of members is set to {self._num_members} because of the depth of the model." ) self._loss.set_reduction_per_member("sum") self._loss.set_reduction_per_sample("sum") logging.warning( f"The reduction per member is set to {self._loss._reduction_per_member} and the reduction per sample is set to {self._loss._reduction_per_sample}." ) assert ( alpha > 0 ), f"The alpha parameter should be greater than 0, but it is {alpha}." self._alpha = alpha if hasattr(self.model, "_alphas"): raise ValueError("The model should not have an `_alphas` attribute.") if hasattr(self.model, "_prior_betas"): raise ValueError("The model should not have an `_prior_betas` attribute.") self._warm_starting_epochs = warm_starting_epochs self.model._alphas = torch.nn.Parameter( torch.zeros(self._num_members, requires_grad=True) / (self._num_members) ) self._alphas_container = torch.empty((0, self._num_members)) self.model.register_buffer( "_prior_betas", torch.ones(self._num_members) / self._num_members ) disable_regularizer(self.model._alphas) self.model.add_method_specific_layers(method="dun") @property def alphas(self) -> torch.Tensor: """This method is used to get the alphas of the model.""" if self.current_epoch < self._warm_starting_epochs and self.training: return self.prior_betas return torch.softmax(self.model._alphas, dim=0) @property def prior_betas(self) -> torch.Tensor: """This method is used to get the prior betas of the model.""" return self.model._prior_betas def _kl_divergence(self) -> torch.Tensor: """This method is used to compute the kl divergence between the prior and the posterior.""" return torch.sum( self.alphas * torch.log(self.alphas / self.prior_betas + TINY_EPSILON) ) def _create_metrics(self, metrics_kwargs: Dict[str, Any]) -> None: """This method is used to create the metrics to be used for training, validation and testing.""" self.metrics = { TRAIN_KEY: metrics_factory(**metrics_kwargs, per_member=True), VALIDATION_KEY: metrics_factory(**metrics_kwargs, per_member=True), TEST_KEY: metrics_factory(**metrics_kwargs, per_member=True), } self._add_additional_metrics( {LOSS_KL_KEY: torchmetrics.MeanMetric()}, tendencies=[MIN_TENDENCY] ) def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor: """This method is used to perform a forward pass of the model. It is done with respect to the number of hidden layers or how hidden layers are being defined in the underlying `self.model`. """ last_layer, stages = self.model(x, staged_output=True, **forward_kwargs) # Since the last layer uses the last hidden layer # we can remove it stages = stages[:-1] outputs = [] for i, h in enumerate(stages): h = self.model._reshaping_layers[i](h) h = self.model.final_layer(h) outputs.append(h) outputs.append(last_layer) return torch.stack(outputs, dim=MEMBERS_DIM) def _step( self, batch: List[torch.Tensor], batch_idx: int, optimizer_idx: Optional[int] = None, phase: str = TRAIN_KEY, ) -> Dict[str, torch.Tensor]: """This method is used to perform a single step.""" x, y = batch y = torch.stack([y for _ in range(self._num_members)], dim=MEMBERS_DIM) y_hat = self._predict(x) outputs = {} if self.training: N_train = self._datamodule.train_dataset_size() batch_size = x.shape[0] loss = -(N_train / batch_size) * self._loss_per_member( y_hat, y, weights_per_member=self.alphas ) kl_divergence = self._kl_divergence() loss = loss - self._alpha * kl_divergence loss = -loss / N_train else: loss = torch.tensor(0.0, device=y_hat.device) kl_divergence = torch.tensor(0.0, device=y_hat.device) y_hat_permember = y_hat.detach() y_permember = y.detach() y = y[:, 0] y_hat_mean = average_predictions( y_hat, self._task, weights=torch.stack([self.alphas for _ in range(y_hat.shape[0])], dim=0), ) outputs[LOSS_KEY] = loss outputs[LOSS_KL_KEY] = kl_divergence.detach() outputs[TARGET_KEY] = y.detach() outputs[INPUT_KEY] = x.detach() outputs[PREDICTION_KEY] = y_hat.detach() outputs[MEAN_PREDICTION_KEY] = y_hat_mean.detach() outputs[PREDICTION_PER_MEMBER_KEY] = y_hat_permember.detach() outputs[TARGET_PER_MEMBER_KEY] = y_permember.detach() outputs[AVERAGE_WEIGHTS_KEY] = torch.stack( [self.alphas for _ in range(y_hat.shape[0])], dim=0 ).detach() return outputs
[docs] def on_train_epoch_end(self) -> None: super().on_train_epoch_end() logging.info(f"Alphas: {self.alphas}") self._alphas_container = torch.cat( [self._alphas_container, self.alphas.unsqueeze(0).detach().cpu()], dim=0 )
[docs] def state_dict(self) -> Dict[str, Any]: """This method is used to get the state of the method.""" state_dict = super().state_dict() state_dict["alphas_container"] = self._alphas_container return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """This method is used to load the state of the method.""" super().load_state_dict(state_dict) self._alphas_container = state_dict["alphas_container"]
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the specific arguments for the DUN method.""" parser = super(DUNMethod, DUNMethod).add_specific_args(parent_parser) parser.add_argument( "--method_alpha", type=float, default=1.0, help="The alpha to be used for the trade-off between the likelihood and the prior.", ) parser.add_argument( "--method_warm_starting_epochs", type=int, default=0, help="The number of epochs to use the prior betas.", ) return parser