Source code for yamle.losses.segmentation

from typing import Any, List, Optional

import torch
import argparse

from yamle.losses.loss import BaseLoss
from yamle.defaults import TINY_EPSILON


[docs] class SoftIntersectionOverUnionLoss(BaseLoss): """This defines the soft intersection over union loss for semantic segmentation. It assumes that the input shape is `(batch_size, num_members, num_classes, height, width)`. No matter what the reduction it is always averaged over the `num_members`. The input is assumed to be probabilities. The loss can also be weighted by a weight tensor of shape `(batch_size)`. Args: factor (float): The softness factor. Defaults to 1.0. ignore_indices (List[int]): The indices to ignore. Defaults to []. """ def __init__( self, factor: float = 1.0, ignore_indices: List[int] = [], **kwargs: Any ) -> None: super().__init__(**kwargs) self._factor = factor self._ignore_indices = ignore_indices def __call__( self, y_hat: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """This method is used to compute the loss.""" num_members = y_hat.shape[1] loss = 0.0 for i in range(num_members): sample_loss = self._soft_iou(y_hat[:, i], y) loss += self._process_sample_loss(sample_loss, i, weights) return self._process_member_loss(loss, num_members) def _soft_iou(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """This method is used to compute the soft IoU loss.""" if y_hat.shape != y.shape: y_one_hot = torch.zeros( y_hat.shape, device=y_hat.device, requires_grad=False ).long() y = y_one_hot.scatter_(1, y.unsqueeze(1), 1) intersection = torch.sum(y_hat * y, dim=[2, 3]) union = torch.sum(y_hat, dim=[2, 3]) + torch.sum(y, dim=[2, 3]) - intersection if self._ignore_indices is not None: for i in self._ignore_indices: intersection[:, i] = 0 union[:, i] = 0 total_classes = y_hat.shape[1] - len(self._ignore_indices) return 1 - torch.sum( (intersection + self._factor) / (union + self._factor) / total_classes, dim=1, ) def __repr__(self) -> str: return f"SoftIntersectionOverUnionLoss(reduction_per_sample={self._reduction_per_sample}, reduction_per_member={self._reduction_per_member}, reduction_per_feature={self._reduction_per_feature}, factor={self._factor}, ignore_indices={self._ignore_indices})"
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the loss specific arguments to the parent parser.""" parser = super( SoftIntersectionOverUnionLoss, SoftIntersectionOverUnionLoss ).add_specific_args(parent_parser) parser.add_argument( "--loss_factor", type=float, default=1.0, help="The softness factor." ) parser.add_argument( "--loss_ignore_indices", type=str, default="[]", help="The indices to ignore.", ) return parser
[docs] class FocalLoss(BaseLoss): """This defines the focal loss for semantic segmentation. It assumes that the input shape is `(batch_size, num_members, num_classes, height, width)`. No matter what the reduction it is always averaged over the `num_members`. The input is assumed to be probabilities. The loss can also be weighted by a weight tensor of shape `(batch_size)`. Args: alpha (float): The alpha factor. Defaults to 0.25. gamma (float): The gamma factor. Defaults to 2.0. ignore_indices (List[int]): The indices to ignore. Defaults to []. """ def __init__( self, alpha: float = 0.25, gamma: float = 2.0, ignore_indices: List[int] = [], *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self._alpha = alpha self._gamma = gamma self._ignore_indices = ignore_indices def __call__( self, y_hat: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: """This method is used to compute the loss.""" num_members = y_hat.shape[1] loss = 0.0 for i in range(num_members): sample_loss = self._focal_loss(y_hat[:, i], y) loss += self._process_sample_loss(sample_loss, i, weights) return self._process_member_loss(loss, num_members) def _focal_loss(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """This method is used to compute the focal loss.""" if y_hat.shape != y.shape: y_one_hot = torch.zeros(y_hat.shape, device=y_hat.device).long() y = y_one_hot.scatter_(1, y.unsqueeze(1), 1) # Compute the focal loss focal = ( -self._alpha * y * torch.pow(1 - y_hat + TINY_EPSILON, self._gamma) * torch.log(y_hat + TINY_EPSILON) ) focal = torch.sum(focal, dim=(2, 3)) loss = 0.0 for i in range(y_hat.shape[1]): if i not in self._ignore_indices: loss += focal[:, i] return loss / (y_hat.shape[1] - len(self._ignore_indices)) def __repr__(self) -> str: return f"FocalLoss(reduction_per_sample={self._reduction_per_sample}, reduction_per_member={self._reduction_per_member}, reduction_per_feature={self._reduction_per_feature}, alpha={self._alpha}, gamma={self._gamma}, ignore_indices={self._ignore_indices})"
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the loss specific arguments to the parent parser.""" parser = super(FocalLoss, FocalLoss).add_specific_args(parent_parser) parser.add_argument( "--loss_alpha", type=float, default=0.25, help="The alpha factor." ) parser.add_argument( "--loss_gamma", type=float, default=2.0, help="The gamma factor." ) parser.add_argument( "--loss_ignore_indices", type=str, default="[]", help="The indices to ignore.", ) return parser