位置: IT常识 - 正文

YOLOv5源码逐行超详细注释与解读(6)——网络结构(1)yolo.py(yolov1 实现)

编辑:rootadmin
YOLOv5源码逐行超详细注释与解读(6)——网络结构(1)yolo.py

推荐整理分享YOLOv5源码逐行超详细注释与解读(6)——网络结构(1)yolo.py(yolov1 实现),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:yolov3源码,yolo 源码,yolo 源码,yolov5实现,yolov5实现,yolo 源码,yolov2源码,yolov3源码,内容如对您有帮助,希望把文章链接给更多的朋友!

前言

在上一篇中,我们简单介绍了YOLOv5的配置文件之一 yolov5s.yaml,这个文件中涉及很多参数,它们的调用会在这篇 yolo.py 和下一篇 common.py 中具体实现。

本篇我们会介绍 yolo.py,这是YOLO的特定模块,和网络构建有关。在 YOLOv5源码中,模型的建立是依靠 yolo.py 中的函数和对象完成的,这个文件主要由三个部分:parse_model函数、Detect类和Model类组成。

yolo.py文件位置在./models/yolo.py

文章代码逐行手打注释,每个模块都有对应讲解,一文帮你梳理整个代码逻辑!

友情提示:全文4万字,可以先点再慢慢看哦~

源码下载地址:mirrors / ultralytics / yolov5 · GitCode

   🍀本人YOLOv5源码详解系列:  

YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析

​​​​​​YOLOv5源码逐行超详细注释与解读(2)——推理部分detect.py

YOLOv5源码逐行超详细注释与解读(3)——训练部分train.py

YOLOv5源码逐行超详细注释与解读(4)——验证部分val(test).py

YOLOv5源码逐行超详细注释与解读(5)——配置文件yolov5s.yaml

YOLOv5源码逐行超详细注释与解读(7)——网络结构(2)common.py 

目录

前言

🚀一、 导包和基本配置

1.1 导入安装好的python库 

1.2 获取当前文件的绝对路径

1.3 加载自定义模块

🚀二、parse_model函数

2.1 获取对应参数

2.2 搭建网络前准备

2.3 更新当前层的参数,计算c2

2.4 使用当前层的参数搭建当前层

2.5 打印和保存layers 

🚀 三、Detect模块

3.1 获取预测得到的参数

3.2 向前传播

3.3 相对坐标转换到grid绝对坐标系

🚀四、Model类

4.1 __init__函数

4.2 数据增强相关函数

4.2.1 forward():管理前向传播函数

4.2.2 _forward_augment():推理的forward

4.2.3 _forward_once():训练的forward

4.2.4 _descale_pred():将推理结果恢复到原图尺寸

4.2.5 _clip_augmented():TTA的时候对原图片进行裁剪

4.2.6 _profile_one_layer():打印日志信息

4.2.7 _initialize_biases():初始化偏置biases信息

4.2.8 _print_biases():打印偏置biases信息

4.2.9 fuse():将Conv2d+BN进行融合

4.2.10 autoshape():扩展模型功能

4.2.11 info():打印模型结构信息

4.2.12 _apply():将模块转移到 CPU/ GPU上

🚀五、yolo.py全部注释

🚀一、 导包和基本配置1.1 导入安装好的python库 '''======================1.导入安装好的python库====================='''import argparse # 解析命令行参数模块import sys # sys系统模块 包含了与Python解释器和它的环境有关的函数from copy import deepcopy # 数据拷贝模块 深拷贝from pathlib import Path # Path将str转换为Path对象 使字符串路径易于操作的模块

首先,导入一下常用的python库:

argparse:  它是一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息sys: 它是与python解释器交互的一个接口,该模块提供对解释器使用或维护的一些变量的访问和获取,它提供了许多函数和变量来处理 Python 运行时环境的不同部分copy:  Python 中赋值语句不复制对象,而是在目标和对象之间创建绑定关系。copy模块提供了通用的浅层复制和深层复制操作pathlib:  这个库提供了一种面向对象的方式来与文件系统交互,可以让代码更简洁、更易读1.2 获取当前文件的绝对路径'''===================2.获取当前文件的绝对路径========================'''FILE = Path(__file__).resolve() # __file__指的是当前文件(即val.py),FILE最终保存着当前文件的绝对路径,比如D://yolov5/modles/yolo.pyROOT = FILE.parents[1] # YOLOv5 root directory 保存着当前项目的父目录,比如 D://yolov5if str(ROOT) not in sys.path: # sys.path即当前python环境可以运行的路径,假如当前项目不在该路径中,就无法运行其中的模块,所以就需要加载路径 sys.path.append(str(ROOT)) # add ROOT to PATH 把ROOT添加到运行路径上# ROOT = ROOT.relative_to(Path.cwd()) # relative ROOT设置为相对路径

