Source code for yamle.methods.sngp

from typing import Any
import torch
import argparse

from yamle.methods.method import BaseMethod
from yamle.models.specific.sngp import RFF, spectral_norm
from yamle.defaults import CLASSIFICATION_KEY, SEGMENTATION_KEY


[docs] def enable_spectral_normalization(model: torch.nn.Module, coeff: float) -> None: """Replace all the layers in the model with spectral normalized layers. Args: model (torch.nn.Module): The model to enable spectral normalization for. """ for name, child in model.named_children(): if len(list(child.children())) > 0: enable_spectral_normalization(child, coeff) else: setattr(model, name, spectral_norm(child, coeff))
[docs] class SNGPMethod(BaseMethod): """This class is the extension of the base method for which the prediciton is performed through the method of: Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness The core of the method is to 1. enable spectral normalization for all `._residual` layers in the model and replace the `._output` layer with a `._output` layer with a Gaussian process Args: m (float): The gamma for exponential moving average for updating the precision matrix. random_features (int): The number of random features to use in the RFF layer. mean_field_factor (float): The factor to use for the mean field approximation. coeff (float): The coefficient for the spectral normalization. """ tasks = [CLASSIFICATION_KEY, SEGMENTATION_KEY] def __init__( self, m: float = 0.99, random_features: int = 512, mean_field_factor: float = 1.0, coeff: float = 1.0, **kwargs: Any ) -> None: super().__init__(**kwargs) enable_spectral_normalization(self.model, coeff=coeff) assert isinstance( self.model._output, torch.nn.Linear ), "The output layer must be a linear layer" self.model._output = RFF( self.model._output.in_features, self.model._output.out_features, random_features, mean_field_factor, m, )
[docs] def on_train_epoch_start(self) -> None: """In the final epoch we need to update the precision matrix. The update is triggered by the `_final_epoch` flag set to `True`.""" if self.current_epoch == self.trainer.max_epochs - 1: self.model._output._final_epoch = True return super().on_train_epoch_start()
[docs] def on_train_epoch_end(self) -> None: if self.model._output._final_epoch: self.model._output._final_epoch = False self.model._output.compute_covariance() return super().on_train_epoch_end()
[docs] @staticmethod def add_specific_args( parent_parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """This method is used to add the specific arguments for the DUN method.""" parser = super(SNGPMethod, SNGPMethod).add_specific_args(parent_parser) parser.add_argument( "--method_m", type=float, default=0.99, help="The gamma for exponential moving average for updating the precision matrix.", ) parser.add_argument( "--method_random_features", type=int, default=512, help="The number of random features to use in the RFF layer.", ) parser.add_argument( "--method_mean_field_factor", type=float, default=1.0, help="The factor to use for the mean field approximation.", ) parser.add_argument( "--method_coeff", type=float, default=1.0, help="The coefficient for the spectral normalization.", ) return parser