位置: IT常识 - 正文

YOLOv5 6.0/6.1结合ASFF(yolov5 教程)

编辑:rootadmin
YOLOv5 6.0/6.1结合ASFF

推荐整理分享YOLOv5 6.0/6.1结合ASFF(yolov5 教程),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:yolov2结构,yolov2结构,yolov5结构解析,yolov5结构解析,yolov5结构解析,yolov3.cfg,yolov5搭建,yolov5 教程,内容如对您有帮助,希望把文章链接给更多的朋友!

YOLOv5 6.0/6.1结合ASFF

前言

YOLO小白纯干货分享!!!

一、主要修改代码YOLOv5 6.0/6.1结合ASFF(yolov5 教程)

二、使用步骤1. models/common.py:加入要修改的代码, 类ASFFV5 class ASFFV5(nn.Module): class ASFFV5(nn.Module): def __init__(self, level, multiplier=1, rfb=False, vis=False, act_cfg=True): """ ASFF version for YoloV5 only. Since YoloV5 outputs 3 layer of feature maps with different channels which is different than YoloV3 normally, multiplier should be 1, 0.5 which means, the channel of ASFF can be 512, 256, 128 -> multiplier=1 256, 128, 64 -> multiplier=0.5 For even smaller, you gonna need change code manually. """ super(ASFFV5, self).__init__() self.level = level self.dim = [int(1024*multiplier), int(512*multiplier), int(256*multiplier)] #print("dim:",self.dim) self.inter_dim = self.dim[self.level] if level == 0: self.stride_level_1 = Conv(int(512*multiplier), self.inter_dim, 3, 2) #print(self.dim) self.stride_level_2 = Conv(int(256*multiplier), self.inter_dim, 3, 2) self.expand = Conv(self.inter_dim, int( 1024*multiplier), 3, 1) elif level == 1: self.compress_level_0 = Conv( int(1024*multiplier), self.inter_dim, 1, 1) self.stride_level_2 = Conv( int(256*multiplier), self.inter_dim, 3, 2) self.expand = Conv(self.inter_dim, int(512*multiplier), 3, 1) elif level == 2: self.compress_level_0 = Conv( int(1024*multiplier), self.inter_dim, 1, 1) self.compress_level_1 = Conv( int(512*multiplier), self.inter_dim, 1, 1) self.expand = Conv(self.inter_dim, int( 256*multiplier), 3, 1) # when adding rfb, we use half number of channels to save memory compress_c = 8 if rfb else 16 self.weight_level_0 = Conv( self.inter_dim, compress_c, 1, 1) self.weight_level_1 = Conv( self.inter_dim, compress_c, 1, 1) self.weight_level_2 = Conv( self.inter_dim, compress_c, 1, 1) self.weight_levels = Conv( compress_c*3, 3, 1, 1) self.vis = vis def forward(self, x_level_0, x_level_1, x_level_2): #s,m,l """ # 128, 256, 512 512, 256, 128 from small -> large """ # print('x_level_0: ', x_level_0.shape) # print('x_level_1: ', x_level_1.shape) # print('x_level_2: ', x_level_2.shape) x_level_0=x[2] x_level_1=x[1] x_level_2=x[0] if self.level == 0: level_0_resized = x_level_0 level_1_resized = self.stride_level_1(x_level_1) level_2_downsampled_inter = F.max_pool2d( x_level_2, 3, stride=2, padding=1) level_2_resized = self.stride_level_2(level_2_downsampled_inter) #print('X——level_0: ', level_2_downsampled_inter.shape) elif self.level == 1: level_0_compressed = self.compress_level_0(x_level_0) level_0_resized = F.interpolate( level_0_compressed, scale_factor=2, mode='nearest') level_1_resized = x_level_1 level_2_resized = self.stride_level_2(x_level_2) elif self.level == 2: level_0_compressed = self.compress_level_0(x_level_0) level_0_resized = F.interpolate( level_0_compressed, scale_factor=4, mode='nearest') x_level_1_compressed = self.compress_level_1(x_level_1) level_1_resized = F.interpolate( x_level_1_compressed, scale_factor=2, mode='nearest') level_2_resized = x_level_2 # print('level: {}, l1_resized: {}, l2_resized: {}'.format(self.level, # level_1_resized.shape, level_2_resized.shape)) level_0_weight_v = self.weight_level_0(level_0_resized) level_1_weight_v = self.weight_level_1(level_1_resized) level_2_weight_v = self.weight_level_2(level_2_resized) # print('level_0_weight_v: ', level_0_weight_v.shape) # print('level_1_weight_v: ', level_1_weight_v.shape) # print('level_2_weight_v: ', level_2_weight_v.shape) levels_weight_v = torch.cat( (level_0_weight_v, level_1_weight_v, level_2_weight_v), 1) levels_weight = self.weight_levels(levels_weight_v) levels_weight = F.softmax(levels_weight, dim=1) fused_out_reduced = level_0_resized * levels_weight[:, 0:1, :, :] +\ level_1_resized * levels_weight[:, 1:2, :, :] +\ level_2_resized * levels_weight[:, 2:, :, :] out = self.expand(fused_out_reduced) if self.vis: return out, levels_weight, fused_out_reduced.sum(dim=1) else: return out2. models/yolo.py:添加 类ASFF_Detect