这段代码会获取当前文件的绝对路径,并使用Path库将其转换为Path对象。

这一部分的主要作用有两个:

将当前项目添加到系统路径上,以使得项目中的模块可以调用。将当前项目的相对路径保存在ROOT中,便于寻找项目中的文件。1.3 加载自定义模块'''===================3..加载自定义模块============================'''from models.common import * # yolov5的网络结构(yolov5)from models.experimental import * # 导入在线下载模块from utils.autoanchor import check_anchor_order # 导入检查anchors合法性的函数from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args # 定义了一些常用的工具函数from utils.plots import feature_visualization # 定义了Annotator类,可以在图像上绘制矩形框和标注信息from utils.torch_utils import (copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync) # 定义了一些与PyTorch有关的工具函数# 导入thop包 用于计算FLOPstry: import thop # for FLOPs computationexcept ImportError: thop = None

这些都是用户自定义的库,由于上一步已经把路径加载上了,所以现在可以导入,这个顺序不可以调换。具体来说,代码从如下几个文件中导入了部分函数和类:

models.common:  这个是yolov5的网络结构models.experimental:  实验性质的代码,包括MixConv2d、跨层权重Sum等utils.autoanchor:  定义了自动生成锚框的方法utils.general:  定义了一些常用的工具函数,比如检查文件是否存在、检查图像大小是否符合要求、打印命令行参数等等utils.plots:    定义了Annotator类,可以在图像上绘制矩形框和标注信息utils.torch_utils:   定义了一些与PyTorch有关的工具函数,比如选择设备、同步时间等

通过导入这些模块,可以更方便地进行目标检测的相关任务,并且减少了代码的复杂度和冗余。

🚀二、parse_model函数YOLOv5源码逐行超详细注释与解读(6)——网络结构(1)yolo.py(yolov1 实现)

parse_model函数用在DetectionModel模块中,主要作用是解析模型yaml的模块,通过读取yaml文件中的配置,并且到common.py中找到相对于的模块,然后组成一个完整的模型解析模型文件(字典形式),并搭建网络结构。简单来说,就是把yaml文件中的网络结构实例化成对应的模型。后续如果需要动模型框架的话,需要对这个函数做相应的改动。

2.1 获取对应参数def parse_model(d, ch): # model_dict, input_channels(3) '''===================1. 获取对应参数============================''' # 使用 logging 模块输出列标签 LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}") # 获取anchors,nc,depth_multiple,width_multiple anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'] # na: 每组先验框包含的先验框数 na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors # no: na * 属性数 (5 + 分类数) no = na * (nc + 5) # number of outputs = anchors * (classes + 5)

这段代码主要是获取配置dict里面的参数,并打印最开始展示的网络结构表的表头。

我们先解释几个参数,d和ch,na和no: 

 d:  yaml 配置文件(字典形式),yolov5s.yaml中的6个元素 + chch:  记录模型每一层的输出channel,初始ch=[3],后面会删除na:  判断anchor的数量no:  根据anchor数量推断的输出维度

这里有一行代码我们上篇YOLOv5源码逐行超详细注释与解读(5)——配置文件yolov5s.yaml就见过了:

这里就是读取了 yaml 文件的相关参数(参数含义忘了的话再看看上篇哦)

 2.2 搭建网络前准备 '''===================2. 搭建网络前准备============================''' # 网络单元列表, 网络输出引用列表, 当前的输出通道数 layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out # 读取 backbone, head 中的网络单元 for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args # 利用 eval 函数, 读取 model 参数对应的类名 如‘Focus’,'Conv'等 m = eval(m) if isinstance(m, str) else m # eval strings # 利用 eval 函数将字符串转换为变量 如‘None’,‘nc’,‘anchors’等 for j, a in enumerate(args): try: args[j] = eval(a) if isinstance(a, str) else a # eval strings except NameError: pass

 这段代码主要是遍历backbone和head的每一层,获取搭建网络前的一系列信息。

 我们还是先解释参数,layers、save和c2:

