Source code for yamle.utils.tracing_utils
from typing import Tuple, Optional
import torch
from pytorch_lightning import LightningModule
from yamle.defaults import (
MODULE_INPUT_SHAPE_KEY,
MODULE_OUTPUT_SHAPE_KEY,
MODULE_NAME_KEY,
)
[docs]
def forward_shape_hook(
module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor
) -> None:
"""This function is used to cache the input and output shapes of a module.
The shapes will be stored in the module as `MODULE_INPUT_SHAPE_KEY` and `MODULE_OUTPUT_SHAPE_KEY`.
Args:
module (torch.nn.Module): The module to cache the input and output shapes of.
input (torch.Tensor): The input to the module.
output (torch.Tensor): The output of the module.
"""
setattr(
module,
MODULE_INPUT_SHAPE_KEY,
[x.shape if isinstance(x, torch.Tensor) else None for x in input],
)
if isinstance(output, torch.Tensor):
setattr(module, MODULE_OUTPUT_SHAPE_KEY, [output.shape])
elif isinstance(output, (tuple, list)):
setattr(
module,
MODULE_OUTPUT_SHAPE_KEY,
[
x.shape if x is not None and isinstance(x, torch.Tensor) else None
for x in output
],
)
else:
setattr(module, MODULE_OUTPUT_SHAPE_KEY, [None])
[docs]
@torch.no_grad()
def get_sample_input_and_target(
method: LightningModule, batch_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This method is used to get the sample input and target of the model."""
input_shape = method._inputs_dim
if batch_size is not None:
input_shape = (batch_size, *input_shape[1:])
if method._inputs_dtype == torch.float:
x = torch.ones(input_shape).to(next(method.model.parameters()).device)
elif method._inputs_dtype == torch.long:
x = torch.randint(0, 1, input_shape).to(next(method.model.parameters()).device)
else:
raise ValueError(f"Input dtype {method._inputs_dtype} is not supported.")
output_shape = method._targets_dim
batch_size = method._inputs_dim[0] if batch_size is None else batch_size
output_shape = (
(batch_size, *output_shape)
if isinstance(output_shape, (tuple, list))
else (batch_size, output_shape)
)
if method._outputs_dtype == torch.float:
y = torch.randn(output_shape).to(next(method.model.parameters()).device)
elif method._outputs_dtype == torch.long:
y = torch.randint(0, 1, output_shape).to(next(method.model.parameters()).device)
if method._targets_dim == 1:
y = y.view(-1)
else:
raise ValueError(f"Output dtype {method._outputs_dtype} is not supported.")
return x, y
[docs]
def get_input_shape_from_model(model: torch.nn.Module) -> Tuple[int, ...]:
"""This method is used to get the input shape of the model."""
return getattr(model, MODULE_INPUT_SHAPE_KEY, None)
[docs]
def get_output_shape_from_model(model: torch.nn.Module) -> Tuple[int, ...]:
"""This method is used to get the output shape of the model."""
return getattr(model, MODULE_OUTPUT_SHAPE_KEY, None)
[docs]
@torch.no_grad()
def trace_input_output_shapes(method: LightningModule) -> None:
"""This method is used to trace the input and output shapes of the model.
Additionally, it will name all the modules in the model.
"""
method.eval()
hooks = []
for m in method.model.modules():
hooks.append(m.register_forward_hook(forward_shape_hook))
batch = get_sample_input_and_target(method)
method.test_step(batch, batch_idx=0)
for hook in hooks:
hook.remove()
method.train()
name_all_modules(method.model)
[docs]
def name_all_modules(model: torch.nn.Module) -> None:
"""This method is used to name all the modules in the model."""
for name, module in model.named_modules():
setattr(module, MODULE_NAME_KEY, name)