位置: IT常识 - 正文

yolov5源码解析(9)--输出(yolov4源码解读)

编辑:rootadmin
yolov5源码解析(9)--输出

推荐整理分享yolov5源码解析(9)--输出(yolov4源码解读),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:yolo v1 pytorch源代码,yolo 源码,yolov2源码,yolov5源码解析修改,yolo 源码,yolov4源码解读,yolov5源码解读,yolov3源码详解,内容如对您有帮助,希望把文章链接给更多的朋友!

本文章基于yolov5-6.2版本。主要讲解的是yolov5是怎么在最终的特征图上得出物体边框、置信度、物体分类的。

一。总体框架

首先贴出总体框架,直接就拿官方文档的图了,本文就是接着右侧的那三层输出开始讨论。

Backbone: New CSP-Darknet53Neck: SPPF, New CSP-PANHead: YOLOv3 Head

这三个输出层分别就是浅、中、深层啦,浅层特征图分辨率是80乘80,中层是40乘40,深层是20乘20,一般来说浅层用于预测小物体,深层用于预测大物体。另外说明一下,浅、中、深三层的特征图输出通道数不一定是256、512、1024,要看你用的是哪一种规格的模型。比如yolov5s的话,那这三层的通道数分别是128、256、512,可以导出onnx格式用Netron看一下模型结构来确定。

 

简要说一下原因,这个是由对应的模型配置文件,即models目录里的yolov5s.yaml,yolov5m.yaml等等来决定的,看你用哪一个,第二个红框里的就是每一层的输出通道数了,但是它是要乘上第一个红框里的值的,即width_multiple这个配置,你会发现几个模型配置文件的内容都差不多,区别就区别在这里的depth_multiple和width_multiple。

 二。输出物体边框、置信度、物体分类

接下来进入正题,每层特征图最终都会经过1乘1卷积,变成(5+分类数)乘3个通道:

0)首先为什么乘以3,因为每一层都有3个anchor,后面再细讲

下面讲的是每一anchor对应的(5+分类数)个通道,假设分类数为2,那一共就是7个通道了,这7个通道分别是xywh(4个通道),置信度(1个通道),分类(此处2分类,就是2个通道)

1)物体边框的4个值,x,y,w,h啦,不过这个x,y并不直接是物体框中心点的坐标,而是它相对于自身所处的格子左上角的偏移,比如下图红色的这个格子(假设现在特征图就是4乘4),这个格子预测出7个值,前4个就是xywh,然后x是0.2,y是0.2,那么中心点就差不多在蓝点所处的位置了(其实这其中还有玄机,一步步来)。然后再把这个中心点的相对值作用到原图的尺度得到最终的坐标。

但是呢如果像上面这样直接预测一个相对格子左上角的偏移的这样一个值呢,会比较不稳定,它可能预测的值很大,比如x给你预测一个10出来,那就是往右数10个格子了,偏差这么大不利用网络收敛,也没有意义,因为这个格子里的特征跟右边第10个格子的特征相差可能很大了。

所以要加一个限制,首先给它sigmoid一下,这样其值范围就变成0-1了(小数),此时它的波动就在自己的这个格子内,然后乘以2再减0.5,如下图(直接拿官方文档的图了~~)

 

 这样它的波动范围就是下图的黄框的范围。

yolov5源码解析(9)--输出(yolov4源码解读)

 限制为0-1好理解,自己这个格子的预测范围就在自己格子内麻,为啥又变成了-0.5-1.5呢,因为这样更容易得到0-1范围内的值。如果的范围限制为0-1,而且是用sigmoid来限制的话,那接近0和1这两个位置的导数就会很小,梯度更新的时候就会慢。

然后就是宽高,宽高也不是直接预测出物体边框的宽高啦,而是基于anchor的,预测出来的值会乘上anchor的宽高得出最终的宽高,并且,这里仍然是先用sigmoid将输出值限制为0-1,然后再乘以2,再来个平方,这样最终的值的范围就是0-4了。

之前说了每一层有3个anchor,这些anchor还是配置在模型的配置文件里的,比如models/yolov5s.yaml,P3就是浅层的(80乘80的格子),P4是中层的(40乘40),P5是深层的(20乘20),然后这里的anchor的大小呢就是绝对值(按照640乘640的图来算的,如果你的输入图不是640乘640,那输入图是会resize一下再进行推理的)

