位置: IT常识 - 正文

基于chatGPT设计卷积神经网络

编辑:rootadmin
基于chatGPT设计卷积神经网络 1. 简介

推荐整理分享基于chatGPT设计卷积神经网络,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

本文主要介绍基于chatGPT,设计一个针对骁龙855芯片设计的友好型神经网络。

提问->跑通总共花了5min左右,最终得到的网络在Cifar100数据集上与ResNet18的精度对比如下。

模型flopsparamstrain acc1/5test acc1/5ResNet18(timm)1.8211.18~98%/100%63.11%/86.11%CNN-GPT1.516.66~98%/100%68.44%/89.14%

此外,GPT生成的模型训练速度更快,ResNet18才训练170个epoch,CNN-GPT已经完整地训练完了,具体的模型结构在文末,欢迎尝试使用~

个人认为比较有意思的一点是,GPT生成的模型并没有在stem block中就进行降采样,而是卷了一会之后才开始,这个与之前在文本任务中的经验结果类似,特征图越大,会避免信息损失。 基于chatGPT生成模型结构的方式或许在某些情况下可以解决我们的问题,但是基础的神经网络知识还是需要的,否则它给出的答案有些小bug,我们也不知道怎么修改。 模型训练结果与超参设置强相关,在这里仅设置默认的训练策略作为对比,根据训练结果看出,两个模型均过拟合严重。而且,不管是ResNet18或者CNN-GPT,均与Cifar100数据集的SOTA精度相差很远,此处需要注意。

2. 具体步骤Step1 骁龙855芯片上可以正常运行的算子

基于上述问题,chatGPT给出的回答如下所示。

以下是在骁龙855芯片上比较友好的神经网络算子:1. Conv2d2. Depthwise Conv2d3. Maxpool2d4. Avgpool2d5. BatchNorm2d6. ReLU7. Flatten8. Linear基于chatGPT设计卷积神经网络

整体看起来还算比较靠谱。

Step2 构建一个模型大小在10M左右的模型

针对上面的问题,chatGPT首先给出了思路。

如果想设计一个在骁龙855芯片上运行速度很快,模型大小在10M以内的模型,可以考虑使用以下方法来进行优化:1. 使用深度可分离卷积代替标准卷积,以获得更快的速度和更小的模型尺寸。2. 使用深度方向的滤波器来替换全连接层,以减少参数数量并减少内存使用。3. 对于每个卷积层,使用Batch Norm层来加速收敛。4. 使用ReLU代替传统的Sigmoid或Tanh激活函数,因为ReLU的计算效率更高。

然后给出了一个模型示例。

以下是一个可以作为参考的模型:Conv2d(3, 32, kernel_size=3, stride=1, padding=1)BatchNorm2d(32)ReLU()DepthwiseConv2d(32, kernel_size=3, stride=1, padding=1)BatchNorm2d(32)ReLU()Conv2d(32, 64, kernel_size=3, stride=2, padding=1)BatchNorm2d(64)ReLU()DepthwiseConv2d(64, kernel_size=3, stride=1, padding=1)BatchNorm2d(64)ReLU()Conv2d(64, 128, kernel_size=3, stride=2, padding=1)BatchNorm2d(128)ReLU()DepthwiseConv2d(128, kernel_size=3, stride=1, padding=1)BatchNorm2d(128)ReLU()Conv2d(128, 256, kernel_size=3, stride=2, padding=1)BatchNorm2d(256)ReLU()DepthwiseConv2d(256, kernel_size=3, stride=1, padding=1)BatchNorm2d(256)ReLU()Conv2d(256, 512, kernel_size=3, stride=2, padding=1)BatchNorm2d(512)ReLU()DepthwiseConv2d(512, kernel_size=3, stride=1, padding=1)BatchNorm2d(512)ReLU()Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)BatchNorm2d(1024)ReLU()Avgpool2d(kernel_size=4)Flatten()Linear(1024, 10)

上面生成的模型有几个小问题。

DepthwiseConv2d没有被定义DepthwiseConv2d的参数列表中没有out channel的参数,需要手动补充。Step3 生成一个DepthwiseConv2d结构

chatGPT回答如下

以下是一个非常简单的torch的DepthwiseConv2d实现:

import torch.nn as nnclass DepthwiseConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None): super(DepthwiseConv2d, self).__init__() if groups is None: groups = in_channels self.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups) self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1) def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) return x

加入到模型结构定义文件中,同时修改类别数,就可以进行模型训练了。

设置训练超参数

