位置: IT常识 - 正文

基于chatGPT设计卷积神经网络

编辑:rootadmin
基于chatGPT设计卷积神经网络 1. 简介

推荐整理分享基于chatGPT设计卷积神经网络,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。

提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。

模型flopsparamstrain acc1/5test acc1/5ResNet18(timm)1.8211.18~98%/100%63.11%/86.11%CNN-GPT1.516.66~98%/100%68.44%/89.14%

此外,GPT生成的模型训练速度更快,ResNet18才训练170个epoch,CNN-GPT已经完整地训练完了,具体的模型结构在文末,欢迎尝试使用~

个人认为比较有意思的一点是,GPT生成的模型并没有在stem block中就进行降采样,而是卷了一会之后才开始,这个与之前在文本任务中的经验结果类似,特征图越大,会避免信息损失。 基于chatGPT生成模型结构的方式或许在某些情况下可以解决我们的问题,但是基础的神经网络知识还是需要的,否则它给出的答案有些小bug,我们也不知道怎么修改。 模型训练结果与超参设置强相关,在这里仅设置默认的训练策略作为对比,根据训练结果看出,两个模型均过拟合严重。而且,不管是ResNet18或者CNN-GPT,均与Cifar100数据集的SOTA精度相差很远,此处需要注意。

2. 具体步骤Step1 骁龙855芯片上可以正常运行的算子

基于上述问题,chatGPT给出的回答如下所示。

以下是在骁龙855芯片上比较友好的神经网络算子:1. Conv2d2. Depthwise Conv2d3. Maxpool2d4. Avgpool2d5. BatchNorm2d6. ReLU7. Flatten8. Linear基于chatGPT设计卷积神经网络

整体看起来还算比较靠谱。

Step2 构建一个模型大小在10M左右的模型

针对上面的问题,chatGPT首先给出了思路。

如果想设计一个在骁龙855芯片上运行速度很快,模型大小在10M以内的模型,可以考虑使用以下方法来进行优化:1. 使用深度可分离卷积代替标准卷积,以获得更快的速度和更小的模型尺寸。2. 使用深度方向的滤波器来替换全连接层,以减少参数数量并减少内存使用。3. 对于每个卷积层,使用Batch Norm层来加速收敛。4. 使用ReLU代替传统的Sigmoid或Tanh激活函数,因为ReLU的计算效率更高。

然后给出了一个模型示例。

以下是一个可以作为参考的模型:Conv2d(3, 32, kernel_size=3, stride=1, padding=1)BatchNorm2d(32)ReLU()DepthwiseConv2d(32, kernel_size=3, stride=1, padding=1)BatchNorm2d(32)ReLU()Conv2d(32, 64, kernel_size=3, stride=2, padding=1)BatchNorm2d(64)ReLU()DepthwiseConv2d(64, kernel_size=3, stride=1, padding=1)BatchNorm2d(64)ReLU()Conv2d(64, 128, kernel_size=3, stride=2, padding=1)BatchNorm2d(128)ReLU()DepthwiseConv2d(128, kernel_size=3, stride=1, padding=1)BatchNorm2d(128)ReLU()Conv2d(128, 256, kernel_size=3, stride=2, padding=1)BatchNorm2d(256)ReLU()DepthwiseConv2d(256, kernel_size=3, stride=1, padding=1)BatchNorm2d(256)ReLU()Conv2d(256, 512, kernel_size=3, stride=2, padding=1)BatchNorm2d(512)ReLU()DepthwiseConv2d(512, kernel_size=3, stride=1, padding=1)BatchNorm2d(512)ReLU()Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)BatchNorm2d(1024)ReLU()Avgpool2d(kernel_size=4)Flatten()Linear(1024, 10)

上面生成的模型有几个小问题。

DepthwiseConv2d没有被定义DepthwiseConv2d的参数列表中没有out channel的参数,需要手动补充。Step3 生成一个DepthwiseConv2d结构

chatGPT回答如下

以下是一个非常简单的torch的DepthwiseConv2d实现:

import torch.nn as nnclass DepthwiseConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None): super(DepthwiseConv2d, self).__init__() if groups is None: groups = in_channels self.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups) self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1) def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) return x

加入到模型结构定义文件中,同时修改类别数,就可以进行模型训练了。

设置训练超参数

为保证对比的公平性,下面是训练过程中的一些超参数,

