位置: IT常识 - 正文

pytorch模型(.pt)转onnx模型(.onnx)的方法详解(1)(pytorch模型转tflite)

编辑:rootadmin
pytorch模型(.pt)转onnx模型(.onnx)的方法详解(1)

推荐整理分享pytorch模型(.pt)转onnx模型(.onnx)的方法详解(1)(pytorch模型转tflite),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch模型转onnx,pytorch模型转paddle,pytorch模型转tflite,pytorch模型转paddle,pytorch模型转paddle,pytorch模型转换,将pytorch模型转化为tensorflow,pytorch模型转换,内容如对您有帮助,希望把文章链接给更多的朋友!

1. pytorch模型转换到onnx模型

2.运行onnx模型

3.比对onnx模型和pytorch模型的输出结果

 我这里重点是第一点和第二点,第三部分  比较容易

首先你要安装 依赖库:onnx 和 onnxruntime,

pip install onnxpip install onnxruntime 进行安装

也可以使用清华源镜像文件安装  速度会快些。

开始:

1. pytorch模型转换到onnx模型

pytorch 转 onnx 仅仅需要一个函数 torch.onnx.export 

torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)

参数说明:

model——需要导出的pytorch模型args——模型的输入参数,满足输入层的shape正确即可。path——输出的onnx模型的位置。例如‘yolov5.onnx’。export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。verbose——是否打印模型转换信息。default=False。input_names——输入节点名称。default=None。output_names——输出节点名称。default=None。do_constant_folding——是否使用常量折叠,默认即可。default=True。dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。 格式如下 : 1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]} 2)仅dict<int, string> dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}} 3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]}opset_version——opset的版本,低版本不支持upsample等操作。pytorch模型(.pt)转onnx模型(.onnx)的方法详解(1)(pytorch模型转tflite)

转化代码:参考1:

import torchimport torch.nnimport onnxmodel = torch.load('best.pt')model.eval()input_names = ['input']output_names = ['output']x = torch.randn(1,3,32,32,requires_grad=True)torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

 参考2:PlainC3AENetCBAM 是网络模型,如果你没有自己的网络模型,可能成功不了

import ioimport torchimport torch.onnxfrom models.C3AEModel import PlainC3AENetCBAMdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def test(): model = PlainC3AENetCBAM() pthfile = r'/home/joy/Projects/models/emotion/PlainC3AENet.pth' loaded_model = torch.load(pthfile, map_location='cpu') # try: # loaded_model.eval() # except AttributeError as error: # print(error) model.load_state_dict(loaded_model['state_dict']) # model = model.to(device) #data type nchw dummy_input1 = torch.randn(1, 3, 64, 64) # dummy_input2 = torch.randn(1, 3, 64, 64) # dummy_input3 = torch.randn(1, 3, 64, 64) input_names = [ "actual_input_1"] output_names = [ "output1" ] # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names) torch.onnx.export(model, dummy_input1, "C3AE_emotion.onnx", verbose=True, input_names=input_names, output_names=output_names)if __name__ == "__main__": test()

直接将PlainC3AENetCBAM替换成需要转换的模型,然后修改pthfile,输入和onnx模型名字然后执行即可。

注意:上面代码中注释的dummy_input2,dummy_input3,torch.onnx.export对应的是多个输入的例子。

在转换过程中遇到的问题汇总

RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible

在转换过程中遇到RuntimeError: Failed to export an ONNX attribute, since it's not constant, please try to make things (e.g., kernel size) static if possible的错误。

我成功的案例,我直接把我训练的网络贴上,成功转换,没有from **   import 模型名词这么委婉,合法,我的比较粗暴

