Source code for yamle.regularizers.weight
from typing import Any
from yamle.regularizers.regularizer import BaseRegularizer
import torch
import torch.nn as nn
[docs]
class L1Regularizer(BaseRegularizer):
"""This is a class for L1 regularization."""
def __call__(self, model: nn.Module) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
loss = torch.tensor(0.0, device=next(model.parameters()).device)
for param in self.get_parameters(model):
loss += torch.sum(torch.abs(param))
return loss
def __repr__(self) -> str:
return f"L1()"
[docs]
class L2Regularizer(BaseRegularizer):
"""This is a class for L2 regularization."""
def __call__(self, model: nn.Module) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
loss = torch.tensor(0.0, device=next(model.parameters()).device)
for param in self.get_parameters(model):
loss += torch.sum(param**2)
return loss * 0.5
def __repr__(self) -> str:
return f"L2()"
[docs]
class L1L2Regularizer(BaseRegularizer):
"""This is a class for combined L1 and L2 regularization."""
def __call__(self, model: nn.Module) -> torch.Tensor:
"""This method is used to calculate the regularization loss."""
loss = torch.tensor(0.0, device=next(model.parameters()).device)
for param in self.get_parameters(model):
loss += 0.5 * torch.sum(param**2)
loss += torch.sum(torch.abs(param))
return loss
def __repr__(self) -> str:
return f"L1L2()"
[docs]
class WeightDecayRegularizer(BaseRegularizer):
"""This is a class for weight decay regularization.
It is implemented in an inefficient manner to be compatible with any optimizer.
During the ``__call__`` method, the weights at time ``t`` are cached.
Then, during the `update_on_step` method, the weights, which were already updated by the optimizer, are further updated by weight decay.
The weight decay is applied as follows:
w_{t+1} = (1 - weight) * w_{t} - \eta * ∇L(w_{t+1})
Hence, after the optimization step, assuming that only ``w_{t+1} = w_{t} - \eta * ∇L(w_{t+1})`` was applied,
we need to apply the ``-weight * w_{t}`` term. The weight is scaled by the learning rate.
"""
def __call__(self, model: nn.Module) -> torch.Tensor:
"""This method is used to cache all the weight values *before* the optimization step.
This is done such that the weights can then be updated at the very end of the training batch.
"""
for param in self.get_parameters(model):
param._cached_weight = param.data.clone().detach()
return torch.tensor(0.0, device=next(model.parameters()).device)
[docs]
def on_after_training_step(
self, model: nn.Module, weight: float, lr: float, *args: Any, **kwargs: Any
) -> None:
"""This method is used to update the model on a given step."""
for param in self.get_parameters(model):
param.data.add_(param._cached_weight, alpha=-weight * lr)
# Reset the cached weight
del param._cached_weight
def __repr__(self) -> str:
return f"WeightDecay()"