from typing import List, Dict, Any, Union, Tuple, Optional
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import math
from yamle.defaults import (
FROZEN_MASK_KEY,
FROZEN_DATA_KEY,
OPTIMIZER_ID_KEY,
DISABLED_OPTIMIZATION_KEY,
TINY_EPSILON,
)
from yamle.models.specific.sgld import SGLD, pSGLD
[docs]
def get_optimizer(
name: str, parameters: List[nn.Parameter], optimizer_config: Dict[str, Any]
) -> torch.optim.Optimizer:
optimizer: torch.optim.Optimizer = None
if name == "adam":
optimizer = torch.optim.Adam(
parameters,
lr=optimizer_config["lr"],
weight_decay=optimizer_config["weight_decay"],
)
elif name == "sgd":
optimizer = torch.optim.SGD(
parameters,
lr=optimizer_config["lr"],
momentum=optimizer_config["momentum"],
weight_decay=optimizer_config["weight_decay"],
)
elif name == "sgld":
optimizer = SGLD(
parameters, lr=optimizer_config["lr"], momentum=optimizer_config["momentum"]
)
elif name == "psgld":
optimizer = pSGLD(parameters, lr=optimizer_config["lr"])
else:
raise ValueError(
f"Optimizer {name} is not supported. Please use one of the following: adam, sgd."
)
return optimizer
[docs]
def get_scheduler(
name: str, optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]
) -> torch.optim.lr_scheduler._LRScheduler:
scheduler: torch.optim.lr_scheduler._LRScheduler = None
if name == "none":
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1)
elif name == "plateau":
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode=scheduler_config["mode"],
factor=scheduler_config["factor"],
patience=scheduler_config["patience"],
verbose=True,
)
elif name == "linear":
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda epoch: 1 - epoch / scheduler_config["max_epochs"]
)
elif name == "exponential":
scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma=scheduler_config["gamma"]
)
elif name == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=scheduler_config["max_epochs"], eta_min=0
)
elif name == "step":
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=scheduler_config["step_size"],
gamma=scheduler_config["gamma"],
)
else:
raise ValueError(
f"Scheduler {name} is not supported. Please use one of the following: step, cosine, linear, power_growth, sine."
)
return scheduler
[docs]
class ScalarScheduler(ABC):
"""This is a general class for scalar schedulers."""
[docs]
@abstractmethod
def step(self) -> None:
"""This method is used to update the scheduler."""
pass
[docs]
@abstractmethod
def get_value(self) -> float:
"""This method is used to get the current value of the scheduler."""
pass
[docs]
@abstractmethod
def reset(self) -> None:
"""This method is used to reset the scheduler."""
pass
[docs]
def state_dict(self) -> Dict[str, Any]:
"""This method is used to get the state of the scheduler."""
return {}
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""This method is used to load the state of the scheduler."""
pass
[docs]
class LinearScalarScheduler(ScalarScheduler):
"""This class defines a linear scheduler for a scalar value.
Until the start epoch the return value will be `start_value`.
After the start epoch the return value will be linearly increased until the end epoch.
After the end epoch the return value will be the `end_value`.
Args:
start_value (float): The initial value of the scheduler.
start_epoch (int): The epoch to start the scheduler.
end_value (float): The final value of the scheduler.
end_epoch (int): The epoch to end the scheduler.
"""
def __init__(
self, start_value: float, start_epoch: int, end_value: float, end_epoch: int
) -> None:
self._start_value = start_value
self._start_epoch = start_epoch
self._end_value = end_value
self._end_epoch = end_epoch
self._current_epoch = 0
# This is used to set a hard value for the scheduler, ignoring the schedule.
self._hard_value: float = None
[docs]
def step(self) -> None:
"""This method is used to update the scheduler."""
self._current_epoch += 1
[docs]
def set_hard_value(self, value: float) -> None:
"""This method is used to set a hard value for the scheduler, ignoring the schedule."""
if self._hard_value is not None:
raise ValueError("The hard value is already set.")
self._hard_value = value
[docs]
def get_value(self) -> float:
"""This method is used to get the current value of the scheduler."""
if self._hard_value is not None:
return self._hard_value
elif self._current_epoch < self._start_epoch:
return self._start_value
elif self._current_epoch >= self._end_epoch:
return self._end_value
else:
return self._start_value + (self._end_value - self._start_value) * (
self._current_epoch - self._start_epoch
) / (self._end_epoch - self._start_epoch - 1 + TINY_EPSILON)
[docs]
def reset(self) -> None:
"""This method is used to reset the scheduler."""
self._current_epoch = 0
[docs]
def state_dict(self) -> Dict[str, Any]:
"""This method is used to get the state of the scheduler."""
state_dict = super().state_dict()
state_dict.update(
{
"current_epoch": self._current_epoch,
"start_value": self._start_value,
"start_epoch": self._start_epoch,
"end_value": self._end_value,
"end_epoch": self._end_epoch,
"hard_value": self._hard_value,
}
)
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""This method is used to load the state of the scheduler."""
self._current_epoch = state_dict["current_epoch"]
self._start_value = state_dict["start_value"]
self._start_epoch = state_dict["start_epoch"]
self._end_value = state_dict["end_value"]
self._end_epoch = state_dict["end_epoch"]
self._hard_value = state_dict["hard_value"]
[docs]
class PowerGrowthScalarScheduler(ScalarScheduler):
"""This class defines an exponential scheduler for a scalar value.
Until the start epoch the return value will be `start_value`.
Given a `power` value,
Args:
start_value (float): The initial value of the scheduler.
start_epoch (int): The epoch to start the scheduler.
end_value (float): The final value of the scheduler.
end_epoch (int): The epoch to end the scheduler.
gamma (float): The exponential growth factor.
"""
def __init__(
self,
start_value: float,
start_epoch: int,
end_value: float,
end_epoch: int,
power: float = 1.0,
) -> None:
assert (
start_value <= end_value
), f"The start value {start_value} must be smaller than the end value {end_value}."
self._start_value = start_value
self._start_epoch = start_epoch
self._end_value = end_value
self._end_epoch = end_epoch
self._current_epoch = 0
# This is used to set a hard value for the scheduler, ignoring the schedule.
self._hard_value: float = None
self._power = power
[docs]
def step(self) -> None:
"""This method is used to update the scheduler."""
self._current_epoch += 1
[docs]
def set_hard_value(self, value: float) -> None:
"""This method is used to set a hard value for the scheduler, ignoring the schedule."""
if self._hard_value is not None:
raise ValueError("The hard value is already set.")
self._hard_value = value
[docs]
def get_value(self) -> float:
"""This method is used to get the current value of the scheduler."""
if self._hard_value is not None:
return self._hard_value
elif self._current_epoch < self._start_epoch:
return self._start_value
elif self._current_epoch >= self._end_epoch:
return self._end_value
else:
# Generate a value between 0 and 1.
value = (self._current_epoch - self._start_epoch) / (
self._end_epoch - self._start_epoch - 1
)
value **= self._power
return self._start_value + (self._end_value - self._start_value) * value
[docs]
def reset(self) -> None:
"""This method is used to reset the scheduler."""
self._current_epoch = 0
[docs]
def state_dict(self) -> Dict[str, Any]:
"""This method is used to get the state of the scheduler."""
state_dict = super().state_dict()
state_dict.update(
{
"current_epoch": self._current_epoch,
"start_value": self._start_value,
"start_epoch": self._start_epoch,
"end_value": self._end_value,
"end_epoch": self._end_epoch,
"hard_value": self._hard_value,
}
)
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""This method is used to load the state of the scheduler."""
self._current_epoch = state_dict["current_epoch"]
self._start_value = state_dict["start_value"]
self._start_epoch = state_dict["start_epoch"]
self._end_value = state_dict["end_value"]
self._end_epoch = state_dict["end_epoch"]
self._hard_value = state_dict["hard_value"]
self._power = state_dict["power"]
[docs]
class SineScalarScheduler(ScalarScheduler):
"""This class defines a sine scheduler for a scalar value.
Until the start epoch the return value will be `start_value`.
After the start epoch the return value will be increased until the end epoch.
After the end epoch the return value will be the `end_value`.
Args:
start_value (float): The initial value of the scheduler.
start_epoch (int): The epoch to start the scheduler.
end_value (float): The final value of the scheduler.
end_epoch (int): The epoch to end the scheduler.
"""
def __init__(
self, start_value: float, start_epoch: int, end_value: float, end_epoch: int
) -> None:
assert (
start_value <= end_value
), f"The start value {start_value} must be less than the end value {end_value}."
self._start_value = start_value
self._start_epoch = start_epoch
self._end_value = end_value
self._end_epoch = end_epoch
self._current_epoch = 0
# This is used to set a hard value for the scheduler, ignoring the schedule.
self._hard_value: float = None
[docs]
def step(self) -> None:
"""This method is used to update the scheduler."""
self._current_epoch += 1
[docs]
def set_hard_value(self, value: float) -> None:
"""This method is used to set a hard value for the scheduler, ignoring the schedule."""
if self._hard_value is not None:
raise ValueError("The hard value is already set.")
self._hard_value = value
[docs]
def get_value(self) -> float:
"""This method is used to get the current value of the scheduler."""
if self._hard_value is not None:
return self._hard_value
elif self._current_epoch < self._start_epoch:
return self._start_value
elif self._current_epoch >= self._end_epoch:
return self._end_value
else:
return (
self._start_value
+ (self._end_value - self._start_value)
* (
torch.sin(
torch.tensor(
(self._current_epoch - self._start_epoch)
/ (self._end_epoch - self._start_epoch)
* (math.pi / 2)
)
)
)
).item()
[docs]
def reset(self) -> None:
"""This method is used to reset the scheduler."""
self._current_epoch = 0
[docs]
def state_dict(self) -> Dict[str, Any]:
"""This method is used to get the state of the scheduler."""
state_dict = super().state_dict()
state_dict.update(
{
"current_epoch": self._current_epoch,
"start_value": self._start_value,
"start_epoch": self._start_epoch,
"end_value": self._end_value,
"end_epoch": self._end_epoch,
"hard_value": self._hard_value,
}
)
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""This method is used to load the state of the scheduler."""
self._current_epoch = state_dict["current_epoch"]
self._start_value = state_dict["start_value"]
self._start_epoch = state_dict["start_epoch"]
self._end_value = state_dict["end_value"]
self._end_epoch = state_dict["end_epoch"]
self._hard_value = state_dict["hard_value"]
[docs]
class CosineScalarScheduler(ScalarScheduler):
"""This class defines a cosine scheduler for a scalar value.
Until the start epoch the return value will be `start_value`.
After the start epoch the return value will be increased until the end epoch.
After the end epoch the return value will be the `end_value`.
Args:
start_value (float): The initial value of the scheduler.
start_epoch (int): The epoch to start the scheduler.
end_value (float): The final value of the scheduler.
end_epoch (int): The epoch to end the scheduler.
"""
def __init__(
self, start_value: float, start_epoch: int, end_value: float, end_epoch: int
) -> None:
assert (
start_value >= end_value
), f"The start value {start_value} must be greater than the end value {end_value}."
self._start_value = start_value
self._start_epoch = start_epoch
self._end_value = end_value
self._end_epoch = end_epoch
self._current_epoch = 0
# This is used to set a hard value for the scheduler, ignoring the schedule.
self._hard_value: float = None
[docs]
def step(self) -> None:
"""This method is used to update the scheduler."""
self._current_epoch += 1
[docs]
def set_hard_value(self, value: float) -> None:
"""This method is used to set a hard value for the scheduler, ignoring the schedule."""
if self._hard_value is not None:
raise ValueError("The hard value is already set.")
self._hard_value = value
[docs]
def get_value(self) -> float:
"""This method is used to get the current value of the scheduler."""
if self._hard_value is not None:
return self._hard_value
elif self._current_epoch < self._start_epoch:
return self._start_value
elif self._current_epoch >= self._end_epoch:
return self._end_value
else:
return (
self._start_value
+ (self._end_value - self._start_value)
* (
1
+ torch.cos(
torch.tensor(
(self._current_epoch - self._start_epoch)
/ (self._end_epoch - self._start_epoch)
* math.pi
)
)
)
/ 2
).item()
[docs]
def reset(self) -> None:
"""This method is used to reset the scheduler."""
self._current_epoch = 0
[docs]
def state_dict(self) -> Dict[str, Any]:
"""This method is used to get the state of the scheduler."""
state_dict = super().state_dict()
state_dict.update(
{
"current_epoch": self._current_epoch,
"start_value": self._start_value,
"start_epoch": self._start_epoch,
"end_value": self._end_value,
"end_epoch": self._end_epoch,
"hard_value": self._hard_value,
}
)
return state_dict
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""This method is used to load the state of the scheduler."""
self._current_epoch = state_dict["current_epoch"]
self._start_value = state_dict["start_value"]
self._start_epoch = state_dict["start_epoch"]
self._end_value = state_dict["end_value"]
self._end_epoch = state_dict["end_epoch"]
self._hard_value = state_dict["hard_value"]
AVAILABLE_SCALAR_SCHEDULERS = {
"linear": LinearScalarScheduler,
"powergrowth": PowerGrowthScalarScheduler,
"sine": SineScalarScheduler,
"cosine": CosineScalarScheduler,
}
[docs]
@torch.no_grad()
def recover_frozen_weights(model: nn.Module) -> None:
"""This function is used to recover frozen weights after an optimization step.
The parameters that do have a `FROZEN_MASK_KEY` and `FROZEN_DATA_KEY` attribute will be recovered. The
`FROZEN_MASK_KEY` is assumed to be a 1D tensor with the same number of elements as the parameter. The `FROZEN_DATA_KEY`
has the same shape as the parameter. The `FROZEN_MASK_KEY` is used to select the elements that will be recovered.
It should only contain `True` or `False` values. The `True` values are the elements that will be recovered from
the `FROZEN_DATA_KEY`.
Args:
model (nn.Module): The model to recover the frozen weights.
"""
for param in model.parameters():
if hasattr(param, FROZEN_MASK_KEY) and hasattr(param, FROZEN_DATA_KEY):
frozen_mask = getattr(param, FROZEN_MASK_KEY)
frozen_data = getattr(param, FROZEN_DATA_KEY)
assert torch.all(frozen_mask == 0) or torch.all(
frozen_mask == 1
), f"The mask ({frozen_mask}) should only contain 0 or 1."
assert (
frozen_mask.numel() == param.numel()
), f"The number of elements of the mask ({frozen_mask.numel()}) and the parameter ({param.numel()}) are different."
assert (
frozen_mask.numel() == frozen_data.numel()
), f"The number of elements of the mask ({frozen_mask.numel()}) and the frozen data ({frozen_data.numel()}) are different."
original_shape = param.shape
current_data = param.data.view(-1)
frozen_data = frozen_data.view(-1)
frozen_mask = frozen_mask.view(-1).bool()
current_data[frozen_mask] = frozen_data[frozen_mask]
param.data = current_data.view(original_shape).contiguous()
[docs]
def freeze_weights(
parameters: Union[nn.Parameter, List[nn.Parameter]],
masks: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> None:
"""This function is used to freeze weights of a model.
The masks are used to select the weights that will be frozen. The masks should have the same number of elements as the
parameters. The masks should only contain `True` or `False` values. The `True` values are the elements that will be
frozen. If no mask is provided, all the weights will be frozen.
Args:
parameters (Union[nn.Parameter, List[nn.Parameter]]): The parameters to freeze.
masks (Optional[Union[torch.Tensor, List[torch.Tensor]]], optional): The masks to select the weights to freeze. Defaults to None.
"""
if masks is None:
masks = [torch.ones_like(param) for param in parameters]
if isinstance(parameters, nn.Parameter):
parameters = [parameters]
if isinstance(masks, torch.Tensor):
masks = [masks]
assert len(parameters) == len(
masks
), f"The number of parameters ({len(parameters)}) and masks ({len(masks)}) are different."
for param, mask in zip(parameters, masks):
assert (
mask.numel() == param.numel()
), f"The number of elements of the mask ({mask.numel()}) and the parameter ({param.numel()}) are different."
assert torch.all(mask == 0) or torch.all(
mask == 1
), f"The mask ({mask}) should only contain 0 or 1."
frozen_data = param.data.detach().clone()
setattr(param, FROZEN_MASK_KEY, mask)
setattr(param, FROZEN_DATA_KEY, frozen_data)
[docs]
def split_optimizer_parameters(
parameters: Union[nn.Parameter, List[nn.Parameter], List[Tuple[str, nn.Parameter]]]
) -> List[Dict[str, Union[List[nn.Parameter], List[Tuple[str, nn.Parameter]]]]]:
"""This function is used to split the parameters of a model into multiple dictionaries depending on an id.
Given all model parameters, this function looks at if `OPTIMIZER_ID_KEY` is set as an attribute of the parameter.
If it is set it will add the parameter to the dictionary with the corresponding `id`. If the `OPTIMIZER_ID_KEY` is
not assigned it is assumed that the parameter is used by the first optimizer.
Returns:
A list of dictionaries with the parameters split.
"""
id_parameter_mapping = {}
if isinstance(parameters, nn.Parameter):
parameters = [parameters]
for param in parameters:
if isinstance(param, tuple):
name = param[0]
p = param[1]
else:
p = param
if not isinstance(p, nn.Parameter):
raise ValueError(f"The parameter {p} is not a valid parameter.")
if hasattr(p, OPTIMIZER_ID_KEY):
optimizer_id = getattr(p, OPTIMIZER_ID_KEY)
else:
optimizer_id = 0 # Default optimizer id
if optimizer_id not in id_parameter_mapping:
id_parameter_mapping[optimizer_id] = []
if isinstance(param, tuple):
id_parameter_mapping[optimizer_id].append((name, p))
else:
id_parameter_mapping[optimizer_id].append(p)
# Sort the parameters by id, key
id_parameter_mapping = [
v for _, v in sorted(id_parameter_mapping.items(), key=lambda item: item[0])
]
return [[{"params": param_list}] for param_list in id_parameter_mapping]
[docs]
def set_optimizer_id(
parameters: Union[List[nn.Parameter], nn.Parameter], optimizer_id: int
) -> None:
"""This function is used to set the `OPTIMIZER_ID_KEY` attribute to the parameters.
Args:
parameters (List[nn.Parameter]): The parameters to set the optimizer id.
optimizer_id (int): The optimizer id to set.
"""
if isinstance(parameters, nn.Parameter):
parameters = [parameters]
for param in parameters:
assert not hasattr(
param, OPTIMIZER_ID_KEY
), f"The parameter {param} already has an optimizer id."
setattr(param, OPTIMIZER_ID_KEY, optimizer_id)
[docs]
def disable_optimization(parameters: Union[List[nn.Parameter], nn.Parameter]) -> None:
"""This function is used to disable the optimization of the parameters.
Args:
parameters (List[nn.Parameter]): The parameters to disable the optimization.
"""
if isinstance(parameters, nn.Parameter):
parameters = [parameters]
for param in parameters:
assert isinstance(
param, nn.Parameter
), f"The parameter {param} is not a valid parameter."
param.requires_grad = False
setattr(param, DISABLED_OPTIMIZATION_KEY, True)
[docs]
def enable_optimization(parameters: Union[List[nn.Parameter], nn.Parameter]) -> None:
"""This function is used to enable the optimization of the parameters.
Args:
parameters (List[nn.Parameter]): The parameters to enable the optimization.
"""
if isinstance(parameters, nn.Parameter):
parameters = [parameters]
for param in parameters:
assert isinstance(
param, nn.Parameter
), f"The parameter {param} is not a valid parameter."
param.requires_grad = True
setattr(param, DISABLED_OPTIMIZATION_KEY, False)