layers:   保存每一层的层结构save:   记录下所有层结构中from不是-1的层结构序号c2:  保存当前层的输出channel 

然后开始迭代循环backbone与head的配置。for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):中有几个参数

f: from,当前层输入来自哪些层n: number,当前层次数 初定m: module,当前层类别args: 当前层类参数 初定

接着还用到一个函数eval(),主要作用是将字符串当成有效的表达式来求值,并且返回执行的结果。在这里简单来说,就是实现list、dict、tuple与str之间的转化。

2.3 更新当前层的参数,计算c2 '''===================3. 更新当前层的参数,计算c2============================''' # depth gain: 控制深度,如yolov5s: n*0.33 # n: 当前模块的次数(间接控制深度) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain # 当该网络单元的参数含有: 输入通道数, 输出通道数 if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3TR, C3SPP, C3Ghost]: # c1: 当前层的输入channel数; c2: 当前层的输出channel数(初定); ch: 记录着所有层的输出channel数 c1, c2 = ch[f], args[0] # no=75,只有最后一层c2=no,最后一层不用控制宽度,输出channel必须是no if c2 != no: # if not output # width gain: 控制宽度,如yolov5s: c2*0.5; c2: 当前层的最终输出channel数(间接控制宽度) c2 = make_divisible(c2 * gw, 8)

这段代码主要是更新当前层的args,计算c2(当前层的输出channel)

首先网络将C3中的BottleNeck数量乘以模型缩放倍数n*gd控制模块的深度缩放,举个栗子,对于yolo5s来讲,gd为0.33,那么就是n*0.33,也就是把默认的深度缩放为原来的1/3。 

然后将m实例化成同名模块,别看列举了那么多模块,目前只用到Conv,SPP,Focus,C3,nn.Upsample。对于以上的这几种类型的模块,ch是一个用来保存之前所有的模块输出的channle,ch[-1]代表着上一个模块的输出通道。args[0]是默认的输出通道。

这样以来,c1=ch[f]就代表输入通道c1为f指向的层的输出通道,c2=args[0]就代表输出通道c2为yaml的args中的第一个变量。注意,如果输出通道不等于255即Detect层的输出通道, 则将通道数乘上width_multiple,并调整为8的倍数。通过函数make_divisible来实现

make_divisible()代码如下:   

   # 使得X能够被divisor整除     def make_divisible(x, divisor):         return math.ceil(x / divisor) * divisor2.4 使用当前层的参数搭建当前层 '''===================4.使用当前层的参数搭建当前层============================''' # 在初始args的基础上更新,加入当前层的输入channel并更新当前层 # [in_channels, out_channels, *args[1:]] args = [c1, c2, *args[1:]] # 如果当前层是BottleneckCSP/C3/C3TR/C3Ghost/C3x,则需要在args中加入Bottleneck的个数 # [in_channels, out_channels, Bottleneck个数, Bool(shortcut有无标记)] if m in [BottleneckCSP, C3, C3TR, C3Ghost]: # 在第二个位置插入bottleneck个数n args.insert(2, n) # number of repeats # 恢复默认值1 n = 1 # 判断是否是归一化模块 elif m is nn.BatchNorm2d: # BN层只需要返回上一层的输出channel args = [ch[f]] # 判断是否是tensor连接模块 elif m is Concat: # Concat层则将f中所有的输出累加得到这层的输出channel c2 = sum(ch[x] for x in f) # 判断是否是detect模块 elif m is Detect: # 在args中加入三个Detect层的输出channel args.append([ch[x] for x in f]) if isinstance(args[1], int): # number of anchors 几乎不执行 args[1] = [list(range(args[1] * 2))] * len(f) elif m is Contract: # 不怎么用 c2 = ch[f] * args[0] ** 2 elif m is Expand: # 不怎么用 c2 = ch[f] // args[0] ** 2 else: c2 = ch[f] # args不变

这段代码主要是使用当前层的参数搭建当前层。

经过以上处理,args里面保存的前两个参数就是module的输入通道数、输出通道数。只有BottleneckCSP和C3这两种module会根据深度参数n调整该模块的重复迭加次数。

然后进行的是其他几种类型的Module判断:

如果是BN层,只需要返回上一层的输出channel,通道数保持不变。如果是Concat层,则将f中所有的输出累加得到这层的输出channel,f是所有需要拼接层的index,输出通道c2是所有层的和。如果是Detect层,则对应检测头部分,这块下一小节细讲。

