位置: IT常识 - 正文

Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

编辑:rootadmin
Pytorch复习笔记--导出Onnx模型为动态输入和静态输入

目录

1--动态输入和静态输入

2--Pytorch API

3--完整代码演示

4--模型可视化

5--测试动态导出的Onnx模型


1--动态输入和静态输入

推荐整理分享Pytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch基础,pytorch技巧,pytorch技巧,pytorch基础,pytorch零基础入门,pytorch 快速入门,pytorch基础教程,pytorch基础教程,内容如对您有帮助,希望把文章链接给更多的朋友!

        当使用 Pytorch 将网络导出为 Onnx 模型格式时,可以导出为动态输入和静态输入两种方式。动态输入即模型输入数据的部分维度是动态的,可以由用户在使用模型时自主设定;静态输入即模型输入数据的维度是静态的,不能够改变,当用户使用模型时只能输入指定维度的数据进行推理。

        显然,动态输入的通用性比静态输入更强。

2--Pytorch APIPytorch复习笔记--导出Onnx模型为动态输入和静态输入(pytorch基础教程)

        在 Pytorch 中,通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入,dynamic_axes 的默认值为 None,即默认为静态输入。

        以下展示动态导出的用法,通过定义 dynamic_axes 参数来设置动态导出输入。dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值;

# 导出为动态输入input_name = 'input'output_name = 'output'torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})3--完整代码演示

        在以下代码中,定义了一个网络,并使用动态导出和静态导出两种方式,将网络导出为 Onnx 模型格式。

import torchimport torch.nn as nnclass Model_Net(nn.Module): def __init__(self): super(Model_Net, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ) def forward(self, data): data = self.layer1(data) return dataif __name__ == "__main__": # 设置输入参数 Batch_size = 8 Channel = 3 Height = 256 Width = 256 input_data = torch.rand((Batch_size, Channel, Height, Width)) # 实例化模型 model = Model_Net() # 导出为静态输入 input_name = 'input' output_name = 'output' torch.onnx.export(model, input_data, "Static_InputNet.onnx", verbose=True, input_names=[input_name], output_names=[output_name]) # 导出为动态输入 torch.onnx.export(model, input_data, "Dynamics_InputNet.onnx", opset_version=11, input_names=[input_name], output_names=[output_name], dynamic_axes={ input_name: {0: 'batch_size', 2: 'input_height', 3: 'input_width'}, output_name: {0: 'batch_size', 2: 'output_height', 3: 'output_width'}})4--模型可视化

        通过 netron 库可视化导出的静态模型和动态模型,代码如下:

import netronnetron.start("./Dynamics_InputNet.onnx")

        静态模型可视化:

         动态模型可视化:

5--测试动态导出的Onnx模型import numpy as npimport onnximport onnxruntimeif __name__ == "__main__": input_data1 = np.random.rand(4, 3, 256, 256).astype(np.float32) input_data2 = np.random.rand(8, 3, 512, 512).astype(np.float32) # 导入 Onnx 模型 Onnx_file = "./Dynamics_InputNet.onnx" Model = onnx.load(Onnx_file) onnx.checker.check_model(Model) # 验证Onnx模型是否准确 # 使用 onnxruntime 推理 model = onnxruntime.InferenceSession(Onnx_file, providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) input_name = model.get_inputs()[0].name output_name = model.get_outputs()[0].name output1 = model.run([output_name], {input_name:input_data1}) output2 = model.run([output_name], {input_name:input_data2}) print('output1.shape: ', np.squeeze(np.array(output1), 0).shape) print('output2.shape: ', np.squeeze(np.array(output2), 0).shape)

         由输出结果可知,对应动态输入 Onnx 模型,其输出维度也是动态的,并且为对应关系,则表明导出的 Onnx 模型无误。

本文链接地址:https://www.jiuchutong.com/zhishi/295970.html 转载请保留说明!

上一篇:vue中使用wangeditor富文本编辑器(vue中使用require报错)