然后在yolo.py 中 Detect 类下面,添加一个ASFF_Detect类

class ASFF_Detect(nn.Module): #add ASFFV5 layer and Rfb stride = None # strides computed during build export = False # onnx export def __init__(self, nc=80, anchors=(), multiplier=0.5,rfb=False,ch=()): # detection layer super(ASFF_Detect, self).__init__() self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor self.nl = len(anchors) # number of detection layers self.na = len(anchors[0]) // 2 # number of anchors self.grid = [torch.zeros(1)] * self.nl # init grid self.l0_fusion = ASFFV5(level=0, multiplier=multiplier,rfb=rfb) self.l1_fusion = ASFFV5(level=1, multiplier=multiplier,rfb=rfb) self.l2_fusion = ASFFV5(level=2, multiplier=multiplier,rfb=rfb) a = torch.tensor(anchors).float().view(self.nl, -1, 2) self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv

接着在 yolo.py的parse_model 中把函数放到模型的代码里: (大概在283行左右)

if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,CBAM,ResBlock_CBAM, C3]: c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) args = [c1, c2, *args[1:]] if m in [BottleneckCSP, C3]: args.insert(2, n) # number of repeats n = 1 elif m is nn.BatchNorm2d: args = [ch[f]] elif m is Concat: c2 = sum([ch[x] for x in f]) elif m is ASFF_Detect: args.append([ch[x] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) elif m is Contract: c2 = ch[f] * args[0] ** 2 elif m is Expand: c2 = ch[f] // args[0] ** 2 elif m is ASFFV5: c2=args[1] else: c2 = ch[f]3.models/yolov5s-asff.yaml

在models文件夹下新建对应的yolov5s-asff.yaml 文件 然后将yolov5s.yaml的内容复制过来,将 head 部分的最后一行进行修改; 将[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) ] 修改成下面:

[[17, 20, 23], 1, ASFF_Detect, [nc, anchors]], # Detect(P3, P4, P5) ]4.查看网络结构

修改 models/yolo.py --cfg models/yolov5s-asff.yaml 接下来run yolo.py 即可查看网络结构

5.将train.py 中 --cfg中的 yaml 文件修改成本文文件即可,开始训练总结

本人在多个数据集上做了大量实验,针对不同的数据集效果不同,需要大家进行实验。有效果有提升的情况占大多数。

最后,希望能互粉一下,做个朋友,一起学习交流。

本文链接地址:https://www.jiuchutong.com/zhishi/300716.html 转载请保留说明!

上一篇:2022年微信小程序授权登录的最新实现方案(2022年微信小程序游戏)

