Source code for yamle.trainers

from typing import Type


from yamle.trainers.trainer import BaseTrainer
from yamle.trainers.ensemble import EnsembleTrainer, BaggingTrainer
from yamle.trainers.calibration import CalibrationTrainer

AVAILABLE_TRAINERS = {
    "base": BaseTrainer,
    "ensemble": EnsembleTrainer,
    "bagging": BaggingTrainer,
    "calibration": CalibrationTrainer,
}


[docs] def trainer_factory(trainer_type: str) -> Type[BaseTrainer]: """This function is used to create a trainer instance based on the trainer type. Args: trainer_type (str): The type of trainer to create. """ if trainer_type not in AVAILABLE_TRAINERS: raise ValueError(f"Unknown trainer type {trainer_type}.") return AVAILABLE_TRAINERS[trainer_type]