位置: IT常识 - 正文

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

编辑:rootadmin
Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

推荐整理分享Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch基础,pytorch技巧,pytorch技巧,pytorch基础,pytorch零基础入门,pytorch 快速入门,pytorch基础教程,pytorch基础教程,内容如对您有帮助,希望把文章链接给更多的朋友!

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch APIPytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torchimport torch.nn as nnclass Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return dataif __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = 'input' output_name = 'output' torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型import numpy as npimport onnximport onnxruntimeif __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print('output1.shape: ', np.squeeze(np.array(output1), 0).shape) print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

本文链接地址:https://www.jiuchutong.com/zhishi/295970.html 转载请保留说明!

上一篇:vue中使用wangeditor富文本编辑器(vue中使用require报错)

下一篇:图文详解vue.js devtools插件使用方法(图文详解一本通)

  • 增值税查询校验码是什么
  • 增值税和附加税如何计算
  • 工商年报中的资金数额怎么填
  • 水泥建材公司
  • 金蝶软件中怎么登记应该税费
  • 小规模纳税人销售自己使用过固定资产
  • 生产设备保险费会计分录
  • 为什么要计提递延所得税
  • 开票系统技术维护费怎么抵扣
  • 企业购入固定资产在每期末应使用公允价值法进行计量
  • 出口货物如何申报
  • 小规模季报财务报表怎么填写
  • 外币实收资本入账汇率
  • 企业是否可以查员工亲属关系
  • 多余的实收资本可以转到其他应付款吗
  • 代收污水处理费免税
  • 个人挂靠利润如何提取
  • 原材料的可变现净值等于产品可变现净值减加工费么
  • 为职工提供免费午餐
  • 企业所得税弥补亏损年限
  • 在途物资的运费放在哪个科目
  • 工会经费拨缴是什么意思
  • 经营结余年末结转
  • 期间费用构成产品成本嘛
  • 生产企业固定资产折旧
  • 收到进项专用发票怎么做
  • 多扣了离职人员的钱
  • 免抵退附加
  • 月末计算各种税费表格模版
  • 股权转让印花税是双方都要缴纳吗
  • 筹建期装修费用计入什么科目
  • 矿产资源税是多少
  • 支付印花税计入什么科目
  • 稽查查补的税款可以享受即征即退吗
  • 增资印花税税目
  • 盘亏机器设备会计分录
  • 劳务公司差额发票账务处理
  • 其他业务收入借贷方向会计分录
  • 水利基金减免了还用计提吗
  • php数组查找
  • php压缩包
  • 增值税中进项税额比对异常能作废申报吗
  • 吃鸡显卡推荐配置1060 5g
  • 公司已开票给客户,但客户未打款怎么办?
  • PHP:mcrypt_module_get_supported_key_sizes()的用法_Mcrypt函数
  • php多线程curl
  • 捐钱扶贫
  • Vue Admin Template关闭eslint校验,lintOnSave:false设置无效解决办法
  • 公司注销账面实收资本如何处理
  • php cache缓存
  • 新手入门指南
  • es restful api文档
  • 命令行窗口
  • 资本化利息支出现金流量表计入哪里
  • 织梦前台的菜单怎么换
  • 做项目前期
  • 高速费发票可以抵税吗
  • 损益类账户期末有余额吗
  • 网吧相关规定
  • 国债收益率如何查看
  • 企业应付账款的借方登记
  • 增值税防伪税控系统
  • 公司在银行买的金条怎么入账
  • 17增值税发票怎么抵扣
  • 主营业务成本可以设明细科目吗
  • 年度汇算清缴交税怎么做账
  • 流动比率多少合理
  • win10 更新 蓝屏
  • mac上dns设置
  • win8默认输入法设置
  • 如何显示文件后缀win10
  • Windows 7 OpenGL配置,解决“无法启动此程序,因为计算机中丢失glut32.dll。”
  • javascriptjs
  • Android之Android apk动态加载机制的研究
  • linux中的shell编程
  • 举例详解民法典第502条
  • javascript设置字体
  • android混淆后怎么破解
  • 重庆电子税务局网页版登录
  • A级纳税人和一般纳税人区别
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设