下一篇:【windows Server 2019系列】 构建IIS服务器(windowsserver2012r2远程协助灰色)

  • qq改实名认证了游戏也会改吗(QQ改实名认证了为什么王者还登不上)

    qq改实名认证了游戏也会改吗(QQ改实名认证了为什么王者还登不上)

  • 聊天界面有个耳朵(聊天时出现一个耳朵是怎么设置?)

    聊天界面有个耳朵(聊天时出现一个耳朵是怎么设置?)

  • 改qq签名不提醒其他人(改qq签名不提醒怎么回事)

    改qq签名不提醒其他人(改qq签名不提醒怎么回事)

  • sai内存使用量不足怎么解决(sai 内存容量不足)

    sai内存使用量不足怎么解决(sai 内存容量不足)

  • 软件授权服务报告无法激活计算机(软件授权服务报告许可证评估失败)

    软件授权服务报告无法激活计算机(软件授权服务报告许可证评估失败)

  • 公众号专辑怎么用(公众号专辑怎么删除)

    公众号专辑怎么用(公众号专辑怎么删除)

  • 为啥流量卡开不了微信(流量卡开不起流量怎么回事)

    为啥流量卡开不了微信(流量卡开不起流量怎么回事)

  • nova7pro是双扬声器吗(nova7pro支持双扬声器)

    nova7pro是双扬声器吗(nova7pro支持双扬声器)

  • 红米8pro上市时间(红米8 pro多少钱)

    红米8pro上市时间(红米8 pro多少钱)

  • p40刷新率多少(p40pro手机刷新率)

    p40刷新率多少(p40pro手机刷新率)

  • 电话有回音是被监听了吗(电话有回音是被监控了吗)

    电话有回音是被监听了吗(电话有回音是被监控了吗)

  • 手机手写不了怎么回事(手机手写不好使)

    手机手写不了怎么回事(手机手写不好使)

  • 华为agsw09是什么型号(agsw09是华为哪一款)

    华为agsw09是什么型号(agsw09是华为哪一款)

  • 手机电话卡怎么装(手机电话卡怎么换)

    手机电话卡怎么装(手机电话卡怎么换)

  • id地址是什么(picacgid地址是什么)

    id地址是什么(picacgid地址是什么)

  • gif怎么渐变(动态渐变图)

    gif怎么渐变(动态渐变图)

  • word邮件合并全部记录(word2020邮件合并)

    word邮件合并全部记录(word2020邮件合并)

  • 使用微信支付的步骤(使用微信支付的国家)

    使用微信支付的步骤(使用微信支付的国家)

  • 荣耀9xpro怎么截屏

    荣耀9xpro怎么截屏

  • 简述软件危机的具体表现(简述软件危机的产生原因)

    简述软件危机的具体表现(简述软件危机的产生原因)

  • 快手本地作品删了怎么找回(删除快手本地作品快手作品还在吗)

    快手本地作品删了怎么找回(删除快手本地作品快手作品还在吗)

  • synchronic和diachronic的区别(synchronic和diachronic的读音)

    synchronic和diachronic的区别(synchronic和diachronic的读音)

  • 陀螺仪灵敏度怎么调(陀螺仪灵敏度怎么调好压枪)

    陀螺仪灵敏度怎么调(陀螺仪灵敏度怎么调好压枪)

  • 微博怎么调夜间模式(微博怎么设置夜间模式2020)

    微博怎么调夜间模式(微博怎么设置夜间模式2020)

  • relx充电多长时间(relx充一次电能用多久)

    relx充电多长时间(relx充一次电能用多久)

  • vue3项目使用样式穿透修改elementUI默认样式(vue3.0用法)

    vue3项目使用样式穿透修改elementUI默认样式(vue3.0用法)

  • 广告牌制作加盟厂家
  • 企业缴纳季度所得税
  • 承兑贴现几个点是月息还是年息
  • 多少金额以下可以一次性费用
  • 代缴水电费如何做账
  • 劳保费属于什么会计科目
  • 未开具发票负数的原因
  • 固定资产折旧年限的最新规定2022
  • 事业单位大型修缮会计分录
  • 当月发票未收到怎么办
  • 管理费用现金流量表中属于
  • 车辆购置税的会计处理
  • 房地产开发企业销售自行开发的房地产项目
  • 核定征收的小型微利企业
  • 小规模企业资本结构
  • 采购材料差旅费怎么入账
  • 两年前少缴的税款是否应补缴?
  • 公司如何为员工缴纳社保
  • 开启网络共享后怎么使用
  • 银行承兑汇票贴现率是多少
  • 网络限速数值
  • 认缴出资额和实缴出资额的时间
  • 应纳消费税包不包括代收代缴
  • php字符串函数有哪些
  • 公司融资a轮说明什么
  • 计算应缴房产税的公式
  • 公司租赁办公室要注意什么
  • php和mysql的联合使用
  • php判断https
  • php url函数
  • 视频监控接入方式有哪几种
  • 【AI大比拼】文心一言 VS ChatGPT-4
  • js生成随机数字和字母组合
  • 西安微信公众号开发
  • linux扫描命令
  • 公司购买空调属于电子设备吗
  • 注册资本与注册资金、出资额的区别
  • 残疾人就业保障金怎么申报
  • 住宿费补贴
  • 印花税不减免
  • 买新车检测
  • 典当行借款合同需交印花税吗
  • 处置资产增值税纳税义务发生时间
  • 商业银行提取的盈余公积可用于
  • 收到的稳岗补贴是否需要交税
  • 公司注销前的发票怎么查
  • 交纳增值税的账务处理PPT
  • 社保补缴收滞纳金吗
  • 损益类科目没有结平是什么意思
  • 什么是无形资产包括哪些
  • 红字发票可以跨月入账吗
  • 先开票后发货是什么意思
  • 发票作废了还能恢复吗?
  • 所得税汇算清缴前取得跨年发票
  • sql经常用的语句
  • 电脑bios找不到硬盘怎么办
  • ubuntu 安装zsh
  • 在linux系统中安装软件
  • centos ssh permission denied
  • window10的dns异常
  • diskgenius_winpe文件夹能删吗
  • PACKAGER.EXE - PACKAGER是什么进程 有什么用
  • 在linux中使用什么命令可以给命令起别名
  • win10系统如何关闭
  • 在Linux系统中安装pacman
  • win8.1怎么样
  • 解决Extjs4中form表单提交后无法进入success函数问题
  • cocos2dx4.0入门
  • cocos2d rpg
  • json查询语句
  • OnApplicationFocus
  • node.js常用命令
  • jquery中如何获取元素?
  • Unity uGui RawImage 渲染小地图
  • Android 自定义view炫酷动画
  • javascript例题
  • 计提税金的公式
  • 小规模纳税人房土两税优惠政策
  • 贵州省增值税普通发票图片
  • 物业收取水电费的通知范文
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设