比如现在是深层的输出,2分类,那么深层的特征图经过最后的1乘1卷积后,会得到3乘(5+2)=21个通道,每7个通道就对应一个anchor了,现在看第2个7个通道(即7-13,从0开始算),那么它对应的anchor就应该是156,198这个,那么预测出来的宽高值经过sigmoid,再乘2,再平方之后,还分别要乘上156和198,得出最终的物体宽高(基于640乘640的图的),然后再按比例得到原图的物体宽高。

2)置信度

代表预测出的物体边框和分类的可信度,最终的范围肯定是0-1了(小数),跟前面的一样,会用sigmoid来把它的范围限制为0-1。

这边可能有一个问题,那个xy不是sigmoid()乘2减0.5吗,这里咋不这么干,那是因为xy的值真的是可以达到-0.5或1.5的,那样的话就变成预测的物体中心点跑到相邻格子里去了,这也不是不行的啦。但置信度只能是0-1!

3)分类

有几个分类,就会再加几个通道,分别代表对应分类的概率,都是用sigmoid把他们的概率限制为0-1,在计算损失的时候,标签对应分类所在通道的直值为1,其它都为0了,然后分别计算BCE损失。

三。源码

最终输出层的相关源码主要就是models/yolo.py的Detect类的源码了,添加了相应的注释。

class Detect(nn.Module): stride = None # strides computed during build onnx_dynamic = False # ONNX export parameter export = False # export mode def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer super().__init__() self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor self.nl = len(anchors) # number of detection layers self.na = len(anchors[0]) // 2 # number of anchors,除以2是因为[10,13, 16,30, 33,23]这个长度是6,对应3个anchor self.grid = [torch.zeros(1)] * self.nl # init grid,下面会计算grid,grid就是每个格子的x,y坐标(整数,比如0-19) self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2),注意后面就可以通过self.anchors来访问它了 self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv,3个输出层最后的1乘1卷积 self.inplace = inplace # use inplace ops (e.g. slice assignment) def forward(self, x): z = [] # inference output for i in range(self.nl): # 三个输出层分别处理 x[i] = self.m[i](x[i]) # conv,经过这个1乘1卷积就变成(5+分类数)个通道了 bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)--这里的85对应coco数据集,5+80个分类 x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if not self.training: # inference if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) y = x[i].sigmoid() if self.inplace: # 这里的grid[i]即对应输出层的3个anchor层的每个格子的坐标,方便进行批量计算,乘上对应的stride[i](下采样率),就得到基于640乘640的图的坐标了 y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy # anchor_grid[i]也是一样,不过它的形状是(1, self.na, 1, 1, 2),跟y[..., 2:4]计算时是会自动广播的,最终得到的宽高也是基于640乘640的图的宽高 y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 # 这段是非inplace操作,计算方法是一样的 xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0 xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh y = torch.cat((xy, wh, conf), 4) z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x) def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')): d = self.anchors[i].device t = self.anchors[i].dtype shape = 1, self.na, ny, nx, 2 # grid shape # grid其实就是特征图网络的坐标,比如20乘20的,其坐标分别是0,0 0,1...0,19 1,0 1,1...19,19,第2个维度na就是anchor数啦 y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t) if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility yv, xv = torch.meshgrid(y, x, indexing='ij') else: yv, xv = torch.meshgrid(y, x) # 注意这边先给它把0.5给减了 grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5 # anchor_grid即每个格子对应的anchor宽高,stride是下采样率,三层分别是8,16,32,这里为啥要乘呢,因为在外面已经把anchors给除了对应的下采样率,这里再乘回来 anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape) return grid, anchor_grid

此处单独说一下torch.meshgrid,它其实就是用于得到网格坐标的,简化代码如下,假设现在是2乘2的网络

y, x = torch.arange(2), torch.arange(2)yv, xv = torch.meshgrid(y, x, indexing='ij')print(f'yv={yv}')print(f'xv={xv}')grid = torch.stack((xv, yv), 2)print(f'grid={grid}')

 输出如下

 grid对应的就是如下图,得到这个网络坐标就可以直接跟输出层的x,y做批量运算了。

 四。NMS