为保证对比的公平性,下面是训练过程中的一些超参数,

parser.add_argument("--epoch", type=int, default=240) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--lr", type=float, default=0.05) parser.add_argument("--lr_decay_rate", type=float, default=0.1) parser.add_argument("--lr_decay_stages", type=str, default="150,180,210") parser.add_argument("--wd", type=float, default=0.0005)3. 完整的模型结构

完整的模型结构如下,也欢迎大家在自己的数据集上尝试。

import torchimport torch.nn as nnclass DepthwiseConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=None): super(DepthwiseConv2d, self).__init__() if groups is None: groups = in_channels self.depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups=groups) self.pointwise_conv = nn.Conv2d(out_channels, out_channels, 1, 1, 0, 1, groups=1) def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) return xclass CNNGPT(nn.Module): def __init__(self) -> None: super().__init__() self.model = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), DepthwiseConv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(), DepthwiseConv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.ReLU(), DepthwiseConv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.ReLU(), DepthwiseConv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.ReLU(), DepthwiseConv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(1024), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(1024, 100), ) def forward(self, x): y = self.model(x) return ydef get_flops_params(model): from thop import profile model.eval() flops, params = profile( model, inputs=[ torch.randn([1, 3, 224, 224]), ], ) print(f"flops: {flops/1000**3} G, params: {params/1000**2} M") return flops, paramsif __name__ == "__main__": model = CNNGPT() get_flops_params(model)
本文链接地址:https://www.jiuchutong.com/zhishi/299522.html 转载请保留说明!

上一篇:ElementUI实现在下拉列表里面进行搜索(elementui怎么样)

