Source code for yamle.pruning.unstructured.magnitude

from typing import Any
import torch
import torch.nn as nn
import argparse

from yamle.pruning.pruner import BasePruner
from yamle.utils.pruning_utils import (
    get_all_prunable_weights,
    is_layer_prunable,
    is_parameter_prunable,
)


[docs] class UnstructuredMagnitudePruner(BasePruner): """This is the base class for unstructured magnitude-based pruning. It will prune the weights with the lowest absolute magnitude. The threshold is determined by the pruning percentage. The pruning percentage is the percentageage of weights to prune. Args: pruning_percentage (float): The percentageage of weights to prune. """ def __init__(self, percentage: float, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) assert 0.0 <= percentage <= 1.0, "Pruning percentage must be between 0 and 1." self._percentage = percentage def __call__(self, m: nn.Module) -> float: """This method is used to prune the model.""" # Get all the weights in the model weights = get_all_prunable_weights(m) # Find the magnitude of the weight at a given percentile threshold = torch.abs(weights).kthvalue(int(self._percentage * len(weights)))[0] # Prune the weights for module in m.modules(): if is_layer_prunable(module): for p in module.parameters(): if is_parameter_prunable(p): # Create a mask to prune the weights, `True` means prune mask = torch.abs(p.data) < threshold self.prune_parameter(p, mask) return threshold.item() def __repr__(self) -> str: return f"{self.__class__.__name__}(percentage={self._percentage})"
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the pruner specific arguments to the parent parser.""" parser = super( UnstructuredMagnitudePruner, UnstructuredMagnitudePruner ).add_specific_args(parent_parser) parser.add_argument( "--pruner_percentage", type=float, default=0.5, help="The percentageage of weights to prune.", ) return parser