Detect类foward之后确实是整个网络最终的输出,不过这个输出还得再经过NMS,提取出最终的答案,即这张图上到底有几个物体,边框、置信度、分类分别是什么。NMS后面再讨论~~

下一篇:

yolov5源码解析(10)--损失计算与anchor_扫地僧1234的博客-CSDN博客

本文链接地址:https://www.jiuchutong.com/zhishi/299638.html 转载请保留说明!

上一篇:【目标检测】YOLOv5模型从大变小,发生了什么?(目标检测yolo)

下一篇:过年回家,你是否也努力的给别人解释软件开发是干啥滴?(过年回家的你)

  • 一加9防水等级(一加9防水等级多少)

    一加9防水等级(一加9防水等级多少)

  • 红米k40Pro多少级亮度调节(红米k40pro最高版本)

    红米k40Pro多少级亮度调节(红米k40pro最高版本)

  • 华为手机如何给部分照片加密(华为手机如何给手表反向充电)

    华为手机如何给部分照片加密(华为手机如何给手表反向充电)

  • 抖音上能私聊吗(抖音上能私聊吗?)

    抖音上能私聊吗(抖音上能私聊吗?)

  • 收藏的视频删除了怎么恢复(我的收藏视频删除了怎么恢复)

    收藏的视频删除了怎么恢复(我的收藏视频删除了怎么恢复)

  • 自己微信昵称怎么加特殊符号(自己微信昵称怎么社置的)

    自己微信昵称怎么加特殊符号(自己微信昵称怎么社置的)

  • 蓝牙耳机对哪里说话(蓝牙耳机对哪里有要求)

    蓝牙耳机对哪里说话(蓝牙耳机对哪里有要求)

  • 天猫精灵是充电用还是一直插电源(天猫精灵充电款)

    天猫精灵是充电用还是一直插电源(天猫精灵充电款)

  • 门户网站是什么意思(门户网站是什么新媒体类型)

    门户网站是什么意思(门户网站是什么新媒体类型)

  • ie缓存异常怎么修复(ie缓存异常怎么修复穿越火线)

    ie缓存异常怎么修复(ie缓存异常怎么修复穿越火线)

  • 华为笔记本蓝屏重启不了(华为笔记本蓝屏怎么回事)

    华为笔记本蓝屏重启不了(华为笔记本蓝屏怎么回事)

  • 手机音孔里有灰尘怎么办(手机声音孔有灰怎么办)

    手机音孔里有灰尘怎么办(手机声音孔有灰怎么办)

  • 500万像素的分辨率是多少(500万像素的分辨率是几K)

    500万像素的分辨率是多少(500万像素的分辨率是几K)

  • 微信收藏夹缓存可以清理吗(微信收藏夹缓存删除了,收藏夹东西还在不在?)

    微信收藏夹缓存可以清理吗(微信收藏夹缓存删除了,收藏夹东西还在不在?)

  • 微信怎么发接龙格式(微信怎么发接龙表格)

    微信怎么发接龙格式(微信怎么发接龙表格)

  • 电脑半个月不关机会不会坏(电脑半个月不关机会怎么样)

    电脑半个月不关机会不会坏(电脑半个月不关机会怎么样)

  • 抖音举报后有什么后果(抖音举报后什么时候封号)

    抖音举报后有什么后果(抖音举报后什么时候封号)

  • 安卓微信怎么指纹支付(安卓手机微信指纹锁怎么设置)

    安卓微信怎么指纹支付(安卓手机微信指纹锁怎么设置)

  • 华为mate30输入法在哪里设置(华为mate30输入法怎么换行)

    华为mate30输入法在哪里设置(华为mate30输入法怎么换行)

  • 拼多多sku改动有影响吗(拼多多修改商品sku会降权吗)

    拼多多sku改动有影响吗(拼多多修改商品sku会降权吗)

  • 人声鼎沸的意思是什么(人声鼎沸的意思是什么(最佳答案))

    人声鼎沸的意思是什么(人声鼎沸的意思是什么(最佳答案))

  • ipadpro电池寿命怎么查(ipadpro电池寿命90能用多久)

    ipadpro电池寿命怎么查(ipadpro电池寿命90能用多久)

  • 抖音没有互相关注可以私信吗(抖音没有互相关注的人会有浏览记录吗)

    抖音没有互相关注可以私信吗(抖音没有互相关注的人会有浏览记录吗)

  • 三星充电器和华为通用吗(三星充电器和华为充电器一样吗)

    三星充电器和华为通用吗(三星充电器和华为充电器一样吗)

  • 手机储存内存影响速度吗(手机内存影响性能吗)

    手机储存内存影响速度吗(手机内存影响性能吗)

  • 苹果无线耳机华为手机可以用吗(苹果无线耳机华为手机怎么连接)

    苹果无线耳机华为手机可以用吗(苹果无线耳机华为手机怎么连接)

  • 一加7pro是什么牌子(一加7Pro是什么手机)

    一加7pro是什么牌子(一加7Pro是什么手机)

  • 苹果权限管理在哪里(iphone里的权限管理在哪)

    苹果权限管理在哪里(iphone里的权限管理在哪)

  • gfxacc.exe是什么进程 作用是什么 gfxacc进程查询(chcfg.exe是什么)

    gfxacc.exe是什么进程 作用是什么 gfxacc进程查询(chcfg.exe是什么)

  • 国际反避税措施
  • 支付电费未开具发票
  • 外币报表折算差额是一种未实现的汇兑损益
  • 用于维修安装服装的材料
  • 个人独资企业有章程没有
  • 连锁药店总部的首营资料
  • 什么叫税控盘清卡
  • 房产税按原值计算公式
  • 开出增值税普通发票需要交税吗
  • 通信服务费可以取消吗
  • 固定资产出售税务处理方法
  • 注册资金怎么提出来
  • 工会工费缴纳标准
  • 跨年多计提折旧的账务处理
  • 稻谷增值税税率多少
  • 投资利润率的计算结果不受建设期的长短
  • 税务缴纳滞纳金处罚依据
  • 如何获取windows最高权限
  • 腾讯电脑管家中的软件市场打不开
  • 成本核算的意义是什么
  • linux网络接口状态命令
  • 期末结转之前有哪些注意事项
  • 土地开发公司是国企吗
  • 善意取得增值税专用发票 企业所得税
  • vue写css
  • 受托加工要交消费税吗
  • 房地产企业如何计算土地使用税
  • Jetson Xavier NX配置全过程——安装jtop和OpenCV4.5.3(二)
  • 安装winsock
  • 应收账款为负数正常吗为什么
  • php json_encode与json_decode详解及实例
  • 进口付汇和出口收汇
  • 结转销售成本的方法
  • Python中tkinter的 Variable类
  • 权益法的比例是多少
  • 存根联明细是自动生成
  • 单位之间的争议由谁处理
  • 支付招聘网站费用怎么入账
  • mysql中的外键的作用
  • 增值税专用发票几个点
  • 小规模纳税人企业所得税2023
  • 应交税费余额是什么意思
  • 可明确区分的商品什么意思
  • 建筑公司包工包料提供建筑服务
  • 季度报税都是几月份
  • 来料加工的账务处理新收入准则
  • 研发费用账务调整合同怎么写
  • 固定资产公司
  • 应收账款科目如何核算
  • 公司办公室买的茶叶怎么入账
  • 申购费从哪里扣
  • 长期待摊费用的最新账务处理
  • 员工房屋租赁合同
  • 贷款罚息会计分录
  • 公对公房产过户
  • 收到商业承兑汇票的会计分录
  • 安卓系统强制竖屏
  • win7提示盗版怎样激活
  • macbook appstore在哪
  • linux获取操作命令的使用方法或参数选项内容
  • 四步制作的花
  • Tips(1)glewExperimental
  • 前端面试题及答案2023vue3
  • 嗌中怎么读
  • python结巴分词
  • js date对象构造方法
  • unity shader可视化编辑
  • jquery从左到右渐渐显示
  • node.js使用教程
  • android解析xml的方法中,将整个文件加载到内存
  • js数组添加元素的方法
  • 夜间模式图
  • python编程基础语法
  • 电子税务局开的发票怎么作废
  • 广州地税局官网办事点
  • 上缴财政总额是什么意思
  • 代理记账公司简介模板范文
  • 国家税务总局关于新型墙体材料增值税政策的通知
  • 大东地税局
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设