Source code for yamle.quantization.qat

import argparse
import logging
from typing import Any

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import MovingAverageMinMaxObserver


from yamle.defaults import QUANTIZED_KEY
from yamle.quantization.quantizer import BaseQuantizer

logging = logging.getLogger("pytorch_lightning")


[docs] class QATQuantizer(BaseQuantizer): """This is the quantization-aware training quantizer class. It performs quantization-aware training on the model. In contrast to the static quantizer, the quantizer uses the calibration (validation) dataset to fine-tune the model. It uses the same optimiser as the one used for training the model. Args: learning_rate (float): The learning rate to use for the fine-tuning. epochs (int): The number of epochs to use for the fine-tuning. """ def __init__( self, learning_rate: float, epochs: int, *args: Any, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self._learning_rate = learning_rate self._epochs = epochs def __call__(self, trainer: Trainer, method: LightningModule) -> None: """This method is used to quantize the model. A copy of the model is saved before quantization. First the model is prepared for quantization. Then the trainer is queried to fine-tune the model. Then the model is quantized. The original model is kept such that it can be recovered. """ self.save_original_model(method) self.prepare(trainer, method) trainer.fine_tune(self._epochs) method.model.apply(torch.ao.quantization.disable_observer) self.save_quantized_model(method) logging.info("Model quantized.") logging.info(method.model) setattr(method, QUANTIZED_KEY, True) self.cleanup(method, trainer)
[docs] def prepare(self, trainer: Trainer, method: LightningModule) -> None: """This method is used to prepare the model for quantization. It caches the original hyperparameters for the optimisation and replaces the hyperparameters with the ones for the fine-tuning. """ method.model.eval() self.replace_layers_for_quantization(method.model) method.model.qconfig = self.get_qconfig() method.model.train() torch.quantization.prepare_qat(method.model, inplace=True) logging.info("Model prepared for quantization.") logging.info(method.model) # Cache the original hyperparameters self._original_hyperparameters = { "learning_rate": method.hparams.learning_rate, "epochs": trainer._epochs, } # Replace the hyperparameters method.hparams.learning_rate = self._learning_rate trainer._epochs = self._epochs
[docs] def cleanup( self, method: LightningModule, trainer: Trainer, *args: Any, **kwargs: Any ) -> None: """This method is used to clean up the model after quantization.""" super().cleanup(*args, **kwargs) # Recover the original hyperparameters method.hparams.learning_rate = self._original_hyperparameters["learning_rate"] trainer._epochs = self._original_hyperparameters["epochs"] del self._original_hyperparameters
[docs] def get_qconfig(self) -> Any: """This method is used to get the quantization configuration. We use the number of activation and weight bits to create the quantization configuration. """ # Else specify the qconfig manually based on the activation and weight bits activation_bits = self._activation_bits weight_bits = self._weight_bits activation_fq = FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=int(2**activation_bits - 1), dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, # Since this is in simulation, we don't want to reduce the range ) weight_fq = FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-int((2**weight_bits) / 2), quant_max=int((2**weight_bits) / 2 - 1), dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, # Since this is in simulation, we don't want to reduce the range ) return torch.quantization.QConfig(activation=activation_fq, weight=weight_fq)
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add specific arguments to the parser.""" parser = super(QATQuantizer, QATQuantizer).add_specific_args(parent_parser) parser.add_argument( "--quantizer_learning_rate", type=float, default=1e-3, help="The learning rate to use for the fine-tuning.", ) parser.add_argument( "--quantizer_epochs", type=int, default=1, help="The number of epochs to use for the fine-tuning.", ) return parser