Contract和Expand目前未在模型中使用。

 2.5 打印和保存layers  '''===================5.打印和保存layers信息============================''' # m_: 得到当前层的module,将n个模块组合存放到m_里面 m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module # 打印当前层结构的一些基本信息 t = str(m)[8:-2].replace('__main__.', '') # module type # 计算这一层的参数量 np = sum(x.numel() for x in m_.parameters()) # number params m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print # 把所有层结构中的from不是-1的值记下 [6,4,14,10,17,20,23] save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist # 将当前层结构module加入layers中 layers.append(m_) if i == 0: ch = [] # 去除输入channel[3] # 把当前层的输出channel数加入ch ch.append(c2) return nn.Sequential(*layers), sorted(save)

这段代码主要是打印当前层结构的一些基本信息并保存。

把构建的模块保存到layers里,把该层的输出通道数写入ch列表里。待全部循环结束后再构建成模型。

返回值:

return nn.Sequential(*layers):  网络的每一层的层结构return sorted(save):   把所有层结构中from不是-1的值记下 并排序 [4, 6, 10, 14, 17, 20, 23]

至此模型就全部构建完毕了。

下面详细介绍一下各个模块。

🚀 三、Detect模块

Detect 模块是 YOLO 网络模型的最后一层 (对应 yaml 文件最后一行),通过 yaml 文件进行

本文链接地址:https://www.jiuchutong.com/zhishi/300423.html 转载请保留说明!

上一篇:ChatGPT的了解与初体验