parser.add_argument("--epoch", type=int, default=240) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--lr", type=float, default=0.05) parser.add_argument("--lr_decay_rate", type=float, default=0.1) parser.add_argument("--lr_decay_stages", type=str, default="150,180,210") parser.add_argument("--wd", type=float, default=0.0005)3. 完整的模型结构

完整的模型结构如下,也欢迎大家在自己的数据集上尝试。

import torchimport torch.nn as nnclass DepthwiseConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None): super(DepthwiseConv2d, self).__init__() if groups is None: groups = in_channels self.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups) self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1) def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) return xclass CNNGPT(nn.Module): def __init__(self) -> None: super().__init__() self.model = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), DepthwiseConv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), DepthwiseConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(), DepthwiseConv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(), DepthwiseConv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(), DepthwiseConv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(1024), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(1024, 100), ) def forward(self, x): y = self.model(x) return ydef get_flops_params(model): from thop import profile model.eval() flops, params = profile( model, inputs=[ torch.randn([1, 3, 224, 224]), ], ) print(f"flops: {flops/1000**3} G, params: {params/1000**2} M") return flops, paramsif __name__ == "__main__": model = CNNGPT() get_flops_params(model)
本文链接地址:https://www.jiuchutong.com/zhishi/299522.html 转载请保留说明!

上一篇:ElementUI实现在下拉列表里面进行搜索(elementui怎么样)

