位置: IT常识 - 正文

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

编辑:rootadmin
Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

推荐整理分享Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch基础,pytorch技巧,pytorch技巧,pytorch基础,pytorch零基础入门,pytorch 快速入门,pytorch基础教程,pytorch基础教程,内容如对您有帮助,希望把文章链接给更多的朋友!

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch APIPytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torchimport torch.nn as nnclass Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return dataif __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = 'input' output_name = 'output' torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型import numpy as npimport onnximport onnxruntimeif __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print('output1.shape: ', np.squeeze(np.array(output1), 0).shape) print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

本文链接地址:https://www.jiuchutong.com/zhishi/295970.html 转载请保留说明!

上一篇:vue中使用wangeditor富文本编辑器(vue中使用require报错)

下一篇:图文详解vue.js devtools插件使用方法(图文详解一本通)

  • 诺基亚5130和5320(诺基亚5130论坛)(诺基亚5130和5320那个好一点)

    诺基亚5130和5320(诺基亚5130论坛)(诺基亚5130和5320那个好一点)

  • 华为手环7有体温检测吗(华为手环7有体温计功能吗)

    华为手环7有体温检测吗(华为手环7有体温计功能吗)

  • 钉钉接龙怎么发起(钉钉接龙怎么发起电脑操作)

    钉钉接龙怎么发起(钉钉接龙怎么发起电脑操作)

  • 小米mix3电池温度如何查看(miui 电池温度)

    小米mix3电池温度如何查看(miui 电池温度)

  • 哔哩哔哩如何设置息屏继续播放(哔哩哔哩如何设置自动连播)

    哔哩哔哩如何设置息屏继续播放(哔哩哔哩如何设置自动连播)

  • 微信怎么拉别人进自己群(微信怎么拉别人黑名单)

    微信怎么拉别人进自己群(微信怎么拉别人黑名单)

  • 笔记本桌面最下面一排失灵(笔记本桌面最下面一连网一排失灵)

    笔记本桌面最下面一排失灵(笔记本桌面最下面一连网一排失灵)

  • 微信朋友圈只对一个人可见怎么设置(微信朋友圈只对一个人可见对方知道吗)

    微信朋友圈只对一个人可见怎么设置(微信朋友圈只对一个人可见对方知道吗)

  • wps可以压缩文件吗(wps可以压缩文件夹发送吗)

    wps可以压缩文件吗(wps可以压缩文件夹发送吗)

  • 美团怎么确认收货(美团怎么确认收单)

    美团怎么确认收货(美团怎么确认收单)

  • windows是美国哪个公司的产品(windows是美国ibm公司的产品吗)

    windows是美国哪个公司的产品(windows是美国ibm公司的产品吗)

  • bd hd哪个清晰度高(bd清晰度高还是hd)

    bd hd哪个清晰度高(bd清晰度高还是hd)

  • 手机号被淘宝限制登录是什么意思(手机号被淘宝限制怎么办)

    手机号被淘宝限制登录是什么意思(手机号被淘宝限制怎么办)

  • 外国人注册微信需要别人辅助吗(外国人注册微信需要实名认证吗)

    外国人注册微信需要别人辅助吗(外国人注册微信需要实名认证吗)

  • android的开发语言是什么(android开发语言介绍)

    android的开发语言是什么(android开发语言介绍)

  • 微信为什么会自动掉线(微信为什么会自动关闭)

    微信为什么会自动掉线(微信为什么会自动关闭)

  • jpg怎么打印(jpg怎么打印不了)

    jpg怎么打印(jpg怎么打印不了)

  • 手机为什么打开软件会闪退(手机为什么打开微信才能接收到新消息)

    手机为什么打开软件会闪退(手机为什么打开微信才能接收到新消息)

  • 华为mate30为什么在德国发布(华为mate30为什么下架了)

    华为mate30为什么在德国发布(华为mate30为什么下架了)

  • 华为笔记本13与x的对比(华为笔记本13与13s哪个好)

    华为笔记本13与x的对比(华为笔记本13与13s哪个好)

  • vivox7电池在哪里设置(vivox7手机电池在哪里)

    vivox7电池在哪里设置(vivox7手机电池在哪里)

  • 小米Note 3通话声音变小的原因(小米note3手机通话声音小怎么解决方法)

    小米Note 3通话声音变小的原因(小米note3手机通话声音小怎么解决方法)

  • 安卓机怎么刷机(安卓手机怎么刷机更新系统)

    安卓机怎么刷机(安卓手机怎么刷机更新系统)

  • 最新2020.12win10 20H2专业版激活秘钥推荐 附激活工具+教程(最新双色球开奖号码)

    最新2020.12win10 20H2专业版激活秘钥推荐 附激活工具+教程(最新双色球开奖号码)

  • 前端获取mac地址(前端获取当前地址)

    前端获取mac地址(前端获取当前地址)

  • 金税盘开票软件服务电话
  • 资产的计税基础怎么计算
  • 净值型理财投资范围
  • 费用报销票跨月跨年可以吗
  • 公司汽车上牌费入什么科目
  • 资产负债表中应收账款
  • 总公司一般纳税多少
  • 委托其他公司开票收款
  • 加计扣除是什么意思啊举例
  • 多交增值税怎么调整
  • 收银系统已入库怎么操作
  • 老板想提取销售公积金
  • 个人所得税个税申报流程
  • 房屋租赁发票需要备注吗
  • 财产保险合同的主体变更
  • 外购的货物用于集体福利企业所得税
  • 广告公司固定资产有哪些?
  • 认证成功次月何时补发
  • 公司收到银行存款利息收入会计分录
  • 华为p60pro上市时间是几月
  • 进项税额转出会计处理
  • 苹果电脑mac系统怎么用
  • cpu天梯图2022最新版1240p
  • windows10一直刷屏
  • Msssrv.exe - Msssrv是什么进程 有什么用
  • PadExe.exe - PadExe是什么进程 有什么用
  • 股权转让所得怎么做账
  • thinkphp错误日志目录
  • 酒店如何核算成本
  • 待抵扣进项税的限额是什么
  • 库存现金月末怎么结转
  • php plates
  • 成本价低于现价 应该卖吗
  • 建设工程项目设计质量控制的内容
  • php互换两个变量的关系
  • 单位购买电水壶会计入账
  • 微信手续费由谁承担
  • 买汽车配件属于什么服务
  • 增值税专用发票丢了怎么补救
  • 个人所得税申报流程图
  • 企业补提以前年度未提的坏账准备
  • html 基础
  • linux db2安装与配置
  • 代收代缴水费收不上来怎么办
  • 原材料结转成本有几种方法
  • 内账会计成本是什么意思
  • sql命令语句
  • 报销发票啥意思
  • 进项大于销项的会计分录怎么做?
  • Windows下MySQL 5.6安装及配置详细图解(大图版)
  • 净利润率的计算方法公式
  • 购买农产品普通发票怎么做账
  • 法院去单位直接扣划单位薪酬
  • 优秀员工奖金领取表模板
  • 员工借款未还财务有责任吗
  • 收到保险公司赔款
  • 去年多摊销了怎么办
  • 开办费列支范围
  • 养老保险产生的利息怎么来的
  • 待结算财政款项是什么科目
  • 收到红字发票怎么做账怎么做进项税额转出
  • 残保金计算公式2023年
  • sql语句递归
  • 远程桌面连接xp系统
  • 如何设置macbook
  • mac怎么连接打印机设备
  • 微软正式推出wind...
  • win8的ie浏览器
  • js时间倒计时定时器怎么弄
  • ftp下载怎么用
  • 常见的css样式
  • unity游戏开发入门经典
  • linux中切换目录命令符
  • shell批量执行curl
  • jquery属性操作
  • js对象的constructor
  • onclick和onfocus
  • linux装python环境
  • 江西省税务局公众号
  • 河北税务云办税厅官方
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设