下一篇:Java实现Token登录验证(基于JWT的token认证实现)(java单点登录token)

  • 华为平板如何连接投影仪(华为平板如何连接热点)

    华为平板如何连接投影仪(华为平板如何连接热点)

  • 打印机不打印怎么回事(打印机不打印怎么回事显示脱机)

    打印机不打印怎么回事(打印机不打印怎么回事显示脱机)

  • 为什么苹果慢动作屏闪(为什么苹果慢动作视频发出去不是慢动作)

    为什么苹果慢动作屏闪(为什么苹果慢动作视频发出去不是慢动作)

  • 华为p30pro电池不耐用(华为p30pro电池不耐用是更换电池还是用背夹电池)

    华为p30pro电池不耐用(华为p30pro电池不耐用是更换电池还是用背夹电池)

  • 微机中主机主要包括什么之外的设备(微机中主机主要包括以下什么的设备)

    微机中主机主要包括什么之外的设备(微机中主机主要包括以下什么的设备)

  • vivo新机充电须知(vivo新手机充电需要关机吗)

    vivo新机充电须知(vivo新手机充电需要关机吗)

  • 按home键有咔咔响声(home键一按咔咔响震动)

    按home键有咔咔响声(home键一按咔咔响震动)

  • 红米儿童模式不见了(红米手机儿童锁如何设置)

    红米儿童模式不见了(红米手机儿童锁如何设置)

  • 微信群里的人能看到我的朋友圈吗(微信群里的人能私聊吗)

    微信群里的人能看到我的朋友圈吗(微信群里的人能私聊吗)

  • 为什么苹果11信号不稳定(为什么苹果11信号下面有四个点)

    为什么苹果11信号不稳定(为什么苹果11信号下面有四个点)

  • 惠普的平板模式可以触屏吗(惠普的平板模式怎么用)

    惠普的平板模式可以触屏吗(惠普的平板模式怎么用)

  • 三星w20和fold有什么区别(三星w2020和三星fold2对比视频)

    三星w20和fold有什么区别(三星w2020和三星fold2对比视频)

  • 电脑主机亮红灯是什么原因(电脑主机亮红灯显示器无信号)

    电脑主机亮红灯是什么原因(电脑主机亮红灯显示器无信号)

  • b站里ul怎么升级(b站ul等级怎么升级)

    b站里ul怎么升级(b站ul等级怎么升级)

  • iphonexsmax电池多少毫安(iphonexsmax电池多少hm)

    iphonexsmax电池多少毫安(iphonexsmax电池多少hm)

  • 苹果快充必须响两声吗(苹果手机快充是不是都会响二声)

    苹果快充必须响两声吗(苹果手机快充是不是都会响二声)

  • 苹果下载的字体怎么在手机中使用(苹果下载的字体怎么使用)

    苹果下载的字体怎么在手机中使用(苹果下载的字体怎么使用)

  • pr渲染是啥意思(pr渲染是啥意思呀)

    pr渲染是啥意思(pr渲染是啥意思呀)

  • 苹果手机电话本怎么导入卡里(苹果手机电话本联系人不见了怎么恢复)

    苹果手机电话本怎么导入卡里(苹果手机电话本联系人不见了怎么恢复)

  • 怎么投诉滴滴出行司机(如何投诉滴滴打车)

    怎么投诉滴滴出行司机(如何投诉滴滴打车)

  • 京东安全中心如何绑定银行卡(京东安全中心如何解绑)

    京东安全中心如何绑定银行卡(京东安全中心如何解绑)

  • 百度经验怎么发(百度经验怎么发私信)

    百度经验怎么发(百度经验怎么发私信)

  • 华为p系列手机有几款(华为p系列手机有5g吗)

    华为p系列手机有几款(华为p系列手机有5g吗)

  • 国美如何用易卡支付(国美怎么用易卡支付)

    国美如何用易卡支付(国美怎么用易卡支付)

  • 【新星计划】如何写好你的博客,涨粉技巧总结(新星计划会限流吗)

    【新星计划】如何写好你的博客,涨粉技巧总结(新星计划会限流吗)

  • 微服务网关Gateway实践总结(微服务网关作用)

    微服务网关Gateway实践总结(微服务网关作用)

  • 免抵退税办法不得抵扣的进项税额
  • 平价转让股权如何做
  • 小规模纳税人到底能不能收专票
  • 政府会计制度等式
  • 购买土地自建厂房,土地怎样摊销
  • 年终奖可以税前扣除嘛
  • 租赁企业可以开具电费发票吗
  • 固定资产清理期末余额在借方是什么意思
  • 企业年末要写财务报告吗
  • 货没出,发票先开
  • 普通发票作废要交税吗
  • 申报工资总额填错了
  • 公司组织旅游费用报个税吗
  • 所得税汇算清缴分录怎么做
  • 自然人税务系统扣缴客户端怎么申报
  • 营改增后计税依据
  • 支付增值税税控系统技术维护费用500元
  • 增值税留抵税额抵减欠税
  • 劳务发票税率营改增后是多少?
  • 委托加工农产品的扣除率
  • 劳务公司开的3%专用发票能抵扣吗
  • 没有缴纳契税
  • 金税盘的年费怎么做抵扣帐
  • 企业收入代扣税金应该怎么做会计处理?
  • 回盘的模板
  • 资产负债表资产总额在哪
  • 环境保护税法是什么意思
  • 个人所得汇算清缴是什么
  • 收到员工水电费怎么做账
  • 进项票未认证怎么入账
  • 土地价款扣除会计分录
  • 园林绿化工程公司电话
  • 技术服务合同要盖骑缝章吗
  • widows11预览版
  • kb4586853更新
  • windows7旗舰版最新版本
  • 织梦采集器图片本地化
  • 对方开过来的专票丢了怎么办
  • 劳务费发票和建筑劳务费发票
  • 汽车客运补充车道标志
  • 海岸边上
  • lstm输出多个预测值
  • bash的主要功能
  • 收取的标书收入如何入账
  • 本地住宿费能报销吗
  • 管理不善造成的货物损失进项税可以抵扣吗
  • 个体工商户开税户
  • 策划费多少钱
  • 去年的进项发票可以抵今年的吗
  • 固定资产可收回金额怎么算
  • 劳务费合同模版
  • SQLite数据库安装及基本操作指南
  • 动态sql怎么执行
  • pygame编程
  • 小规模纳税人劳保用品可以抵扣吗
  • 企业所得税季末资产总额
  • 独立核算的分公司可以享受六税两费减半吗
  • 转让不动产与销售不符
  • 广告费应归于什么费用
  • 从政府取得土地使用权缴增值税吗
  • 一般纳税人会计分录
  • 加计扣除声明怎么填
  • 电子承兑汇票接收不了怎么办
  • sql语句实现查询示例
  • sql server 不存在或者拒绝访问
  • linux下mysql开启远程访问权限 防火墙开放3306端口
  • windows组合快捷键
  • u盘ghost装机
  • 如何禁止windows7自动更新
  • winxp任务栏消失
  • mac expose
  • 如何关闭危险端口
  • 零基础玩转python
  • node.js用途
  • 利用python绘图
  • vue实现数据导出为excel
  • js中tolocalestring
  • AssetBundle.Unload(false)的作用
  • 稽查局积案清理工作成效显著
  • 电力企业所得税按期
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设