from typing import Any, Optional
import torch
import argparse
import torch.nn.functional as F
from yamle.losses.loss import BaseLoss
from yamle.defaults import TINY_EPSILON
[docs]
class NoiseContrastiveEstimatorLoss(BaseLoss):
"""This defines the noise contrastive estimation (NCE) loss.
It assumes that the input shape is `(batch_size, num_members, num_classes)`.
No matter what the reduction it is always averaged over the `num_members`.
Args:
temperature (float): The temperature to use for the softmax. Defaults to 1.0.
similarity (str): The similarity function to use. Defaults to `cosine`. Choices are `cosine` and `dot`.
"""
def __init__(
self,
temperature: float = 1.0,
similarity: str = "cosine",
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
assert similarity in [
"cosine",
"dot",
], f"Similarity function must be either `cosine` or `dot`. Got {similarity}."
assert (
temperature > 0
), f"Temperature must be greater than 0. Got {temperature}."
self._similarity = similarity
self._temperature = temperature
def _cosine_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the cosine similarity between two tensors."""
return F.cosine_similarity(x, y, dim=-1, eps=TINY_EPSILON)
def _dot_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the dot similarity between two tensors."""
return torch.matmul(x, y.transpose(-1, -2))
def _loss(
self,
y_hat: torch.Tensor,
y_hat_positive: torch.Tensor,
y_hat_negative: torch.Tensor,
) -> torch.Tensor:
"""Computes the NCE loss.
The `y_hat` tensor contains the default predictions for `x` samples. The shape is `(batch_size, num_classes)`.
The `y_hat_positive` tensor contains the predictions for the positive samples. The shape is `(batch_size, num_classes)`.
The `y_hat_negative` tensor contains the predictions for the negative samples. The shape is `(batch_size, K, num_classes)`.
"""
assert (
y_hat.shape == y_hat_positive.shape
), f"The shapes of the predictions do not match. Got {y_hat.shape}, {y_hat_positive.shape}."
assert (
y_hat.shape[0] == y_hat_negative.shape[0]
), f"The batch sizes of the predictions do not match. Got {y_hat.shape[0]}, {y_hat_negative.shape[0]}."
if self._similarity == "cosine":
similarity_fn = self._cosine_similarity
elif self._similarity == "dot":
similarity_fn = self._dot_similarity
else:
raise NotImplementedError(
f"Similarity function {self._similarity} is not implemented."
)
similarity_positive = (
similarity_fn(y_hat, y_hat_positive) / self._temperature
).exp()
similarity_negative = (
similarity_fn(
y_hat.unsqueeze(1).repeat(1, y_hat_negative.shape[1], 1), y_hat_negative
)
/ self._temperature
).exp()
loss = -torch.log(
similarity_positive
/ (
similarity_positive
+ torch.sum(similarity_negative, dim=1)
+ TINY_EPSILON
)
+ TINY_EPSILON
)
return loss
def __call__(
self,
y_hat: torch.Tensor,
y_hat_positive: torch.Tensor,
y_hat_negative: torch.Tensor,
weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""This method is used to compute the NCE loss."""
num_members = y_hat_positive.shape[1]
loss = 0.0
for i in range(num_members):
sample_loss = self._loss(
y_hat[:, i], y_hat_positive[:, i], y_hat_negative[:, i]
)
loss += self._process_sample_loss(sample_loss, i, weights)
return self._process_member_loss(loss, num_members)
def __repr__(self) -> str:
return f"NoiseContrastiveEstimatorLoss(reduction_per_sample={self._reduction_per_sample}, reduction_per_member={self._reduction_per_member}, reduction_per_feature={self._reduction_per_feature}, similarity={self._similarity})"
[docs]
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
parser = super(
NoiseContrastiveEstimatorLoss, NoiseContrastiveEstimatorLoss
).add_specific_args(parent_parser)
parser.add_argument(
"--loss_temperature",
type=float,
default=1.0,
help="The temperature to use for the softmax.",
)
parser.add_argument(
"--loss_similarity",
type=str,
choices=["cosine", "dot"],
default="cosine",
help="The similarity function to use.",
)
return parser