Source code for yamle.quantization.quantizer

import argparse
import copy
import logging
from abc import ABC, abstractmethod
from typing import Any

import torch.nn as nn
from pytorch_lightning import LightningModule, Trainer

from yamle.defaults import FLOAT_MODEL_KEY, QUANTIZED_KEY, QUANTIZED_MODEL_KEY
from yamle.models.operations import Add
from yamle.models.specific.mcdropout import Dropout1d, Dropout2d, Dropout3d
from yamle.quantization.models.operations import QuantizableAdd
from yamle.quantization.models.specific.mcdropout import (
    QuantisedDropout1d,
    QuantisedDropout2d,
    QuantisedDropout3d,
)

logging = logging.getLogger("pytorch_lightning")


[docs] class BaseQuantizer(ABC): """This is the base class for all quantization methods. The quantizer's call method will be used to quantize the model. Args: activation_bits (int): The number of bits to use for the activation. weight_bits (int): The number of bits to use for the weight. """ def __init__(self, activation_bits: int, weight_bits: int) -> None: assert ( 0 <= activation_bits <= 8 ), "The number of bits for the activation must be between 0 and 8. Got {activation_bits}." assert ( 0 <= weight_bits <= 8 ), "The number of bits for the weight must be between 0 and 8. Got {weight_bits}." self._activation_bits = activation_bits self._weight_bits = weight_bits @abstractmethod def __call__(self, trainer: Trainer, method: LightningModule) -> None: """This method is used to quantize the model. A copy of the model is saved before quantization. First the model is prepared for quantization. Then the trainer is queried for the dataloader - this can be used to calibrate the model or fine-tune it. Then the model is quantized. The original model is kept such that it can be recovered. """ raise NotImplementedError("This method needs to be implemented.")
[docs] @abstractmethod def prepare(self, *args: Any, **kwargs: Any) -> None: """This method is used to prepare the model for quantization.""" raise NotImplementedError("This method needs to be implemented.")
[docs] @abstractmethod def get_qconfig(self) -> Any: """This method is used to get the quantization configuration.""" raise NotImplementedError("This method needs to be implemented.")
[docs] def cleanup(self, *args: Any, **kwargs: Any) -> None: """This method is used to clean up the model after quantization.""" pass
[docs] def save_original_model(self, method: LightningModule) -> None: """This method is used to create a copy of the original model.""" # Give a warning if original model already exists if hasattr(method, FLOAT_MODEL_KEY): logging.warning("Original model already exists. Overwriting it.") model_copy = copy.deepcopy(method.model.cpu()) setattr(method, FLOAT_MODEL_KEY, model_copy)
[docs] def save_quantized_model(self, method: LightningModule) -> None: """This method is used to save the quantized model.""" # Give a warning if quantized model already exists if hasattr(method, QUANTIZED_MODEL_KEY): logging.warning("Quantized model already exists. Overwriting it.") q_model = copy.deepcopy(method.model.cpu()) setattr(method, QUANTIZED_MODEL_KEY, q_model)
[docs] def recover(self, method: LightningModule) -> None: """This method is used to recover the original model.""" if hasattr(method, QUANTIZED_KEY) and getattr(method, QUANTIZED_KEY): setattr(method, QUANTIZED_MODEL_KEY, getattr(method, "model")) setattr(method, "model", getattr(method, FLOAT_MODEL_KEY)) delattr(method, FLOAT_MODEL_KEY) setattr(method, QUANTIZED_KEY, False)
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the pruner specific arguments to the parent parser.""" parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument( "--quantizer_activation_bits", type=int, default=8, help="The number of bits to use for the activation.", ) parser.add_argument( "--quantizer_weight_bits", type=int, default=8, help="The number of bits to use for the weight.", ) return parser
[docs] def replace_layers_for_quantization(self, model: nn.Module) -> None: """This function takes a model and replaces any special layers with their quantizable counterparts. e.g. Add -> FloatFunctional.add""" def _recursive_replace(module: nn.Module) -> None: for name, child in module.named_children(): if isinstance(child, Add): setattr(module, name, QuantizableAdd()) elif isinstance(child, Dropout1d): setattr(module, name, QuantisedDropout1d(module._p)) elif isinstance(child, Dropout2d): setattr(module, name, QuantisedDropout2d(module._p)) elif isinstance(child, Dropout3d): setattr(module, name, QuantisedDropout3d(module._p)) else: _recursive_replace(child) _recursive_replace(model) for m in model.modules(): if hasattr(m, "replace_layers_for_quantization"): m.replace_layers_for_quantization()
def __repr__(self) -> str: return f"{self.__class__.__name__}()"
[docs] class DummyQuantizer(BaseQuantizer): """This is a dummy quantizer that does not perform any quantization.""" def __call__(self, *args: Any, **kwargs: Any) -> None: pass
[docs] def prepare(self) -> None: pass
[docs] def get_qconfig(self) -> None: pass