Source code for yamle.utils.pruning_utils
from typing import List, Union
import torch
import torch.nn as nn
from yamle.defaults import DISABLED_PRUNING_KEY
[docs]
def enable_pruning(m: Union[nn.Parameter, List[nn.Parameter]]) -> None:
"""Enable pruning for the given parameters.
Args:
m (Union[nn.Parameter, List[nn.Parameter]]): The parameters to enable pruning for.
"""
if isinstance(m, nn.Parameter):
setattr(m, DISABLED_PRUNING_KEY, False)
elif isinstance(m, list):
for param in m:
setattr(param, DISABLED_PRUNING_KEY, False)
else:
raise ValueError(
f"The parameters should be either a list of parameters or a single parameter. Got {type(m)}."
)
[docs]
def disable_pruning(m: Union[nn.Parameter, List[nn.Parameter]]) -> None:
"""Disable pruning for the given parameters.
Args:
m (Union[nn.Parameter, List[nn.Parameter]]): The parameters to disable pruning for.
"""
if isinstance(m, nn.Parameter):
setattr(m, DISABLED_PRUNING_KEY, True)
elif isinstance(m, list):
for param in m:
setattr(param, DISABLED_PRUNING_KEY, True)
else:
raise ValueError(
f"The parameters should be either a list of parameters or a single parameter. Got {type(m)}."
)
[docs]
def is_layer_prunable(layer: nn.Module) -> bool:
"""Check if a layer is prunable.
Args:
layer (nn.Module): The layer to check.
"""
return (
isinstance(layer, nn.Linear)
or issubclass(type(layer), nn.Linear)
or isinstance(layer, nn.Conv2d)
or issubclass(type(layer), nn.Conv2d)
)
[docs]
def is_parameter_prunable(param: Union[nn.Parameter, List[nn.Parameter]]) -> bool:
"""Check if a parameter is prunable.
Args:
param (Union[nn.Parameter, torch.Tensor]): The parameter to check.
"""
if isinstance(param, nn.Parameter):
if not hasattr(param, DISABLED_PRUNING_KEY):
return True
return hasattr(param, DISABLED_PRUNING_KEY) and not getattr(
param, DISABLED_PRUNING_KEY
)
elif isinstance(param, list):
for p in param:
if not is_parameter_prunable(p):
return False
return True
else:
raise ValueError(
f"The parameters should be either a list of parameters or a single parameter. Got {type(param)}."
)
[docs]
def get_all_prunable_weights(module: nn.Module) -> torch.Tensor:
"""Get all the prunable weights in the model.
All the parameters of the prunable layers will be flattened into a single vector.
These weights will be returned in a single Tensor.
Args:
m (nn.Module): The model to get the weights from.
"""
weights = []
for m in module.modules():
if is_layer_prunable(m):
for p in m.parameters():
if is_parameter_prunable(p):
weights.append(p.data.view(-1))
return torch.cat(weights)
[docs]
def get_all_prunable_modules(module: nn.Module) -> List[nn.Module]:
"""Get all the prunable layers in the model.
Args:
m (nn.Module): The model to get the layers from.
"""
layers = []
for m in module.modules():
if is_layer_prunable(m):
layers.append(m)
return layers