Source code for yamle.quantization.static
import logging
from typing import Any
import torch
from pytorch_lightning import LightningModule, Trainer
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.observer import HistogramObserver
from yamle.defaults import QUANTIZED_KEY
from yamle.quantization.quantizer import BaseQuantizer
logging = logging.getLogger("pytorch_lightning")
[docs]
class StaticQuantizer(BaseQuantizer):
"""This is the static quantizer class.
It performs static post-training quantization on the model.
It does it with respect to a specific number of bits for the activation and weight.
The quantization is simulated and the model is not actually quantized.
"""
def __call__(self, trainer: Trainer, method: LightningModule) -> None:
"""This method is used to quantize the model.
A copy of the model is saved before quantization, just in case.
First the model is prepared for quantization.
Then the trainer is queried for the dataloader - this can be used to calibrate the model or fine-tune it.
Then the the fake quantization is applied to the model and the observer is disabled to simulate quantization.
The original model is kept such that it can be recovered.
"""
self.save_original_model(method)
self.prepare(trainer, method)
trainer.calibrate()
method.model.apply(torch.ao.quantization.enable_fake_quant)
method.model.apply(torch.ao.quantization.disable_observer)
self.save_quantized_model(method)
logging.info("Model quantized.")
logging.info(method.model)
setattr(method, QUANTIZED_KEY, True)
[docs]
def prepare(self, trainer: Trainer, method: LightningModule) -> None:
"""This method is used to prepare the model for quantization."""
method.model.eval()
self.replace_layers_for_quantization(method.model)
method.model.qconfig = self.get_qconfig()
method.model.train()
torch.quantization.prepare_qat(method.model, inplace=True)
method.model.apply(torch.ao.quantization.disable_fake_quant)
logging.info("Model prepared for quantization.")
logging.info(method.model)
[docs]
def get_qconfig(self) -> Any:
"""This method is used to get the quantization configuration.
We use the number of activation and weight bits to create the quantization configuration.
"""
# Else specify the qconfig manually based on the activation and weight bits
activation_bits = self._activation_bits
weight_bits = self._weight_bits
activation_fq = FakeQuantize.with_args(
observer=HistogramObserver,
quant_min=0,
quant_max=int(2**activation_bits - 1),
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=False, # Since this is in simulation, we don't want to reduce the range
)
weight_fq = FakeQuantize.with_args(
observer=HistogramObserver,
quant_min=-int((2**weight_bits) / 2),
quant_max=int((2**weight_bits) / 2 - 1),
dtype=torch.qint8,
qscheme=torch.per_tensor_symmetric,
reduce_range=False, # Since this is in simulation, we don't want to reduce the range
)
qconfig = torch.quantization.QConfig(activation=activation_fq, weight=weight_fq)
return qconfig