Source code for yamle.regularizers.feature
from typing import Any, List
from yamle.regularizers.regularizer import BaseRegularizer
import torch
import argparse
[docs]
class L1FeatureRegularizer(BaseRegularizer):
"""This is a class for L1 regularization for the output features."""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
batch_size = x.shape[0]
return torch.abs(x).sum() / batch_size
def __repr__(self) -> str:
return f"L1FeatureRegularizer()"
[docs]
class L2FeatureRegularizer(BaseRegularizer):
"""This is a class for L2 regularization for the output features."""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
batch_size = x.shape[0]
return (torch.sum(x**2) * 0.5) / batch_size
def __repr__(self) -> str:
return f"L2Feature()"
[docs]
class InnerProductFeatureRegularizer(BaseRegularizer):
"""This is a class for inner product regularization.
Given a tensor `x` which can be split in dimension `dim` into `n` tensors `x_1, ..., x_n`, the regularization loss is calculated as:
`loss = sum_{i=1}^{n} sum_{j=i+1}^{n} x_i * x_j`
`loss = loss / (n*(n-1)/2)`
Args:
dim (int): The dimension over which split the tensor to then calculate the inner product as a cartesian product.
"""
def __init__(self, dim: int = 1, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._dim = dim
def _split_and_reshape_tensor_on_dim(self, x: torch.Tensor) -> List[torch.Tensor]:
"""This method is used to split the tensor on the given dimension and then reshape it."""
batch_size = x.shape[0]
x = torch.split(x, 1, dim=self._dim)
x = [x_.squeeze(dim=self._dim).view(batch_size, -1) for x_ in x]
return x
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
loss = 0.0
x = self._split_and_reshape_tensor_on_dim(x)
for i in range(len(x)):
for j in range(i + 1, len(x)):
loss += torch.sum(x[i] * x[j], dim=1).mean()
return loss / (len(x) * (len(x) - 1) / 2)
[docs]
@staticmethod
def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""This method is used to add specific arguments to the parser."""
parser = super(
InnerProductFeatureRegularizer, InnerProductFeatureRegularizer
).add_specific_args(parser)
parser.add_argument(
"--regularizer_dim",
type=int,
default=1,
help="The dimension over which split the tensor to then calculate the inner product as a cartesian product.",
)
return parser
def __repr__(self) -> str:
return f"InnerProductFeatureRegularizer(dim={self._dim})"
[docs]
class CosineSimilarityFeatureRegularizer(InnerProductFeatureRegularizer):
"""This is a class for cosine similarity regularization.
Given a tensor `x` which can be split in dimension `dim` into `n` tensors `x_1, ..., x_n`, the regularization loss is calculated as:
`loss = sum_{i=1}^{n} sum_{j=i+1}^{n} cos(x_i, x_j)`
`loss = loss / (n*(n-1)/2)`
The `cos` function is the cosine similarity between `x_i` and `x_j`.
`cos(x_i, x_j) = x_i * x_j / (||x_i|| * ||x_j||)`
"""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
loss = 0.0
x = self._split_and_reshape_tensor_on_dim(x)
for i in range(len(x)):
for j in range(i + 1, len(x)):
loss += torch.cosine_similarity(x[i], x[j], dim=1).mean()
return loss / (len(x) * (len(x) - 1) / 2)
def __repr__(self) -> str:
return f"CosineSimilarityFeatureRegularizer(dim={self._dim})"
[docs]
class CorrelationFeatureRegularizer(CosineSimilarityFeatureRegularizer):
"""This is a class for correlation regularization.
Correlation is the cosine similarity between centered versions of x and y.
Unlike the cosine, the correlation is invariant to both scale and location changes of x and y.
"""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
x = self._split_and_reshape_tensor_on_dim(x)
for i in range(len(x)):
x[i] = x[i] - x[i].mean(dim=1, keepdim=True)
x = torch.stack(x, dim=self._dim)
return super().__call__(x)
def __repr__(self) -> str:
return f"CorrelationFeatureRegularizer(dim={self._dim})"