import torchimport torch.nnimport onnxfrom torchvision import transformsimport torch.nn as nnfrom torch.nn import Sequential# 添加模型# 设置数据转换方式preprocess_transform = transforms.Compose([ transforms.ToTensor(), # 把数据转换为张量(Tensor) transforms.Normalize( # 标准化,即使数据服从期望值为 0,标准差为 1 的正态分布 mean=[0.5, ], # 期望 std=[0.5, ] # 标准差 )])class CNN(nn.Module): # 从父类 nn.Module 继承 def __init__(self): # 相当于 C++ 的构造函数 # super() 函数是用于调用父类(超类)的一个方法,是用来解决多重继承问题的 super(CNN, self).__init__() # 第一层卷积层。Sequential(意为序列) 括号内表示要进行的操作 self.conv1 = Sequential( nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) # 第二卷积层 self.conv2 = Sequential( nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) ) # 全连接层(Dense,密集连接层) self.dense = Sequential( nn.Linear(7 * 7 * 128, 1024), nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(1024, 10) ) def forward(self, x): # 正向传播 x1 = self.conv1(x) x2 = self.conv2(x1) x = x2.view(-1, 7 * 7 * 128) x = self.dense(x) return x# 训练# 训练和参数优化# 定义求导函数def get_Variable(x): x = torch.autograd.Variable(x) # Pytorch 的自动求导 # 判断是否有可用的 GPU return x.cuda() if torch.cuda.is_available() else x# 判断是否GPUdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# device1 = torch.device('cpu')# 定义网络model = CNN()loaded_model = torch.load('save_model/model.pth', map_location='cuda:0')model.load_state_dict(loaded_model)model.eval()input_names = ['input']output_names = ['output']# x = torch.randn(1,3,32,32,requires_grad=True)x = torch.randn(1, 1, 28, 28, requires_grad=True) # 这个要与你的训练模型网络输入一致。我的是黑白图像torch.onnx.export(model, x, 'save_model/model.onnx', input_names=input_names, output_names=output_names, verbose='True')

前提是你要准备好*.pth模型保持文件

输出结果:

graph(%input : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=1, device=cpu), %dense.0.weight : Float(1024, 6272, strides=[6272, 1], requires_grad=1, device=cpu), %dense.0.bias : Float(1024, strides=[1], requires_grad=1, device=cpu), %dense.3.weight : Float(10, 1024, strides=[1024, 1], requires_grad=1, device=cpu), %dense.3.bias : Float(10, strides=[1], requires_grad=1, device=cpu), %33 : Float(64, 1, 3, 3, strides=[9, 9, 3, 1], requires_grad=0, device=cpu), %34 : Float(64, strides=[1], requires_grad=0, device=cpu), %36 : Float(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu), %37 : Float(128, strides=[1], requires_grad=0, device=cpu)): %input.4 : Float(1, 64, 28, 28, strides=[50176, 784, 28, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input, %33, %34) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\conv.py:443:0 %21 : Float(1, 64, 28, 28, strides=[50176, 784, 28, 1], requires_grad=1, device=cpu) = onnx::Relu(%input.4) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\functional.py:1442:0 %input.8 : Float(1, 64, 14, 14, strides=[12544, 196, 14, 1], requires_grad=1, device=cpu) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%21) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\functional.py:797:0 %input.16 : Float(1, 128, 14, 14, strides=[25088, 196, 14, 1], requires_grad=1, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input.8, %36, %37) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\conv.py:443:0 %25 : Float(1, 128, 14, 14, strides=[25088, 196, 14, 1], requires_grad=1, device=cpu) = onnx::Relu(%input.16) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\functional.py:1442:0 %26 : Float(1, 128, 7, 7, strides=[6272, 49, 7, 1], requires_grad=1, device=cpu) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%25) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\functional.py:797:0 %27 : Long(2, strides=[1], device=cpu) = onnx::Constant[value= -1 6272 [ CPULongType{2} ]]() # E:/paddle_project/Pytorch_Imag_Classify/zifu_fenlei/CNN/pt模型转onnx模型.py:51:0 %28 : Float(1, 6272, strides=[6272, 1], requires_grad=1, device=cpu) = onnx::Reshape(%26, %27) # E:/paddle_project/Pytorch_Imag_Classify/zifu_fenlei/CNN/pt模型转onnx模型.py:51:0 %input.20 : Float(1, 1024, strides=[1024, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%28, %dense.0.weight, %dense.0.bias) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\linear.py:103:0 %input.24 : Float(1, 1024, strides=[1024, 1], requires_grad=1, device=cpu) = onnx::Relu(%input.20) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\functional.py:1442:0 %output : Float(1, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%input.24, %dense.3.weight, %dense.3.bias) # D:\ProgramData\Anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\linear.py:103:0 return (%output)

输出结果的device  是CPU,模型加载的时候是GPU。这就是转换的意义吧

2.运行onnx模型

import onnximport onnxruntime as ortmodel = onnx.load('best.onnx')onnx.checker.check_model(model)session = ort.InferenceSession('best.onnx')x=np.random.randn(1,3,32,32).astype(np.float32) # 注意输入type一定要np.float32!!!!!# x= torch.randn(batch_size,chancel,h,w)outputs = session.run(None,input = { 'input' : x })

参考:

Pytorch模型转onnx模型实例_python_脚本之家 (jb51.net)

