Source code for yamle.trainers.calibration
import argparse
from typing import Any
from torch.utils.data import DataLoader
from yamle.trainers.trainer import BaseTrainer
[docs]
class CalibrationTrainer(BaseTrainer):
"""This class defines a temperature trainer which first trains the model and then calibrates it.
The training is on the training set and the calibration is on the calibration set.
Args:
calibration_epochs (int): The number of epochs to calibrate the model.
"""
def __init__(self, calibration_epochs: int, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._calibration_epochs = calibration_epochs
[docs]
def fit(self, train_dataloader: DataLoader, validation_dataloader: DataLoader) -> float:
"""This method trains the method and then does the calibration.
Args:
train_dataloader (DataLoader): The dataloader to be used for training.
validation_dataloader (DataLoader): The dataloader to be used for validation.
"""
training_time = super().fit(train_dataloader, validation_dataloader)
calibration_dataloader = self._datamodule.calibration_dataloader()
if not hasattr(self._method, "calibrate"):
raise ValueError("Make sure that the method has a calibrate method.")
self._method.calibrate()
self._initialize_trainer(epochs=self._calibration_epochs)
calibration_time = super().fit(calibration_dataloader, validation_dataloader)
return training_time + calibration_time
[docs]
@staticmethod
def add_specific_args(
parent_parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""This method adds trainer arguments to the given parser.
Args:
parent_parser (ArgumentParser): The parser to which the arguments should be added.
"""
parser = super(CalibrationTrainer, CalibrationTrainer).add_specific_args(
parent_parser
)
parser.add_argument(
"--trainer_calibration_epochs",
type=int,
default=10,
help="The number of epochs to be used for training.",
)
return parser