yamle.losses.classification module#

yamle.losses.classification.one_hot(y, num_classes, label_smoothing=0.0)[source]#

One-hot encodes the target and directly applies label smoothing.

Taken from: pytorch/pytorch#7455

Return type:

Tensor

class yamle.losses.classification.CrossEntropyLoss(label_smoothing=0.0, one_hot_target=False, class_weights=None, flatten=False, *args, **kwargs)[source]#

Bases: BaseLoss

This defines the base cross-entropy loss.

It assumes that the input shape is (batch_size, num_members, num_classes). No matter what the reduction it is always averaged over the num_members. The target is assumed to be of shape (batch_size), it needs to be one-hot encoded.

The input is assumed to be probabilities.

The loss can also be weighted by a weight tensor of shape (batch_size).

Parameters:
  • label_smoothing (float) – The amount of label smoothing to apply.

  • one_hot_target (bool) – Whether the target is already one-hot encoded. Defaults to False.

  • class_weights (Optional[Union[torch.Tensor, List[float]]]) – The weights to apply to each class. Defaults to None.

  • flatten (bool) – Whether to flatten the predictions and the targets. Defaults to False.

tasks = ['classification', 'segmentation']#
static add_specific_args(parent_parser)[source]#

This method is used to add the loss specific arguments to the parent parser.

Return type:

ArgumentParser

class yamle.losses.classification.TextCrossEntropyLoss(label_smoothing=0.0, one_hot_target=False, class_weights=None, flatten=False, *args, **kwargs)[source]#

Bases: CrossEntropyLoss

This defines the base cross-entropy loss.

It assumes that the input shape is (batch_size, num_members, sequence_length, num_classes). No matter what the reduction it is always averaged over the num_members. The target is assumed to be of shape (batch_size, sequence_length).

The input is assumed to be probabilities.

The loss can also be weighted by a weight tensor of shape (batch_size).

Parameters:

label_smoothing (float) – The amount of label smoothing to apply.

tasks = ['text_classification']#