Source code for yamle.pruning

from typing import Optional, Type, Callable

from yamle.pruning.unstructured.magnitude import UnstructuredMagnitudePruner
from yamle.pruning.pruner import DummyPruner

AVAILABLE_PRUNERS = {
    "unstructured_magnitude": UnstructuredMagnitudePruner,
    None: DummyPruner,
    "dummy": DummyPruner,
    "none": DummyPruner,
}


[docs] def pruner_factory(pruner_type: Optional[str] = None) -> Type[Callable]: """This function is used to create a pruner instance based on the pruner type. Args: pruner_type (str): The type of pruner to create. """ if pruner_type not in AVAILABLE_PRUNERS: raise ValueError(f"Unknown pruner type {pruner_type}.") return AVAILABLE_PRUNERS[pruner_type]