Source code for yamle.losses.classification

from typing import Any, Optional, Union, List

import torch
import argparse

from yamle.losses.loss import BaseLoss
from yamle.defaults import TINY_EPSILON, CLASSIFICATION_KEY, TEXT_CLASSIFICATION_KEY, SEGMENTATION_KEY


[docs] def one_hot( y: torch.Tensor, num_classes: int, label_smoothing: float = 0.0 ) -> torch.Tensor: """One-hot encodes the target and directly applies label smoothing. Taken from: https://github.com/pytorch/pytorch/issues/7455#issuecomment-513062631 """ return ( torch.zeros(size=(y.size(0), num_classes, *y.size()[1:]), device=y.device) .fill_(label_smoothing / (num_classes - 1)) .scatter_(1, y.data.unsqueeze(1), 1.0 - label_smoothing) )
[docs] class CrossEntropyLoss(BaseLoss): """This defines the base cross-entropy loss. It assumes that the input shape is `(batch_size, num_members, num_classes)`. No matter what the reduction it is always averaged over the `num_members`. The target is assumed to be of shape `(batch_size)`, it needs to be one-hot encoded. The input is assumed to be probabilities. The loss can also be weighted by a weight tensor of shape `(batch_size)`. Args: label_smoothing (float): The amount of label smoothing to apply. one_hot_target (bool): Whether the target is already one-hot encoded. Defaults to `False`. class_weights (Optional[Union[torch.Tensor, List[float]]]): The weights to apply to each class. Defaults to `None`. flatten (bool): Whether to flatten the predictions and the targets. Defaults to `False`. """ tasks = [CLASSIFICATION_KEY, SEGMENTATION_KEY] def __init__( self, label_smoothing: float = 0.0, one_hot_target: bool = False, class_weights: Optional[Union[torch.Tensor, List[float]]] = None, flatten: bool = False, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) assert 0.0 <= label_smoothing <= 1.0, "Label smoothing must be between 0 and 1." self._label_smoothing = label_smoothing self._one_hot_target = one_hot_target self._class_weights = class_weights self._flatten = flatten def _loss(self, y_hat: torch.Tensor, y_one_hot: torch.Tensor) -> torch.Tensor: """Computes the negative log likelihood loss.""" # Check if `y_hat` is already probabilities. assert ( y_hat.shape == y_one_hot.shape ), f"The shapes of the predictions and the targets do not match. Got {y_hat.shape} and {y_one_hot.shape}." assert torch.all(y_hat >= 0.0) and torch.all( y_hat <= 1.1 ), f"The predictions are not probabilities. Got {y_hat[y_hat < 0.0]} or {y_hat[y_hat > 1.1]}." # Check that y_hat sums approximately to 1. assert torch.allclose( y_hat.sum(dim=1), torch.ones_like(y_hat.sum(dim=1)), atol=1e-2 ), f"The predictions do not sum to 1. Got {y_hat.sum(dim=1)}." class_weights = None if self._class_weights is not None: if isinstance(self._class_weights, list): self._class_weights = torch.tensor( self._class_weights, device=y_hat.device ) assert ( self._class_weights.shape[0] == y_hat.shape[-1] ), f"The number of class weights ({self._class_weights.shape[0]}) does not match the number of classes ({y_hat.shape[-1]})." # Check the device. self._class_weights = self._class_weights.to(y_hat.device) class_weights = self._class_weights else: class_weights = torch.ones(y_hat.shape[-1], device=y_hat.device) return -torch.sum( y_one_hot * class_weights * torch.log(y_hat + TINY_EPSILON), dim=1 ) def __call__( self, y_hat: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """This method is used to compute the loss.""" num_members = y_hat.shape[1] if self._flatten: B = y_hat.shape[0] C = y_hat.shape[2] y_hat = y_hat.reshape(B, num_members, C, -1) y = y.reshape(B, -1) loss = 0.0 one_hot_y = ( one_hot(y, y_hat.shape[-1], self._label_smoothing) if not self._one_hot_target else y ) for i in range(num_members): sample_loss = self._loss(y_hat[:, i], one_hot_y) loss += self._process_sample_loss(sample_loss, i, weights) return self._process_member_loss(loss, num_members) def __repr__(self) -> str: return f"CrossEntropyLoss(reduction_per_sample={self._reduction_per_sample}, reduction_per_member={self._reduction_per_member}, reduction_per_feature={self._reduction_per_feature}, label_smoothing={self._label_smoothing}, one_hot_target={self._one_hot_target}, class_weights={self._class_weights}"
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: parser = super(CrossEntropyLoss, CrossEntropyLoss).add_specific_args( parent_parser ) parser.add_argument( "--loss_label_smoothing", type=float, default=0.0, help="The amount of label smoothing to apply.", ) parser.add_argument( "--loss_one_hot_target", type=int, choices=[0, 1], default=0, help="Whether the target is already one-hot encoded.", ) parser.add_argument( "--loss_class_weights", type=str, default=None, help="The weights to apply to each class.", ) parser.add_argument( "--loss_flatten", type=int, choices=[0, 1], default=0, help="Whether to flatten the predictions and the targets.", ) return parser
[docs] class TextCrossEntropyLoss(CrossEntropyLoss): """This defines the base cross-entropy loss. It assumes that the input shape is `(batch_size, num_members, sequence_length, num_classes)`. No matter what the reduction it is always averaged over the `num_members`. The target is assumed to be of shape `(batch_size, sequence_length)`. The input is assumed to be probabilities. The loss can also be weighted by a weight tensor of shape `(batch_size)`. Args: label_smoothing (float): The amount of label smoothing to apply. """ tasks = [TEXT_CLASSIFICATION_KEY] def _loss(self, y_hat: torch.Tensor, y_one_hot: torch.Tensor) -> torch.Tensor: """Computes the negative log likelihood loss.""" # Check if `y_hat` is already probabilities. assert ( y_hat.shape == y_one_hot.shape ), f"The shapes of the predictions and the targets do not match. Got {y_hat.shape} and {y_one_hot.shape}." assert torch.all(y_hat >= 0.0) and torch.all( y_hat <= 1.1 ), f"The predictions are not probabilities. Got {y_hat[y_hat < 0.0]} or {y_hat[y_hat > 1.1]}." return -torch.sum(y_one_hot * torch.log(y_hat + TINY_EPSILON), dim=2) def __call__( self, y_hat: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """This method is used to compute the loss.""" num_members = y_hat.shape[1] loss = 0.0 y_original_shape = y.shape one_hot_y = one_hot(y.reshape(-1), y_hat.shape[-1], self._label_smoothing) one_hot_y = one_hot_y.reshape(y_original_shape[0], y_original_shape[1], -1) for i in range(num_members): sample_loss = self._loss(y_hat[:, i], one_hot_y).mean(dim=1) loss += self._process_sample_loss(sample_loss, i, weights) return self._process_member_loss(loss, num_members) def __repr__(self) -> str: return f"TextCrossEntropyLoss(reduction_per_sample={self._reduction_per_sample}, reduction_per_member={self._reduction_per_member}, reduction_per_feature={self._reduction_per_feature}, label_smoothing={self._label_smoothing})"