Source code for yamle.losses

from typing import Type, Callable, Optional

from yamle.losses.loss import DummyLoss
from yamle.losses.classification import CrossEntropyLoss, TextCrossEntropyLoss
from yamle.losses.contrastive import NoiseContrastiveEstimatorLoss
from yamle.losses.regression import (
    GaussianNegativeLogLikelihoodLoss,
    MeanSquaredError,
    QuantileRegressionLoss,
)
from yamle.losses.segmentation import FocalLoss, SoftIntersectionOverUnionLoss
from yamle.losses.evidential_regression import EvidentialRegressionLoss

AVAILABLE_LOSSES = {
    "crossentropy": CrossEntropyLoss,
    "nce": NoiseContrastiveEstimatorLoss,
    "textcrossentropy": TextCrossEntropyLoss,
    "gaussiannll": GaussianNegativeLogLikelihoodLoss,
    "mse": MeanSquaredError,
    "quantile": QuantileRegressionLoss,
    "focal": FocalLoss,
    "softiou": SoftIntersectionOverUnionLoss,
    "evidentialregression": EvidentialRegressionLoss,
    None: DummyLoss,
    "dummy": DummyLoss,
}


[docs] def loss_factory(loss_type: Optional[str] = None) -> Type[Callable]: """This function is used to create a loss instance based on the loss type. Args: loss_type (str): The type of loss to create. """ if loss_type not in AVAILABLE_LOSSES: raise ValueError(f"Unknown loss type {loss_type}.") return AVAILABLE_LOSSES[loss_type]