from typing import Any
import torch.nn as nn
from yamle.regularizers.regularizer import BaseRegularizer
import torch
import argparse
[docs]
class ShrinkAndPerturbRegularizer(BaseRegularizer):
"""This is a class for a shrink and perturb regularization.
It shrinks the weights by a factor of `l` and adds a noise sampled from a
normal distribution with mean 0 and standard deviation `std` to the weights at
a certain epoch frequency.
There is also a second argument which limits the starting epoch and the ending epoch
within which the shrink and perturb regularization is applied.
It follows the paper: https://arxiv.org/pdf/1910.08475.pdf
Args:
l (float): The factor by which the weights are shrunk.
std (float): The standard deviation of the normal distribution from which the noise is sampled.
start_epoch (int): The epoch at which the shrink and perturb regularization starts. Default is 0, which means that the regularization is applied from the beginning of the training.
end_epoch (int): The epoch at which the shrink and perturb regularization ends. Default is -1, which means that the regularization is applied until the end of the training.
epoch_frequency (int): The frequency at which the shrink and perturb regularization is applied.
"""
def __init__(
self,
l: float,
std: float,
start_epoch: int,
end_epoch: int,
epoch_frequency: int,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
assert 0 <= l <= 1, f"The shrink factor must be between 0 and 1, but got {l}."
assert (
std >= 0
), f"The standard deviation of the normal distribution must be non-negative, but got {std}."
assert (
start_epoch >= 0
), f"The start epoch must be non-negative, but got {start_epoch}."
assert (
end_epoch == -1 or end_epoch >= start_epoch
), f"The end epoch must be greater than or equal to the start epoch, but got {end_epoch} and {start_epoch}."
assert (
epoch_frequency > 0
), f"The epoch frequency must be greater than 0, but got {epoch_frequency}."
self._l = l
self._std = std
self._start_epoch = start_epoch
self._end_epoch = end_epoch
self._epoch_frequency = epoch_frequency
[docs]
def on_after_train_epoch(
self, model: nn.Module, epoch: int, *args: Any, **kwargs: Any
) -> None:
"""Add noise to the weights after a given training epoch.
For all parameters that require gradients, the weights are shrunk by a factor of `l` and a noise sampled from a
normal distribution with mean 0 and standard deviation `std` is added to the weights.
"""
if (
epoch >= self._start_epoch
and (epoch <= self._end_epoch or self._end_epoch == -1)
and epoch % self._epoch_frequency == 0
and epoch != 0
):
for param in model.parameters():
if param.requires_grad:
param.data = (
param.data * self._l + torch.randn_like(param.data) * self._std
)
[docs]
@staticmethod
def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""This method is used to add specific arguments to the parser."""
parser = super(
ShrinkAndPerturbRegularizer, ShrinkAndPerturbRegularizer
).add_specific_args(parser)
parser.add_argument(
"--regularizer_l",
type=float,
default=0.1,
help="The factor by which the weights are shrunk.",
)
parser.add_argument(
"--regularizer_std",
type=float,
default=0.1,
help="The standard deviation of the normal distribution from which the noise is sampled.",
)
parser.add_argument(
"--regularizer_start_epoch",
type=int,
default=0,
help="The epoch at which the shrink and perturb regularization starts.",
)
parser.add_argument(
"--regularizer_end_epoch",
type=int,
default=-1,
help="The epoch at which the shrink and perturb regularization ends.",
)
parser.add_argument(
"--regularizer_epoch_frequency",
type=int,
default=1,
help="The frequency at which the shrink and perturb regularization is applied.",
)
return parser