Source code for yamle.utils.export_utils

import torch
import torch.nn as nn
from yamle.utils.tracing_utils import get_input_shape_from_model
import onnx

import logging

logging = logging.getLogger("pytorch_lightning")


[docs] def export_onnx(model: nn.Module, path: str) -> None: """This method is used to export the model to ONNX. Args: model (nn.Module): The model to export. path (str): The path to save the model. """ model.eval() input_shape = list(get_input_shape_from_model(model)[0]) input_shape[0] = 1 x = torch.randn(*input_shape).to(next(model.parameters()).device) # Make a pass to count the number of outputs outputs = model(x) num_outputs = len(outputs) if isinstance(outputs, (list, tuple)) else 1 output_names = [f"output_{i}" for i in range(num_outputs)] dynamic_axes = { "input": {0: "batch_size"}, } for i in range(num_outputs): dynamic_axes[output_names[i]] = {0: "batch_size"} logging.info("Exporting model to ONNX.") torch.onnx.export( model, x, path, export_params=True, opset_version=10, do_constant_folding=True, input_names=["input"], output_names=output_names, dynamic_axes=dynamic_axes, ) # Perform sanity check logging.info("Checking ONNX model.") onnx_model = onnx.load(path) onnx.checker.check_model(onnx_model) logging.info("ONNX model is valid.")