位置: IT常识 - 正文

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

编辑:rootadmin
Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

推荐整理分享Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch基础,pytorch技巧,pytorch技巧,pytorch基础,pytorch零基础入门,pytorch 快速入门,pytorch基础教程,pytorch基础教程,内容如对您有帮助,希望把文章链接给更多的朋友!

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch APIPytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torchimport torch.nn as nnclass Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return dataif __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = 'input' output_name = 'output' torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型import numpy as npimport onnximport onnxruntimeif __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print('output1.shape: ', np.squeeze(np.array(output1), 0).shape) print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

本文链接地址:https://www.jiuchutong.com/zhishi/295970.html 转载请保留说明!

上一篇:vue中使用wangeditor富文本编辑器(vue中使用require报错)

下一篇:图文详解vue.js devtools插件使用方法(图文详解一本通)

  • 红米note11pro怎么录屏(红米note11pro怎么打开OTG功能)

    红米note11pro怎么录屏(红米note11pro怎么打开OTG功能)

  • 红米9a是什么处理器(红米9a是啥)

    红米9a是什么处理器(红米9a是啥)

  • 充电宝充不满(充电宝充不满电一直闪)

    充电宝充不满(充电宝充不满电一直闪)

  • 通过抖音号能查到微博吗(通过抖音号能查到手机号码吗)

    通过抖音号能查到微博吗(通过抖音号能查到手机号码吗)

  • 华为荣耀9xOTG设置在哪里

    华为荣耀9xOTG设置在哪里

  • 苹果辅助触控自动消失什么原因(苹果辅助触控自动关闭怎样解决)

    苹果辅助触控自动消失什么原因(苹果辅助触控自动关闭怎样解决)

  • 在手机上怎么退出在电脑上登录QQ(在手机上怎么退流量包)

    在手机上怎么退出在电脑上登录QQ(在手机上怎么退流量包)

  • amda86600k相当于i几(amda106800k相当于i几)

    amda86600k相当于i几(amda106800k相当于i几)

  • 魅族17是曲面屏吗(魅族17pro是曲面屏)

    魅族17是曲面屏吗(魅族17pro是曲面屏)

  • 微信登录失败4-34是什么原因(微信登录失败4-34要怎么解决)

    微信登录失败4-34是什么原因(微信登录失败4-34要怎么解决)

  • 电话线接口能直接接网线吗(电话线能直接接吗)

    电话线接口能直接接网线吗(电话线能直接接吗)

  • 聊天记录可以免费恢复吗(聊天记录免费版)

    聊天记录可以免费恢复吗(聊天记录免费版)

  • pr与显卡驱动不兼容(pr2020与显卡驱动不兼容)

    pr与显卡驱动不兼容(pr2020与显卡驱动不兼容)

  • 幻灯片怎么点一下出一个(幻灯片怎么点一下出来一张图片)

    幻灯片怎么点一下出一个(幻灯片怎么点一下出来一张图片)

  • 苹果11能放几张卡(苹果14pro max拍照)

    苹果11能放几张卡(苹果14pro max拍照)

  • 手机指纹键怎么更换(手机指纹键怎么拆下来)

    手机指纹键怎么更换(手机指纹键怎么拆下来)

  • 解除唯品会绑定手机号(怎么解绑唯品会)

    解除唯品会绑定手机号(怎么解绑唯品会)

  • 小米电脑键盘锁快捷键(小米电脑键盘锁住了打不了字)

    小米电脑键盘锁快捷键(小米电脑键盘锁住了打不了字)

  • 华为p30pro的耳机插孔在哪里(华为p30pro的耳机模式在哪里)

    华为p30pro的耳机插孔在哪里(华为p30pro的耳机模式在哪里)

  • iptv包月费是什么意思(iptv包月-包月费)

    iptv包月费是什么意思(iptv包月-包月费)

  • 快手对战是干嘛呢(快手对战pk是什么意思)

    快手对战是干嘛呢(快手对战pk是什么意思)

  • iphonexrhome键怎么调出来(iphonexrhome键怎么隐藏)

    iphonexrhome键怎么调出来(iphonexrhome键怎么隐藏)

  • 苹果xsmax为什么总发烫(苹果xsmax为什么拍照好看)

    苹果xsmax为什么总发烫(苹果xsmax为什么拍照好看)

  • mp2j2cha是ipad第几代(mp2g2cha是ipad几代)

    mp2j2cha是ipad第几代(mp2g2cha是ipad几代)

  • 多数路由器的ip地址和网关192.168.1.1(路由器ip分配数量)

    多数路由器的ip地址和网关192.168.1.1(路由器ip分配数量)

  • 税控盘网上申请解锁
  • 个人所得税累计收入
  • 政府非税收入的种类
  • 本期应纳税额减征额怎么填写
  • 如何让自己公司成为供应商
  • 债务重组的会计准则
  • 公司广告法违规交不起罚款怎么办
  • 在产品,半成品,产成品是什么意思
  • 标准的现金流量表格式
  • 进项增值税发票怎么认证
  • 超市购物卡开票可以做账吗
  • 专票和普票的税率哪个高
  • 总分类账的账簿启用表怎么填
  • 幼儿园收的餐费必须与食谱做平账怎么调账
  • 公司购买投影仪的必要性?
  • 长期待摊费用的内容和特征
  • 展厅设计费用计什么科目
  • 生产企业成本会计科目
  • 发票未到的费用怎么处理
  • 接受捐赠物品的增值税
  • 转让二手宾馆需要注意事项
  • win7如何禁用wifi
  • 期货公司保证金怎么算的
  • 美团提现手续费入哪个会计科目
  • 增值税专用发票有几联?
  • win10蓝牙无法连接,有解决方法吗
  • oeloader.exe - oeloader是什么进程 有什么用
  • data.dataloader
  • 零售业的进货帐务怎么做
  • 建设期需要流动资金吗
  • vueconfigjs配置proxy 无效
  • 达特穆尔动物园
  • thinkphp post
  • php示例代码大全
  • php使用( )关键字来创建对象
  • Yii2创建多界面主题(Theme)的方法
  • 其他应付款贷方余额表示谁欠谁
  • 承税汇票个人能用吗
  • 交所得税怎么记账
  • mysql中事务的作用
  • 销售旧货和销售使用过的固定资产区别
  • 非税收入票据可以跨年度使用吗
  • 差旅费分摊到各部门
  • 技术服务费可以计入成本吗
  • 利润表的期末余额怎么算出来的
  • 递延所得税的会计核算
  • 其他收益和其他综合收益属于什么科目
  • 应收账款周转率高说明
  • 销售材料购买方会计分录
  • 新建厂房房产证办理流程
  • 普票丢失可以以照片入账么
  • 年度纳税总额包括个税吗
  • 员工垫付的费用怎样记账
  • 电子承兑汇票接收不了怎么办
  • 收到福利费的专用发票
  • 月底计提工资的会计处理
  • 数据库性能优化方法论和最佳实践
  • mysql 启动报错
  • ubuntu选择语言
  • openeuler操作系统安装方法
  • 清华同方笔记本无线网络开关在哪
  • find linux命令详解
  • win10预览版选哪个
  • windows 10预览版
  • 进程cmd.exe
  • centos7 本地yum
  • win10移动版和win10区别
  • javascript到c
  • pythonmatch函数
  • linux的open
  • unity连接
  • js修改css文件
  • Python实现以时间换空间的缓存替换算法
  • javascript怎么样
  • python数据类型详细介绍
  • 临时占地耕地占用税纳税义务发生时间
  • 深圳地税电子税务局
  • 狠抓组织收入工作
  • 山东网上信访投诉平台
  • 济南保安证查询系统
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设