Source code for yamle.losses.loss

from abc import ABC, abstractmethod
from typing import Any, Optional, Union, List

import torch
import argparse

from yamle.defaults import SUPPORTED_TASKS


[docs] class BaseLoss(ABC): """This is the base class for all losses. Args: reduction_per_member (str): The reduction to be used for the loss. Given that the input has the shape (batch_size, num_members, ...), the loss will be reduced over the `num_members` dimension. reduction_per_sample (str): The reduction to be used for the loss. Given that the input has the shape (batch_size, num_members, ...), the loss will be reduced over the `batch_size` dimension. reduction_per_feature (str): The reduction to be used for the loss. Given that the input has the shape (batch_size, num_members, num_features, ...), the loss will be reduced over the `num_features` dimension. """ tasks = SUPPORTED_TASKS def __init__( self, reduction_per_member: str = "mean", reduction_per_sample: str = "mean", reduction_per_feature: str = "none", task: str = None, ) -> None: self.set_reduction_per_member(reduction_per_member) self.set_reduction_per_sample(reduction_per_sample) self.set_reduction_per_feature(reduction_per_feature) assert task in self.tasks, f"Unknown task {task} for this loss. Must be one of {self.tasks}." @abstractmethod def __call__(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """This method is used to compute the loss. Args: y_hat (torch.Tensor): The predictions. y (torch.Tensor): The ground truth. """ raise NotImplementedError("This method is not implemented.") def _process_sample_loss( self, sample_loss: torch.Tensor, member: int, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """This method is used to process the loss per sample in the batch.""" if weights is not None: if weights.dim() == 1: sample_loss = sample_loss * weights elif weights.dim() == 2: sample_loss = sample_loss * weights[:, member] else: raise ValueError( f"Expected weights to be of dimension 1 or 2. Got {weights.dim()}." ) if self._reduction_per_sample == "mean": return sample_loss.mean(dim=0) elif self._reduction_per_sample == "sum": return sample_loss.sum(dim=0) elif self._reduction_per_sample == "none": return sample_loss else: raise ValueError( f"Expected reduction per sample to be one of 'mean', 'sum' or 'none'. Got {self._reduction_per_sample}." ) def _process_member_loss( self, member_loss: torch.Tensor, members: int ) -> torch.Tensor: """This method is used to process the loss per member.""" if self._reduction_per_member == "mean": return member_loss / members elif self._reduction_per_member == "sum": return member_loss elif self._reduction_per_member == "none": return member_loss else: raise ValueError( f"Expected reduction per member to be one of 'mean', 'sum' or 'none'. Got {self._reduction_per_member}." ) def _process_feature_loss( self, feature_loss: torch.Tensor, dim: Union[int, List[int]] ) -> torch.Tensor: """This method is used to process the loss per feature.""" if self._reduction_per_feature == "mean": return torch.mean(feature_loss, dim=dim) elif self._reduction_per_feature == "sum": return torch.sum(feature_loss, dim=dim) elif self._reduction_per_feature == "none": return feature_loss else: raise ValueError( f"Expected reduction per feature to be one of 'mean', 'sum' or 'none'. Got {self._reduction_per_feature}." ) def __repr__(self) -> str: return f"BaseLoss(reduction_per_member={self._reduction_per_member}, reduction_per_sample={self._reduction_per_sample}, reduction_per_feature={self._reduction_per_feature})"
[docs] def set_reduction_per_member(self, reduction_per_member: str) -> None: """This method is used to set the reduction per member.""" assert reduction_per_member in [ "mean", "sum", "none", ], f"Expected reduction per member to be one of 'mean', 'sum' or 'none'. Got {reduction_per_member}." self._reduction_per_member = reduction_per_member
[docs] def set_reduction_per_sample(self, reduction_per_sample: str) -> None: """This method is used to set the reduction per sample.""" assert reduction_per_sample in [ "mean", "sum", "none", ], f"Expected reduction per sample to be one of 'mean', 'sum' or 'none'. Got {reduction_per_sample}." self._reduction_per_sample = reduction_per_sample
[docs] def set_reduction_per_feature(self, reduction_per_feature: str) -> None: """This method is used to set the reduction per feature.""" assert reduction_per_feature in [ "mean", "sum", "none", ], f"Expected reduction per feature to be one of 'mean', 'sum' or 'none'. Got {reduction_per_feature}." self._reduction_per_feature = reduction_per_feature
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the loss specific arguments to the parent parser.""" parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--loss_reduction_per_member", type=str, default="sum", help="The reduction to be used for the loss. Given that the input has the shape (batch_size, num_members, ...), the loss will be reduced over the `num_members` dimension.", ) parser.add_argument( "--loss_reduction_per_sample", type=str, default="mean", help="The reduction to be used for the loss. Given that the input has the shape (batch_size, num_members, ...), the loss will be reduced over the `batch_size` dimension.", ) parser.add_argument( "--loss_reduction_per_feature", type=str, default="none", help="The reduction to be used for the loss. Given that the input has the shape (batch_size, num_members, ...), the loss will be reduced over the `...` dimension.", ) return parser
[docs] class DummyLoss(BaseLoss): """This is a dummy loss class which always returns 0.0 tensor.""" def __call__(self, *args: Any, **kwargs: Any) -> torch.Tensor: """This method is used to compute the loss. In this case it will just return 0 and put it on some device. """ # Iterate through all the arguments and if one of them is a tensor, get it's device device = "cpu" for arg in args: if isinstance(arg, torch.Tensor): device = arg.device break return torch.tensor(0.0).to(device) def __repr__(self) -> str: return f"DummyLoss()"