Source code for yamle.regularizers

from typing import Type, Callable, Optional


from yamle.regularizers.regularizer import DummyRegularizer
from yamle.regularizers.feature import (
    L1FeatureRegularizer,
    L2FeatureRegularizer,
    InnerProductFeatureRegularizer,
    CorrelationFeatureRegularizer,
    CosineSimilarityFeatureRegularizer,
)
from yamle.regularizers.weight import (
    L1Regularizer,
    L2Regularizer,
    L1L2Regularizer,
    WeightDecayRegularizer,
)
from yamle.regularizers.gradient import GradientNoiseRegularizer
from yamle.regularizers.model import ShrinkAndPerturbRegularizer

AVAILABLE_REGULARIZERS = {
    "l1": L1Regularizer,
    "l2": L2Regularizer,
    "weight_decay": WeightDecayRegularizer,
    "l1l2": L1L2Regularizer,
    "l1_feature": L1FeatureRegularizer,
    "l2_feature": L2FeatureRegularizer,
    "inner_product_feature": InnerProductFeatureRegularizer,
    "correlation_feature": CorrelationFeatureRegularizer,
    "cosine_similarity_feature": CosineSimilarityFeatureRegularizer,
    "gradient_noise": GradientNoiseRegularizer,
    "shrink_and_perturb": ShrinkAndPerturbRegularizer,
    None: DummyRegularizer,
    "none": DummyRegularizer,
    "dummy": DummyRegularizer,
}


[docs] def regularizer_factory(regularizer_type: Optional[str] = None) -> Type[Callable]: """This function is used to create a regularizer instance based on the regularizer type. Args: regularizer_type (str): The type of regularizer to create. """ if regularizer_type not in AVAILABLE_REGULARIZERS: raise ValueError(f"Unknown regularizer type {regularizer_type}.") return AVAILABLE_REGULARIZERS[regularizer_type]