yamle.losses.contrastive module#

class yamle.losses.contrastive.NoiseContrastiveEstimatorLoss(temperature=1.0, similarity='cosine', *args, **kwargs)[source]#

Bases: BaseLoss

This defines the noise contrastive estimation (NCE) loss.

It assumes that the input shape is (batch_size, num_members, num_classes). No matter what the reduction it is always averaged over the num_members.

Parameters:
  • temperature (float) – The temperature to use for the softmax. Defaults to 1.0.

  • similarity (str) – The similarity function to use. Defaults to cosine. Choices are cosine and dot.

static add_specific_args(parent_parser)[source]#

This method is used to add the loss specific arguments to the parent parser.

Return type:

ArgumentParser