下一篇:Pytorch实现GAT(基于PyTorch实现)(pytorch基础)

  • iqoo新系统originos ocean怎么切换

    iqoo新系统originos ocean怎么切换

  • 小米10支持红外遥控吗(小米10支持红外线功能吗)

    小米10支持红外遥控吗(小米10支持红外线功能吗)

  • oppo手机声音太大怎么调整(OPPO手机声音太小)

    oppo手机声音太大怎么调整(OPPO手机声音太小)

  • shadowrocket显示超时

    shadowrocket显示超时

  • icloud已停止响应是什么意思(icloud停止响应怎么办)

    icloud已停止响应是什么意思(icloud停止响应怎么办)

  • 微信语音播放失败怎么办(微信语音播放失败怎么办oppo)

    微信语音播放失败怎么办(微信语音播放失败怎么办oppo)

  • 苹果se有什么特点(苹果se有什么特别功能)

    苹果se有什么特点(苹果se有什么特别功能)

  • 华为手机充电慢了怎么解决(华为手机充电慢怎么办 解决)

    华为手机充电慢了怎么解决(华为手机充电慢怎么办 解决)

  • sony 8000g 8500g区别

    sony 8000g 8500g区别

  • 华为荣耀20有双击亮屏吗(华为荣耀20有双卡双待吗)

    华为荣耀20有双击亮屏吗(华为荣耀20有双卡双待吗)

  • 新换的苹果不显示通讯录怎么办(快速开始 旧iphone没显示)

    新换的苹果不显示通讯录怎么办(快速开始 旧iphone没显示)

  • 苹果微信恢复聊天记录(苹果微信恢复聊天记录指令)

    苹果微信恢复聊天记录(苹果微信恢复聊天记录指令)

  • 手机qq怎么隐藏黄钻(手机qq怎么隐藏图标)

    手机qq怎么隐藏黄钻(手机qq怎么隐藏图标)

  • 小米2s是不是4g手机(小米2s支不支持4g)

    小米2s是不是4g手机(小米2s支不支持4g)

  • 怎样把pdf拆分(怎样把pdf拆分成图片)

    怎样把pdf拆分(怎样把pdf拆分成图片)

  • 苹果xr定位不准确怎么调(苹果xr为什么定位不准确)

    苹果xr定位不准确怎么调(苹果xr为什么定位不准确)

  • 苹果手机暗黑模式怎么开启(苹果手机暗黑模式怎么设置)

    苹果手机暗黑模式怎么开启(苹果手机暗黑模式怎么设置)

  • 无法重新连接所有网络驱动器(无法重新连接所有网络驱动器鼠标键盘)

    无法重新连接所有网络驱动器(无法重新连接所有网络驱动器鼠标键盘)

  • 设置强提醒对方发现吗(强提醒对方会出现什么页面)

    设置强提醒对方发现吗(强提醒对方会出现什么页面)

  • 飞猪能买儿童半价票吗(飞猪购买儿童票怎么下单)

    飞猪能买儿童半价票吗(飞猪购买儿童票怎么下单)

  • 华为mate20照相机的使用方法(华为mate20照相机分辨率多少为好)

    华为mate20照相机的使用方法(华为mate20照相机分辨率多少为好)

  • a1431是苹果几(a1431是苹果几多少钱)

    a1431是苹果几(a1431是苹果几多少钱)

  • 苹果手机中毒有什么表现(iphone 手机中毒)

    苹果手机中毒有什么表现(iphone 手机中毒)

  • 腾讯电视投屏怎么设置(腾讯电视投屏怎么弄)

    腾讯电视投屏怎么设置(腾讯电视投屏怎么弄)

  • QQ音乐界面模糊怎么回事(qq音乐画面)

    QQ音乐界面模糊怎么回事(qq音乐画面)

  • 苹果xs怎么关闭后台程序(苹果xs怎么关闭拍照声音)

    苹果xs怎么关闭后台程序(苹果xs怎么关闭拍照声音)

  • Win10最新kb5007253补丁怎么安装?(win10最新版本22h2激活)

    Win10最新kb5007253补丁怎么安装?(win10最新版本22h2激活)

  • 税务局退回个税手续费会计分录
  • 佣金和手续费支出 纳税调整
  • 工人工资算生产总值吗
  • 会计中罚款属于什么处理
  • 商誉减值是在年报还是半年报
  • 土地成本包含什么
  • 增值税的附加税率是多少
  • 暂估发票一直未收回
  • 一般纳税人差额征税申报表怎么填
  • 收客户款现金折让发票怎么处理
  • 现金日记账支出和收入表格怎么做
  • 物流企业信用评级
  • 高新技术企业在增值税有什么优惠
  • 什么发票可以冲销
  • 前期物业开办费和承接费一样吗
  • 公允价值变动损益会计处理
  • 拆迁补偿款上交财政
  • 收到联营单位投入的设备一台
  • 企业接受捐赠是营业收入吗
  • 员工离职后收取客户钱款
  • 冲红发票金额大怎么办
  • 小规模季报营业税怎么算
  • 税收完税证明是契证吗
  • 公司名下的房产出租需要交哪些税
  • 现金支票丢了
  • 怎么简单快速的辨别是铝是锡
  • 苹果电脑屏幕键盘怎么去除
  • 补交以前年度的城建税会计分录
  • windows 11预览版
  • 惠普2600打印机故障排除
  • 补缴增值税和滞纳税区别
  • 顺流交易逆流交易未实现内部交易损益
  • 高新技术企业取消资格怎么处罚
  • 期间费用为何要摊销
  • php fgets
  • 保险中介市场现状和基本特点
  • 公司清算债权债务如何清理
  • 几款常用的表单设计软件
  • 代码怎么用?
  • 小规模纳税人收入会计分录
  • 非关联企业借款利息扣除
  • php ajax请求
  • 内外参标定
  • opencv dng
  • 毕业设计烦死了
  • java开源二次开发平台
  • php array_walk_recursive 使用自定的函数处理数组中的每一个元素
  • 如何验证工具坐标系
  • 看望职工家属可以计入福利费吗
  • 开通对公账号怎么办理
  • 固定资产年限折旧方法
  • mysql数据库优化及sql调优
  • 加工费如何开增票
  • 单位定期存款如遇利率调整,不论调高调低
  • 电子税务局申报流程
  • 股东投入固定资产怎么做账
  • 新医院会计制度什么时候实施
  • 采购费用属于什么会计分录
  • mysql配置怎么调出来
  • mysql怎么修改my.ini
  • 粘贴板有问题不能粘贴怎么处理
  • windows 08
  • easybcd修复ubuntu
  • 华硕电脑升级win11
  • 极限竞速中心应用程序
  • linux查看进程并杀死
  • 搭建android开发环境时为什么要先安装jdk
  • 编程中的python
  • vue分页组件page
  • win10安装node.js
  • Android Studio cvs 状态颜色
  • nodejs操作mysql数据库
  • javascript获取数据类型
  • 认识iu
  • 如何查发票是否作废
  • 天津市税务总局现任领导
  • 湖北省人民代表大会常务委员会关于深入开展
  • 贵阳市税务电话
  • 北京国税查询发票真伪查询系统
  • 宏酷集团创始人简介
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设