Pytorch 模型 pth 转onnx模型多输出实验

pytorch 模型预实验完成,进一步部署很重要的一步,转存pth模型为ONNX

Reference:

5.

# -*- coding: utf-8 -*-
"""
Project: TESTCODE
Creator: CHENRAN
Create time: 2022-06-14 14:20
IDE: PyCharm
Introduction:
"""
import torch
print(torch.__version__)

# pth模型转onnx模型
def pth_to_onnx(input, pth_path, onnx_path,):

    model = torch.load(pth_path)
    model.eval()
    loaded_model = torch.jit.load(pth_path)
    loaded_model_output = loaded_model(input.cuda())
    # 指定模型的输入,以及onnx的输出路径
    """
    model:# 正在运行的模型
    input: # 模型输入(或用于多个输入的元组)
    onnx_path: # 保存onnx模型格式的路径名字
    verbose:# 是否打印网络
    opset_version:# 导出模型的ONNX版本
    input_names:# 模型的输入名称
    output_names:# 模型的输出名称
    example_outputs: 模型的输出示例
    """

    output_names_list = [prob, prob_logit, prob_map, height_prob, height_prob_logit,
                        center_mask, visit_mask, center_idx, offset, edge_map,
                         seg_map]

    torch.onnx.export(model, input, onnx_path,
                      opset_version=11,
                      input_names=[input],
                      export_params=True,
                      enable_onnx_checker=True,
                      verbose=True,
                      output_names=output_names_list,
                      example_outputs= loaded_model_output,
                      )
    print("Exporting .pth model to onnx model")
    print("Successful!!!")



def main():
    example = torch.rand(1, 3, 320, 800)
    folder_path = /media/ubuntu/backup/CRData/Eigenlanes_redefine/Modeling/culane/output/train/weight/
    jit_pth_path = folder_path+finalmodelalljit.pth

    onnx_path = folder_path+model.onnx
    pth_to_onnx(input=example, pth_path=jit_pth_path, onnx_path=onnx_path)



if __name__ == __main__:
    main()

Result:

经验分享 程序员 微信小程序 职场和发展