Source code for yamle.quantization.models.operations
import torch
import torch.nn as nn
[docs]
class QuantizableAdd(nn.Module):
"""A simple class implementing residual addition but with a FloatFunctional object."""
def __init__(self) -> None:
super(QuantizableAdd, self).__init__()
self._add = torch.ao.nn.quantized.FloatFunctional()
[docs]
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""The forward function of the residual addition."""
return self._add.add(x, y)