Source code for yamle.utils.regularizer_utils

from typing import Union, List

import torch.nn as nn
from yamle.defaults import DISABLED_REGULARIZER_KEY


[docs] def disable_regularizer(parameters: Union[List[nn.Parameter], nn.Parameter]) -> None: """This method is used to disable weight decay for the given parameters.""" if isinstance(parameters, nn.Parameter): setattr(parameters, DISABLED_REGULARIZER_KEY, True) elif isinstance(parameters, list): for param in parameters: setattr(param, DISABLED_REGULARIZER_KEY, True) else: raise ValueError( f"The parameters should be either a list of parameters or a single parameter. Got {type(parameters)}." )
[docs] def enable_regularizer(parameters: Union[List[nn.Parameter], nn.Parameter]) -> None: """This method is used to enable weight decay for the given parameters.""" if isinstance(parameters, nn.Parameter): setattr(parameters, DISABLED_REGULARIZER_KEY, False) elif isinstance(parameters, list): for param in parameters: setattr(param, DISABLED_REGULARIZER_KEY, False) else: raise ValueError( f"The parameters should be either a list of parameters or a single parameter. Got {type(parameters)}." )