Source code for yamle.utils.running_utils
from typing import Any, Dict
import argparse
from yamle.data.datamodule import BaseDataModule
[docs]
def prepare_datamodule_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
"""This method is used to prepare the datamodule kwargs."""
datamodule_kwargs = extract_kwargs(args, "datamodule_")
datamodule_kwargs["seed"] = args.seed
return datamodule_kwargs
[docs]
def prepare_model_kwargs(
args: argparse.Namespace, datamodule: BaseDataModule
) -> Dict[str, Any]:
"""This method is used to prepare the model kwargs."""
model_kwargs = extract_kwargs(args, "model_")
model_kwargs["task"] = datamodule.task
model_kwargs["outputs_dim"] = datamodule.outputs_dim
model_kwargs["inputs_dim"] = datamodule.inputs_dim
model_kwargs["seed"] = args.seed
return model_kwargs
[docs]
def prepare_loss_kwargs(
args: argparse.Namespace, datamodule: BaseDataModule
) -> Dict[str, Any]:
"""This method is used to prepare the loss kwargs."""
loss_kwargs = extract_kwargs(args, "loss_")
loss_kwargs["task"] = datamodule.task
return loss_kwargs
[docs]
def prepare_regularizer_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
"""This method is used to prepare the regularizer kwargs."""
regularizer_kwargs = extract_kwargs(args, "regularizer_")
return regularizer_kwargs
[docs]
def prepare_metrics_kwargs(
args: argparse.Namespace, datamodule: BaseDataModule
) -> Dict[str, Any]:
"""This method is used to prepare the metrics kwargs."""
metrics_kwargs = {}
metrics_kwargs["task"] = datamodule.task
metrics_kwargs["outputs_dim"] = datamodule.outputs_dim
metrics_kwargs["ignore_indices"] = datamodule.ignore_indices
metrics_kwargs["num_members"] = (
1 if not hasattr(args, "method_num_members") else args.method_num_members
)
metrics_kwargs["metrics"] = args.method_metrics
return metrics_kwargs
[docs]
def prepare_method_kwargs(
args: argparse.Namespace, datamodule: BaseDataModule
) -> Dict[str, Any]:
"""This method is used to prepare the method kwargs."""
method_kwargs = extract_kwargs(args, "method_")
method_kwargs["seed"] = args.seed
method_kwargs["task"] = datamodule.task
method_kwargs["outputs_dim"] = datamodule.outputs_dim
method_kwargs["targets_dim"] = datamodule.targets_dim
method_kwargs["outputs_dtype"] = datamodule.outputs_dtype
method_kwargs["inputs_dim"] = (args.datamodule_batch_size, *datamodule.inputs_dim)
method_kwargs["inputs_dtype"] = datamodule.inputs_dtype
method_kwargs["datamodule"] = datamodule
method_kwargs["save_path"] = args.save_path
method_kwargs["metrics_kwargs"] = prepare_metrics_kwargs(args, datamodule)
method_kwargs["model_kwargs"] = prepare_model_kwargs(args, datamodule)
return method_kwargs
[docs]
def prepare_trainer_kwargs(
args: argparse.Namespace, datamodule: BaseDataModule
) -> Dict[str, Any]:
"""This method is used to prepare the trainer kwargs."""
trainer_kwargs = extract_kwargs(args, "trainer_")
trainer_kwargs["save_path"] = args.save_path
trainer_kwargs["st_checkpoint_dir"] = args.st_checkpoint_dir
trainer_kwargs["datamodule"] = datamodule
trainer_kwargs["task"] = datamodule.task
trainer_kwargs["no_saving"] = args.no_saving
return trainer_kwargs
[docs]
def prepare_test_trainer_kwargs(
args: argparse.Namespace, datamodule: BaseDataModule
) -> Dict[str, Any]:
"""This method is used to prepare the trainer kwargs."""
test_trainer_kwargs = prepare_trainer_kwargs(args, datamodule)
test_trainer_kwargs["precision"] = 32
return test_trainer_kwargs
[docs]
def prepare_pruner_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
"""This method is used to prepare the pruner kwargs."""
pruner_kwargs = extract_kwargs(args, "pruner_")
return pruner_kwargs
[docs]
def prepare_quantizer_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
"""This method is used to prepare the quantizer kwargs."""
quantizer_kwargs = extract_kwargs(args, "quantizer_")
return quantizer_kwargs