Source code for yamle.quantization.models.specific.mcdropout
from typing import Any
import torch
import torch.nn.functional as F
from torch.ao.nn.quantized import FloatFunctional
from torch.ao.quantization import QuantStub
from yamle.models.specific.mcdropout import Dropout1d, Dropout2d, Dropout3d
[docs]
class QuantisedDropout1d(Dropout1d):
"""This is the dropout class but the probability is remebered in a `nn.Parameter`.
Args:
p (float): The probability of an element to be zeroed.
inplace (bool): If set to `True`, will do this operation in-place.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.quant = FloatFunctional()
self.quant_stub = QuantStub()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to perform a forward pass through the dropout layer."""
mask = F.dropout(torch.ones_like(x), p=self._p, training=True, inplace=self.inplace)
mask = self.quant_stub(mask)
return self.quant.mul(x, mask)
[docs]
class QuantisedDropout2d(Dropout2d):
"""This is the dropout class but the probability is remebered in a `nn.Parameter`.
Args:
p (float): The probability of an element to be zeroed.
inplace (bool): If set to `True`, will do this operation in-place.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.quant = FloatFunctional()
self.quant_stub = QuantStub()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to perform a forward pass through the dropout layer."""
# Create a mask where filters will be completely zeroed out
mask = F.dropout2d(torch.ones_like(x), p=self._p, training=True, inplace=self.inplace)
mask = self.quant_stub(mask)
return self.quant.mul(x, mask)
[docs]
class QuantisedDropout3d(Dropout3d):
"""This is the dropout class but the probability is remebered in a `nn.Parameter`.
Args:
p (float): The probability of an element to be zeroed.
inplace (bool): If set to `True`, will do this operation in-place.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.quant = FloatFunctional()
self.quant_stub = QuantStub()
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""This method is used to perform a forward pass through the dropout layer."""
# Create a mask where filters will be completely zeroed out
mask = F.dropout3d(torch.ones_like(x), p=self._p, training=True, inplace=self.inplace)
mask = self.quant_stub(mask)
return self.quant.mul(x, mask)