下一篇:Java实现Token登录验证(基于JWT的token认证实现)(java单点登录token)

  • 华为nova7pro能支持前后双景录像吗(华为nova7pro可以刷门禁卡吗)

    华为nova7pro能支持前后双景录像吗(华为nova7pro可以刷门禁卡吗)

  • windows7一键还原在哪(windows7一键还原软件)

    windows7一键还原在哪(windows7一键还原软件)

  • 机顶盒是路由器吗(网络盒子和机顶盒区别)

    机顶盒是路由器吗(网络盒子和机顶盒区别)

  • 华为p30遥控器功能在哪(华为p30遥控器功能不见了)

    华为p30遥控器功能在哪(华为p30遥控器功能不见了)

  • hd5750相当于什么显卡(hd5750 5770)

    hd5750相当于什么显卡(hd5750 5770)

  • 电脑截图截不了是怎么回事(电脑截图截不了qq聊天记录)

    电脑截图截不了是怎么回事(电脑截图截不了qq聊天记录)

  • rockchip是什么设备(rockchip rga)

    rockchip是什么设备(rockchip rga)

  • 网易账号被锁定啥意思(网易账号被锁定或其他原因终止换号)

    网易账号被锁定啥意思(网易账号被锁定或其他原因终止换号)

  • 手机qq软件打不开了怎么办(手机qq软件打不开怎么回事)

    手机qq软件打不开了怎么办(手机qq软件打不开怎么回事)

  • 英特尔630核显相当于什么显卡(英特尔 630核显)

    英特尔630核显相当于什么显卡(英特尔 630核显)

  • 苹果6电信版能用移动卡吗(苹果6电信版能插卡吗)

    苹果6电信版能用移动卡吗(苹果6电信版能插卡吗)

  • a1691苹果6s是什么版(iphone 6s a1699)

    a1691苹果6s是什么版(iphone 6s a1699)

  • 拼多多揽件超时罚款多少(拼多多揽件超时是12小时还是24小时)

    拼多多揽件超时罚款多少(拼多多揽件超时是12小时还是24小时)

  • 华为手机激活后能退吗(华为手机激活后还能退吗)

    华为手机激活后能退吗(华为手机激活后还能退吗)

  • 显卡是用来连接主机和什么的接口电路(显卡是跟啥连接的)

    显卡是用来连接主机和什么的接口电路(显卡是跟啥连接的)

  • 滴滴车主可以异地接单吗(滴滴车主可以异地注册回当地出车吗)

    滴滴车主可以异地接单吗(滴滴车主可以异地注册回当地出车吗)

  • 户户通怎么重新安装(户户通怎么重新开通)

    户户通怎么重新安装(户户通怎么重新开通)

  • 小米9怎么开启wlan热点(小米9怎么开启无线反向充电)

    小米9怎么开启wlan热点(小米9怎么开启无线反向充电)

  • 怎么复制快手视频教程(怎么复制快手视频链接)

    怎么复制快手视频教程(怎么复制快手视频链接)

  • 苹果x显示edge怎么解决(苹果X显示edge怎么解决)

    苹果x显示edge怎么解决(苹果X显示edge怎么解决)

  • mix2发布时间(mix2啥时候出的)

    mix2发布时间(mix2啥时候出的)

  • 华为p30呼叫等待怎么设置(华为p30呼叫等待怎么取消)

    华为p30呼叫等待怎么设置(华为p30呼叫等待怎么取消)

  • iphonex底下的横线隐藏(iphonex底部横条关闭了怎么打开)

    iphonex底下的横线隐藏(iphonex底部横条关闭了怎么打开)

  • Win10系统瞬间释放系统资源的方法(win10一键释放内存)

    Win10系统瞬间释放系统资源的方法(win10一键释放内存)

  • 20分钟,使用Amazon SageMaker快速搭建属于自己的AIGC应用

    20分钟,使用Amazon SageMaker快速搭建属于自己的AIGC应用

  • 用WP Downgrade降级WordPress版本(wndgui降级)

    用WP Downgrade降级WordPress版本(wndgui降级)

  • 小规模纳税人一般纳税人区别
  • 应税劳务销售额怎么算
  • 车辆使用费属于什么税收分类
  • 餐补如何入账
  • 主办会计跟会计有什么区别
  • 一次性开票分期确认收入的税务文件是什么
  • 合伙律师事务所的合伙人必须是
  • 差旅费处理会计分录
  • 现金预算属于营业预算吗
  • 记账软件收费吗
  • 营业成本在资产负债表哪里
  • 个税累计扣除项目合计2712.5是什么
  • 年末结转增值税之后增值税有余额嘛
  • 计提资产减值准备金的纳税调整
  • 筹建期间有收入怎么办
  • 联营和合营的区别
  • 合作社 注销
  • 股权收购溢价的原因
  • 出口退税的原则是什么
  • 固定资产销售税务处理
  • 采购比价流程图
  • 固定资产加速折旧方法
  • 未分配利润怎么处理
  • 视同销售如何纳税调整?
  • 去年盈利今年亏损怎么表达
  • 劳务公司的主营业务成本是人工工资需要个人开发票吗?
  • 支付境外咨询费代扣代缴增值税
  • 设备固定资产原值
  • 2020 php 薪资
  • 发票已经认证,发现错误了要怎么处理
  • 提前预支工资怎么报税
  • Win11 Build 22000.348更新补丁KB5007262预览版发布(附更新修复内容汇总)
  • ctfmon.exe文件在哪
  • php的运算符主要包括哪些?
  • 计划成本核算的账务处理
  • 收到现金股利或债券利息时会计分录
  • 收到借款利息会计科目
  • 房地产开发企业增值税怎么算
  • vue3加ts
  • yolov5m
  • 子公司破产对母公司是利好还是利空
  • SM1、SM2、SM3、SM4、同态加密、密态计算、隐私计算和安全多方计算的概念
  • 机器学习中的数学——距离定义(八):余弦距离(Cosine Distance)
  • map获取entry
  • 电子客票行程单怎么获取
  • 固定资产转为投资性房地产属于会计政策变更
  • python模块的扩展名
  • 织梦怎么添加相关
  • 公司注销了就可以不发工资么
  • mysql的服务器
  • 物业公司物业费计入什么会计科目
  • 员工自己全额承担社保费用要计提吗
  • 什么企业的应急预案需要备案
  • 折旧啥意思
  • 电子发票可以作为保修凭证吗
  • 土地租金计入什么会计科目
  • 小规模纳税人未开票收入如何申报增值税
  • 以旧换新价格怎么确认
  • 收到客户账款会计分录
  • 固定资产 金额
  • 小规模公司开票税率是多少
  • 教育机构的咨询
  • macbookpro隐藏
  • 如何利用excel的数据制作图表
  • bootcamp您的磁盘未能分区
  • 根据自己的需求配置一台电脑
  • win7注册表详解
  • linux中vi命令详解
  • mobile window
  • 360安全卫士检测出来高危漏洞需要修复吗
  • node session
  • 支持google play
  • jquery和dom对象之间怎么转换
  • 文章标题用几号字体
  • python基本用法
  • js页面点击怎么随机生成图片
  • bootstrap需要学多久
  • 大米可以抵扣增值税吗
  • 晋江电视
  • 增值税发票开具后多长时间可在查询平台查询到?
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设