pytorch模型转onnx模型的方法详解_python_脚本之家 (jb51.net)

本文链接地址:https://www.jiuchutong.com/zhishi/280479.html 转载请保留说明!

上一篇:command.exe是病毒进程吗 command进程安全吗(cmt.exe病毒)

下一篇:Win10 Build 19044.1379/19043.1379更新补丁KB5007253预览版推送

  • 英语loser是什么意思网红语属于什么人

    英语loser是什么意思网红语属于什么人

  • 小米12sultra如何截屏(小米12sultra如何投屏)

    小米12sultra如何截屏(小米12sultra如何投屏)

  • 魅族18spro屏幕尺寸(2021魅族18pro屏幕尺寸)

    魅族18spro屏幕尺寸(2021魅族18pro屏幕尺寸)

  • 华为hry-al00是什么型号(华为hry_al00a是什么型号)

    华为hry-al00是什么型号(华为hry_al00a是什么型号)

  • 华为荣耀10青春版有没有指纹(华为荣耀10青春版拆机视频)

    华为荣耀10青春版有没有指纹(华为荣耀10青春版拆机视频)

  • 微信当前通话对方网络不佳(微信电话当前通话质量不佳)

    微信当前通话对方网络不佳(微信电话当前通话质量不佳)

  • 触摸式蓝牙耳机怎么开机(触摸式蓝牙耳机使用教程)

    触摸式蓝牙耳机怎么开机(触摸式蓝牙耳机使用教程)

  • 天猫超市发货地在哪里(天猫超市发货地在哪里南京)

    天猫超市发货地在哪里(天猫超市发货地在哪里南京)

  • 小米10和10pro手机壳通用吗(小米10和10pro手机膜一样吗)

    小米10和10pro手机壳通用吗(小米10和10pro手机膜一样吗)

  • 下载app提示未完成付款(下载未完成)

    下载app提示未完成付款(下载未完成)

  • 为什么qq相册里的照片显示不出来(为什么qq相册里的照片变小了)

    为什么qq相册里的照片显示不出来(为什么qq相册里的照片变小了)

  • 微信如何取消爱奇艺自动续费(微信如何取消爱奇艺会员续费)

    微信如何取消爱奇艺自动续费(微信如何取消爱奇艺会员续费)

  • 电脑能连接wifi吗(电脑能连接wifi手机连不上怎么回事)

    电脑能连接wifi吗(电脑能连接wifi手机连不上怎么回事)

  • 华为p40屏幕上的小白圈怎么去掉(华为p40屏幕上的悬浮按钮怎么去掉)

    华为p40屏幕上的小白圈怎么去掉(华为p40屏幕上的悬浮按钮怎么去掉)

  • 微信举报是匿名的吗(微信举报匿名能查到微信两位个号)

    微信举报是匿名的吗(微信举报匿名能查到微信两位个号)

  • 微博不能评论的原因(微博不能评论的文章)

    微博不能评论的原因(微博不能评论的文章)

  • 群红包不能领怎么回事(群里的红包领不了是什么原因)

    群红包不能领怎么回事(群里的红包领不了是什么原因)

  • vivos5有语音助手吗(vivos5语音助手不让别人叫醒)

    vivos5有语音助手吗(vivos5语音助手不让别人叫醒)

  • 怎么看微信聊天记录时间(怎么看微信聊天记录多少g)

    怎么看微信聊天记录时间(怎么看微信聊天记录多少g)

  • 苹果wifi版可以插卡吗(苹果wifi版可以登id吗)

    苹果wifi版可以插卡吗(苹果wifi版可以登id吗)

  • 6splus尺寸长宽多少(6s plus尺寸是多少)

    6splus尺寸长宽多少(6s plus尺寸是多少)

  • 移动盒子光信号闪红灯怎么回事(移动盒子光信号闪烁)

    移动盒子光信号闪红灯怎么回事(移动盒子光信号闪烁)

  • 华为askaloox是什么型号(华为型号ask-al00x是哪款)

    华为askaloox是什么型号(华为型号ask-al00x是哪款)

  • ico图标怎么弄透明(ico格式背景透明)

    ico图标怎么弄透明(ico格式背景透明)

  • 本地连接删除了怎么恢复(本地连接删除了会怎么样)

    本地连接删除了怎么恢复(本地连接删除了会怎么样)

  • 为什么全教育平台登录不了(全国教育平台用户名是什么)

    为什么全教育平台登录不了(全国教育平台用户名是什么)

  • 红米airdots能单独使用吗(红米airdots单耳)

    红米airdots能单独使用吗(红米airdots单耳)

  • 华为p30pro可以同时登录两个微信吗(华为p30pro可以互相定位吗)

    华为p30pro可以同时登录两个微信吗(华为p30pro可以互相定位吗)

  • 苹果备忘录怎么看字数(苹果备忘录怎么导出来长图文)

    苹果备忘录怎么看字数(苹果备忘录怎么导出来长图文)

  • 苹果呼叫转移一直转圈(苹果呼叫转移一直转圈怎么办)

    苹果呼叫转移一直转圈(苹果呼叫转移一直转圈怎么办)

  • safari尚未接入互联网怎么设置(safari尚未接入互联网,自动关闭怎么回事)

    safari尚未接入互联网怎么设置(safari尚未接入互联网,自动关闭怎么回事)

  • U盘装机大师 U盘启动盘制作教程(U盘装系统图文教程)(u盘装机大师怎么用)

    U盘装机大师 U盘启动盘制作教程(U盘装系统图文教程)(u盘装机大师怎么用)

  • 美国个税计算器2021计算器
  • 礼品的进项税能抵扣吗
  • 无形资产摊销的会计科目
  • 应付账款不需要函证
  • 负债的账面价值减去未来期间计算应纳税所得额
  • 外商投资企业采购国产设备退税后续监管办法
  • 个税专项扣除是什么时候开始实行
  • 捐赠和赞助业务的税务处理怎么做?
  • 房地产公司扣减土地出让金怎么入账?
  • 短期投资款取消退回计入什么科目?
  • 进口货物报关费可以计入制造费用
  • 交通补贴可以抵扣个税吗
  • 办理营业执照需要钱吗
  • 医疗机构交企业所得税吗
  • 施工单位的项目
  • 无法读取金税盘时间版本怎么解决
  • 非银行支付机构条例(征求意见稿)
  • 季度不超30万需计提增值税吗
  • 金税三期个人所得税扣缴系统网络设置
  • 弥补以前年度亏损从哪里取数
  • 小规模核定征收税率
  • 中小企业货币资金内部控制案例
  • 电子发票服务平台诺诺发票官网
  • 租房合同开发票的金额要和合同一致吗
  • 成本分摊会计
  • 技术服务的大类包括
  • Mac怎么禁用icloud
  • 增值税当月申报次月缴纳吗
  • 贷款本息转本金
  • 2020工资计税基数怎么算
  • PHP:pg_num_rows()的用法_PostgreSQL函数
  • 绣球花的叶子出现了斑点,这是怎么了?
  • php的数据类型主要有哪几种
  • 工业企业的费用
  • 发票已认证当月未申报怎么办
  • 盛开的樱花和姬子的故事
  • 公司内部往来双向挂账
  • php判断useragent
  • 后浪是什么意思网络用语
  • 什么是合伙企业?它的特点有哪些?
  • 母公司为子公司提供担保是利好吗
  • wind安装
  • okhttp源码
  • 金融商品转让如何确定销售额
  • 汇兑应该计入什么科目
  • sqlserver2019性能
  • 分公司是否能开劳务发票
  • 计提工会经费会计分录怎么写
  • 成本法的处置
  • 同级财政和本级财政
  • 汇算清缴是不是一定要做
  • 建安企业费用有哪些
  • 公司卖出货物没有发票
  • 长投对方亏损
  • 契税的计税金额是什么
  • 基金赎回可以赎回部分吗
  • 非税收入包括哪几种
  • 增值税增量留抵退税进项构成比例
  • 怎么查企业适用的会计准则
  • 慧通年终奖怎么计算
  • 制造费用和直接人工的区别
  • 银行存款日记账与银行对账单之间的核对属于
  • 商业企业注销应检查哪方面的问题
  • win10下mysql 5.7.17 zip压缩包版安装教程
  • windowsxp电脑开机
  • microsoft window vista
  • 快速复制一张同样的幻灯片
  • mac装xp系统
  • 导入extjs、jquery 文件时$使用冲突问题解决方法
  • linux中shell命令
  • shell脚本clear
  • linux装python环境
  • javascript获取数据
  • jquery中有几种方法可以来设置和获取样式
  • 交管12123怎么打电话
  • 差额征税可以全部抵成本么?
  • 资源税原矿和选矿的区别
  • 汉口市中心
  • 电子发票查询官网入口国家税务局重庆电子税务
  • 小规模纳税人和一般纳税人的区别
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设