Source code for yamle.regularizers.gradient

from typing import Any

import torch.nn as nn
from yamle.defaults import TINY_EPSILON
from yamle.regularizers.regularizer import BaseRegularizer

import torch
import argparse


[docs] class GradientNoiseRegularizer(BaseRegularizer): """This is a class for a gradient noise regularization. It adds a noise sampled from a normal distribution with mean 0 and standard deviation `std` to the gradient. It follows the paper: https://arxiv.org/pdf/1511.06807.pdf Args: eta (float): The standard deviation of the normal distribution from which the noise is sampled. gamma (float): The factor by which the noise is multiplied. """ def __init__(self, eta: float, gamma: float, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert ( eta >= 0 ), "The standard deviation of the normal distribution must be non-negative." assert ( eta > 0 ), "The standard deviation of the normal distribution must be positive." assert 0 <= gamma <= 1, f"The factor must be between 0 and 1, but got {gamma}." self._eta = eta self._gamma = gamma def _var(self, epoch: int) -> float: """Return the variance of the noise at a given epoch.""" return self._eta / ((1 + epoch) ** self._gamma + TINY_EPSILON)
[docs] def on_after_backward( self, model: nn.Module, epoch: int, *args: Any, **kwargs: Any ) -> None: """Add noise to the gradients after the backward pass.""" var = self._var(epoch) for param in model.parameters(): if param.grad is not None: param.grad += torch.randn_like(param.grad) * var
[docs] @staticmethod def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """This method is used to add specific arguments to the parser.""" parser = super( GradientNoiseRegularizer, GradientNoiseRegularizer ).add_specific_args(parser) parser.add_argument( "--regularizer_eta", type=float, default=0.1, help="The standard deviation of the normal distribution from which the noise is sampled.", ) parser.add_argument( "--regularizer_gamma", type=float, default=0.55, help="The factor by which the noise is multiplied.", ) return parser
def __repr__(self) -> str: return f"GradientNoiseRegularizer(eta={self._eta}, gamma={self._gamma})"