yamle.utils.pruning_utils module#
- yamle.utils.pruning_utils.enable_pruning(m)[source]#
Enable pruning for the given parameters.
- Parameters:
m¶ (Union[nn.Parameter, List[nn.Parameter]]) – The parameters to enable pruning for.
- Return type:
None
- yamle.utils.pruning_utils.disable_pruning(m)[source]#
Disable pruning for the given parameters.
- Parameters:
m¶ (Union[nn.Parameter, List[nn.Parameter]]) – The parameters to disable pruning for.
- Return type:
None
- yamle.utils.pruning_utils.is_layer_prunable(layer)[source]#
Check if a layer is prunable.
- Parameters:
layer¶ (nn.Module) – The layer to check.
- Return type:
bool
- yamle.utils.pruning_utils.is_parameter_prunable(param)[source]#
Check if a parameter is prunable.
- Parameters:
param¶ (Union[nn.Parameter, torch.Tensor]) – The parameter to check.
- Return type:
bool
- yamle.utils.pruning_utils.get_all_prunable_weights(module)[source]#
Get all the prunable weights in the model.
All the parameters of the prunable layers will be flattened into a single vector. These weights will be returned in a single Tensor.
- Parameters:
m¶ (nn.Module) – The model to get the weights from.
- Return type:
Tensor