Source code for yamle.losses.evidential_regression

"""
Adopted from: https://github.com/aamini/evidential-deep-learning/
"""

from typing import Any, Optional, Tuple

import torch
import torch.nn as nn
import numpy as np

from yamle.losses.loss import BaseLoss
from yamle.defaults import TINY_EPSILON, REGRESSION_KEY, DEPTH_ESTIMATION_KEY, RECONSTRUCTION_KEY


[docs] class NIG_NLL(nn.Module): """Negative log-likelihood loss for Normal Inverse Gamma (NIG) distribution."""
[docs] def forward( self, y: torch.Tensor, gamma: torch.Tensor, v: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: """Compute the loss function.""" twoBlambda = 2 * beta * (1 + v) nll = ( 0.5 * torch.log(torch.tensor(np.pi) / (v + TINY_EPSILON) + TINY_EPSILON) - alpha * torch.log(twoBlambda + TINY_EPSILON) + (alpha + 0.5) * torch.log(v * (y - gamma) ** 2 + twoBlambda + TINY_EPSILON) + torch.lgamma(alpha) - torch.lgamma(alpha + 0.5) ) return torch.mean(nll)
[docs] class NIG_Reg(nn.Module): """Regularization loss for Normal Inverse Gamma (NIG) distribution."""
[docs] def forward( self, y: torch.Tensor, gamma: torch.Tensor, v: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, ) -> torch.Tensor: """Compute the loss function.""" error = torch.abs(y - gamma) evi = 2 * v + (alpha) reg = error * evi return torch.mean(reg)
[docs] class EvidentialRegressionLoss(BaseLoss): """Evidential regression loss for probabilistic regression.""" tasks = [REGRESSION_KEY, DEPTH_ESTIMATION_KEY, RECONSTRUCTION_KEY] def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._nll = NIG_NLL() self._reg = NIG_Reg() def __call__( self, y_hat: torch.Tensor, y: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: gamma, v, alpha, beta = torch.split(y_hat, 1, dim=-1) loss_nll = self._nll(y, gamma, v, alpha, beta) loss_reg = self._reg(y, gamma, v, alpha, beta) return loss_nll, loss_reg def __repr__(self) -> str: return "EvidentialRegression()"