yamle.utils.pruning_utils module#

yamle.utils.pruning_utils.enable_pruning(m)[source]#

Enable pruning for the given parameters.

Parameters:

m (Union[nn.Parameter, List[nn.Parameter]]) – The parameters to enable pruning for.

Return type:

None

yamle.utils.pruning_utils.disable_pruning(m)[source]#

Disable pruning for the given parameters.

Parameters:

m (Union[nn.Parameter, List[nn.Parameter]]) – The parameters to disable pruning for.

Return type:

None

yamle.utils.pruning_utils.is_layer_prunable(layer)[source]#

Check if a layer is prunable.

Parameters:

layer (nn.Module) – The layer to check.

Return type:

bool

yamle.utils.pruning_utils.is_parameter_prunable(param)[source]#

Check if a parameter is prunable.

Parameters:

param (Union[nn.Parameter, torch.Tensor]) – The parameter to check.

Return type:

bool

yamle.utils.pruning_utils.get_all_prunable_weights(module)[source]#

Get all the prunable weights in the model.

All the parameters of the prunable layers will be flattened into a single vector. These weights will be returned in a single Tensor.

Parameters:

m (nn.Module) – The model to get the weights from.

Return type:

Tensor

yamle.utils.pruning_utils.get_all_prunable_modules(module)[source]#

Get all the prunable layers in the model.

Parameters:

m (nn.Module) – The model to get the layers from.

Return type:

List[Module]