下一篇:图文详解vue.js devtools插件使用方法(图文详解一本通)

  • 怎么免费让短视频播放量达到100万?(怎样免费视频)

    怎么免费让短视频播放量达到100万?(怎样免费视频)

  • QQ群怎么设置群禁言(QQ群怎么设置群员不能改群名)

    QQ群怎么设置群禁言(QQ群怎么设置群员不能改群名)

  • ipad2021会支持二代笔吗(ipad 2020支持二代笔吗)

    ipad2021会支持二代笔吗(ipad 2020支持二代笔吗)

  • office是什么意思(offer是什么意思)

    office是什么意思(offer是什么意思)

  • 屏幕出现黄斑伤内屏了吗(屏幕出现黄斑伤内屏了能修复吗)

    屏幕出现黄斑伤内屏了吗(屏幕出现黄斑伤内屏了能修复吗)

  • 开启勿扰模式别人打电话提示什么(开启勿扰模式别人打电话进来会有显示吗)

    开启勿扰模式别人打电话提示什么(开启勿扰模式别人打电话进来会有显示吗)

  • cpu分为哪几种类型(cpu有什么分类)

    cpu分为哪几种类型(cpu有什么分类)

  • 计算机什么是承载CPU(计算机什么是承载CPU、BIOS和内存等器件的部分)

    计算机什么是承载CPU(计算机什么是承载CPU、BIOS和内存等器件的部分)

  • 小米电脑键盘打不出字怎么回事(小米电脑键盘打字要打两遍)

    小米电脑键盘打不出字怎么回事(小米电脑键盘打字要打两遍)

  • 华为怎么设置闹钟铃声本地音乐(华为怎么设置闹钟在耳机里响)

    华为怎么设置闹钟铃声本地音乐(华为怎么设置闹钟在耳机里响)

  • 支付宝被对方拉黑怎么转账(支付宝被对方拉黑了怎么办怎么恢复)

    支付宝被对方拉黑怎么转账(支付宝被对方拉黑了怎么办怎么恢复)

  • 下标怎么打快捷键(快捷键打下标)

    下标怎么打快捷键(快捷键打下标)

  • 手机截屏是什么意思(手机截屏是什么格式)

    手机截屏是什么意思(手机截屏是什么格式)

  • 微信通话中断什么意思(接电话时微信怎么没网络)

    微信通话中断什么意思(接电话时微信怎么没网络)

  • 软件更新在哪里(oppo软件更新在哪里)

    软件更新在哪里(oppo软件更新在哪里)

  • ps径向渐变在哪(ps径向渐变在哪里)

    ps径向渐变在哪(ps径向渐变在哪里)

  • 抖音发表的视频怎么删(抖音发表的视频删除了还能找到吗)

    抖音发表的视频怎么删(抖音发表的视频删除了还能找到吗)

  • qq字体怎么变成系统字体(qq字体怎么变成手机系统字体)

    qq字体怎么变成系统字体(qq字体怎么变成手机系统字体)

  • 第一弹是什么样的软件(第一弹是怎么了)

    第一弹是什么样的软件(第一弹是怎么了)

  • ipad air3耳机孔是圆孔吗(ipad air3 3.5mm耳机孔)

    ipad air3耳机孔是圆孔吗(ipad air3 3.5mm耳机孔)

  • 横线字怎么打出来(怎么打横线上面写字)

    横线字怎么打出来(怎么打横线上面写字)

  • 华为nova系列主打什么(华为nova11系列)

    华为nova系列主打什么(华为nova11系列)

  • dedecms织梦会员中心调用会员最后登录时间和IP(织梦收费5800的解决方法)

    dedecms织梦会员中心调用会员最后登录时间和IP(织梦收费5800的解决方法)

  • 员工扣了个税但没交给税务局
  • 当月未抵扣的进项税
  • 附加税填表说明
  • 报税是怎么操作的
  • 代理税务有哪些机构
  • 企业所得税汇算清缴时间
  • 应付账款不需要函证
  • 即征即退的增值税属于政府补助
  • 企业之间交换房屋 契税
  • 转账支票有没有密码
  • 运输服务增值税纳税义务发生时间
  • 30人以上的企业有哪些
  • 企业所得税季报是全年累计吗
  • 建筑企业预缴增值税计算
  • 个体工商户多久不用自动注销
  • 收入成本以前年度损益调整账务处理是怎样的?
  • 计提的增值税比例怎么算
  • 国外汇款 用什么理由
  • 印花税销售分录
  • 公户转账给个人没有票
  • 国税2016年第53号公告解读
  • 应交税金增值税明细账怎么登记
  • 股东借款利息计入利润表哪个科目
  • 跨年度残保金退回做什么
  • 一般纳税人简易计税会计分录
  • 股东认缴和实缴不一致
  • 支付商品展览费计入
  • 当期应交所得税怎么计算
  • 个体户开劳务费发票需要交哪些税
  • 行纪合同的效力
  • 先发货后开票的销售业务流程
  • 收到税务汇算清缴怎么办
  • 固定资产增值税税率
  • thinkphp yii
  • 劳务报酬所得与经营所得
  • 企业签订的技术合同
  • win7纯净版系统怎么安装
  • 手机忘记密码怎么解开锁华为
  • 出口货物免抵退税额的计算方法
  • vue企业开发实战
  • zip 压缩命令
  • python去掉文本的指定符号
  • 防伪税控系统该如何操作
  • 小规模纳税人0申报汇算清缴
  • 帝国cms栏目分类
  • 如何修改php网页内容
  • 织梦栏目描述调用
  • 软件使用权怎么入账
  • 小规模纳税人开具增值税专用发票
  • 企业期末预收账款怎么算
  • 建筑工程租赁费属于什么费用
  • 未开发票如何确认收入并进行申报?
  • 金税四期是什么意思
  • 软件研发的整个流程
  • 农产品开具发票税率是多少?
  • 企业代扣税费会计分录
  • 建筑劳务公司的会计账务处理
  • 辅助生产成本如何结转
  • 开票6个点怎么计算
  • 统计得到的一组数据有80个
  • ultraiso刻录音乐到dvd
  • ubuntu20.0安装
  • linux中添加用户和组的操作
  • vnc for linux
  • centos安装nmtui
  • gwsystemservice.exe是什么进程 有什么作用 gwsystemservice进程查询
  • win7 txt文件属性更改
  • 基于stm32的100个毕业设计
  • cocos2djs教程
  • javascript语言中,以下关于array
  • python爬虫系统
  • js实现拖拽div的弹出框
  • 内蒙古电子税务局app官方下载
  • 怎么查询开票信息呢
  • 税务局冬季作息时间
  • 国家税务总局网站官网浙江
  • 国家税务总局上海市电子税务局
  • 10%加计抵减政策条件
  • 贵州省微企补助政策
  • 地税服务大厅上班时间
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设