位置: IT常识 - 正文

【目标检测】YOLOv5遇上知识蒸馏(目标检测如何入门)

编辑:rootadmin
【目标检测】YOLOv5遇上知识蒸馏 前言

推荐整理分享【目标检测】YOLOv5遇上知识蒸馏(目标检测如何入门),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:目标检测tricks,目标检测如何入门,目标检测教程,目标检测yolo,目标检测怎么入门,目标检测yolov5,目标检测怎么入门,目标检测yolov5,内容如对您有帮助,希望把文章链接给更多的朋友!

模型压缩方法主要4种:

网络剪枝(Network pruning)稀疏表示(Sparse representation)模型量化(Model quantification)知识蒸馏(Konwledge distillation)

本文主要来研究知识蒸馏的相关知识,并尝试用知识蒸馏的方法对YOLOv5进行改进。

知识蒸馏理论简介概述

知识蒸馏(Knowledge Distillation)由深度学习三巨头Hinton在2015年提出。

论文标题:Distilling the knowledge in a neural network 论文地址:https://arxiv.org/pdf/1503.02531.pdf

“蒸馏”是个化工学科中的术语,本身指的是将液体混合物加热沸腾,使其中沸点较低的组分首先变成蒸气,再冷凝成液体,用来分离混合物。而知识蒸馏的含义和蒸馏本身相似但并不完全相同,知识蒸馏指的是同时训练两个网络,一个较复杂的网络作为教师网络,另一个较简单的网络作为学生网络,将教师网络训练得到的结果提炼出来,用来引导学生网络的结果,从而让学生网络学习得更好。

一个公认前提是小模型相比于大模型更容易陷入局部最优,下图[1]中,中间绿色的椭圆表示小网络模型的收敛空间,红色的椭圆表示大网络模型的收敛空间;如果不用知识蒸馏,直接训练小网络,它只会在绿色椭圆区域收敛,而使用知识蒸馏之后,小网络可以收敛到橙色椭圆区域,收敛到更小的最优点。

软标签

有了上面的概念,自然而然想到的一个问题就是,教师模型如何引导学生模型进行学习。这就涉及到论文中提及的一个概念——软标签(Soft target)

如上图[1]所示,以手写数字识别为例,这是一个10分类任务,左边这幅图是采用硬标签(Hard target),输出独热向量,概率最高的类别为1,其它类别为0;右边这幅图采用的是软标签(Soft target),通过softmax层输出的各类别概率,这样的输出具有更高的信息熵,即包含更多信息量。 教师模型输出软标签,从而指导学生模型学习。

softmax的原始公式是这样:

qi=exp⁡(zi)∑jexp⁡(zj)q_{i}=\frac{\exp \left(z_{i}\right)}{\sum_{j} \exp \left(z_{j}\right)}qi​=∑j​exp(zj​)exp(zi​)​

在论文中,作者对这个公式又加以改进,引入了一个新的温度变量T,公式如下:

qi=exp⁡(zi/T)∑jexp⁡(zj/T)q_{i}=\frac{\exp \left(z_{i} / T\right)}{\sum_{j} \exp \left(z_{j} / T\right)}qi​=∑j​exp(zj​/T)exp(zi​/T)​

加入这个变量,能使各类别之间的输出更均衡,如下图[2]所示,T=1为softmax,但是当T过大时,会发现输出向量会趋于一条直线,因此,T通常取中间较小值。

蒸馏温度

上面引入了一个新的变量温度T,这个T也可以称为蒸馏温度,原论文中给出了关于T的进一步讨论,随着T的增加,信息熵会越来越大,如下图[1]所示:

实际上,温度的高低改变的是Student模型训练过程中对负标签的关注程度。当温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Student模型会相对更多地关注到负标签[1]。

因此,T的取值可以遵循如下策略:

当想从负标签中学到一些信息量的时候,温度T应调高一些当想减少负标签的干扰的时候,温度T应调低一些

需要注意的是,这个T只作用于教师网络和学生网络的蒸馏过程,学生网络正常输出仍使用softmax,即T取值为1,就像蒸馏过程一样,需要先进行升温,将知识蒸馏出来,然后输出的时候要冷却降温(T=1)

知识蒸馏过程

从原理上来讲,知识蒸馏没有想象中那么复杂,其流程如下图[1]所示:

【目标检测】YOLOv5遇上知识蒸馏(目标检测如何入门)

在T下,训练教师网络得到 soft targets1在T下,训练学生网络得到 soft targets2通过 soft targets1 和 soft targets2 得到 distillation loss在温度1下,训练学生网络得到 soft targets3通过 soft targets3 和 ground truth 得到 student loss

通过这五个步骤,就得到了两个损失值 distillation loss 和 student loss,那么训练的整体损失,就是这两个损失值的加权和,公式[2]如下:

注:

这里的蒸馏损失系数乘了一个T2T^2T2 这是由于soft targets产生的梯度大小按照1/T21/T^21/T2进行了缩放,这里需要补充回来α\alphaα应远小于β\betaβ 即需要让知识蒸馏损失权重大一些,否则没有蒸馏效果

后面,论文作者分别做了手写数字识别和声音识别实验,这里主要来看作者在MNIST数据集上的实验结果,结果如下表所示:

10xEnsemble是10个教师模型的平均值,Distilled Single model是Baseline模型经过蒸馏之后的结果,可以看到蒸馏出来的准确率提升了1.9%.

YOLOv5加上知识蒸馏

下面就将知识蒸馏融入到YOLOv5目标检测任务中,使用的是YOLOv5-6.0版本。 相关代码参考自:https://github.com/Adlik/yolov5

代码修改

其实知识蒸馏的想法很简单,在仓库作者的代码版本中,修改的内容也并不多,主要是模型加载和损失计算部分。

下面按照顺序来解读一下修改内容。

首先是train_distillation.py这个文件,通过修改train.py得到。

新增四个参数:

parser.add_argument('--t_weights', type=str, default='./weights/yolov5s.pt', help='initial teacher model weights path')parser.add_argument('--t_cfg', type=str, default='models/yolov5s.yaml', help='teacher model.yaml path')parser.add_argument('--d_output', action='store_true', default=False, help='if true, only distill outputs')parser.add_argument('--d_feature', action='store_true', default=False, help='if true, distill both feature and output layers')

t_weights 教师模型权重,和学生模型加载类似

t_cfg 教师模型配置,和学生模型配置类似

d_output 这个参数写在这里但不起作用,应该是作者调试时用到的参数,默认是只蒸馏结果

d_feature 这个参数默认是关闭,如果开启,蒸馏损失计算将不仅仅是计算两个模型输出的结果,并且中间特征层也会参与计算(不过这个作者没写完整,可能写到一半弃坑了)

模型加载: 这部分需要多加载一个教师模型,相关代码如下:

# Modelcheck_suffix(weights, '.pt') # check weightspretrained = weights.endswith('.pt')if pretrained: with torch_distributed_zero_first(LOCAL_RANK): weights = attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(csd, strict=False) # load LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report# 这里添加加载教师模型 # Teacher model LOGGER.info(f'Loaded teacher model {t_cfg}') # report t_ckpt = torch.load(t_weights, map_location=device) # load checkpoint t_model = Model(t_cfg or t_ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) exclude = ['anchor'] if (t_cfg or hyp.get('anchors')) and not resume else [] # exclude keys csd = t_ckpt['model'].float().state_dict() # checkpoint state_dict as FP32 csd = intersect_dicts(csd, t_model.state_dict(), exclude=exclude) # intersect t_model.load_state_dict(csd, strict=False) # load

损失计算: 这里多了一个d_outputs_loss,也就是计算蒸馏损失

s_loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_sized_outputs_loss = compute_distillation_output_loss(pred, t_pred, model, d_weight=10)loss = d_outputs_loss + s_loss

蒸馏损失在loss.py中进行定义:

def compute_distillation_output_loss(p, t_p, model, d_weight=1): t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor t_lcls, t_lbox, t_lobj = t_ft([0]), t_ft([0]), t_ft([0]) h = model.hyp # hyperparameters red = 'mean' # Loss reduction (sum or mean) if red != "mean": raise NotImplementedError("reduction must be mean in distillation mode!") DboxLoss = nn.MSELoss(reduction="none") DclsLoss = nn.MSELoss(reduction="none") DobjLoss = nn.MSELoss(reduction="none") # per output for i, pi in enumerate(p): # layer index, layer predictions t_pi = t_p[i] t_obj_scale = t_pi[..., 4].sigmoid() # BBox b_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, 4) t_lbox += torch.mean(DboxLoss(pi[..., :4], t_pi[..., :4]) * b_obj_scale) # Class if model.nc > 1: # cls loss (only if multiple classes) c_obj_scale = t_obj_scale.unsqueeze(-1).repeat(1, 1, 1, 1, model.nc) # t_lcls += torch.mean(c_obj_scale * (pi[..., 5:] - t_pi[..., 5:]) ** 2) t_lcls += torch.mean(DclsLoss(pi[..., 5:], t_pi[..., 5:]) * c_obj_scale) # t_lobj += torch.mean(t_obj_scale * (pi[..., 4] - t_pi[..., 4]) ** 2) t_lobj += torch.mean(DobjLoss(pi[..., 4], t_pi[..., 4]) * t_obj_scale) t_lbox *= h['box'] t_lobj *= h['obj'] t_lcls *= h['cls'] # bs = p[0].shape[0] # batch size loss = (t_lobj + t_lbox + t_lcls) * d_weight return loss

因为目标检测和原论文中的分类问题有所区别,并不能直接简单套用原论文提出的soft-target,那么这里的处理方式就是将三个损失(位置损失、目标损失、类别损失)简单粗暴地用MSELoss进行计算,然后蒸馏损失就是这三部分之和。

值得注意的是,理论部分我们提到过,蒸馏损失需要比学生损失的权重更大,因此,这里在计算蒸馏损失中,加入了一个权重d_weight,权重计算时取10.

下面是代码作者给出的一个实验结果:

ModelCompressionstrategyInput size [h, w]mAPval0.5:0.95Pretrain weightyolov5sbaseline[640, 640]37.2pth | onnxyolov5sdistillation[640, 640]39.3pth | onnxyolov5squantization[640, 640]36.5xml | binyolov5sdistillation + quantization[640, 640]38.6xml | bin

他采用的是coco数据集,用yolov5m作为教师模型,yolov5s作为学生模型,表格第二行展示了蒸馏之后的效果,mAP提升了2.1.

实验验证

为了验证蒸馏是否有效,我在VisDrone数据集上进行了实验,训练了100epoch,实验结果如下表所示:

Student ModelTeacher ModelInput size [h, w]mAPtest0.5mAPtest0.5:0.95yolov5m-[640, 640]0.320.181yolov5myolov5m[640, 640]0.3050.163yolov5myolov5x[640, 640]0.3020.161yolov5m-[1280, 1280]0.4480.261yolov5myolov5x[1280, 1280]0.4010.23

结果挺意外的,使用蒸馏训练之后,mAP反而下降了,严重怀疑蒸馏出来的是糟粕😵

结论

知识蒸馏理论上并不复杂,但经过实验,基本判断这玩意理论价值大于应用价值,用来讲故事可以,实际上提升效果非常有限。当然这是我做了有限实验得出的初步结论,如果读者有更好的思路,可以在评论区留言和我讨论。

参考

[1]【论文泛读】 知识蒸馏:Distilling the knowledge in a neural network:https://www.bilibili.com/read/cv16841475 [2]【论文精讲|无废话版】知识蒸馏:https://www.bilibili.com/video/BV1h8411t7SA

本文链接地址:https://www.jiuchutong.com/zhishi/288835.html 转载请保留说明!

上一篇:潘塔纳尔湿地的风信子金刚鹦鹉,巴西 (© David Pattyn/Minden Pictures)(潘塔纳尔湿地的主要成因)

下一篇:最新大麦网抢票脚本-Python实战(最新大麦抢票脚本)

  • 如何快速增加微博粉丝?(如何快速增加微信群聊人数)

    如何快速增加微博粉丝?(如何快速增加微信群聊人数)

  • switch怎么连接蓝牙耳机(switch怎么连接蓝牙)

    switch怎么连接蓝牙耳机(switch怎么连接蓝牙)

  • 财付通怎么关闭财付通支付(财付通怎么关闭在哪里关闭)

    财付通怎么关闭财付通支付(财付通怎么关闭在哪里关闭)

  • 荣耀30青春版如何设置返回键(荣耀30青春版如何升级鸿蒙系统)

    荣耀30青春版如何设置返回键(荣耀30青春版如何升级鸿蒙系统)

  • 华为matebooke触摸板失灵怎么回事(华为matebooke触摸屏怎么开启)

    华为matebooke触摸板失灵怎么回事(华为matebooke触摸屏怎么开启)

  • 摄像头可以保存多久的视频(摄像头可以保存一年吗)

    摄像头可以保存多久的视频(摄像头可以保存一年吗)

  • 快手直播突然不进人了(快手直播突然不推流怎么回事)

    快手直播突然不进人了(快手直播突然不推流怎么回事)

  • ios什么时候支持改微信号(ios什么时候支持多开)

    ios什么时候支持改微信号(ios什么时候支持多开)

  • kindle更新慢(kindle更新速度)

    kindle更新慢(kindle更新速度)

  • 京东会员有几个等级(京东会员有几个账号)

    京东会员有几个等级(京东会员有几个账号)

  • 华为手机为什么不能安装软件了(华为手机为什么变成黑白屏了)

    华为手机为什么不能安装软件了(华为手机为什么变成黑白屏了)

  • 强行换行用什么键(强行换行什么意思)

    强行换行用什么键(强行换行什么意思)

  • 苹果手机爱奇艺看不了视频了怎么办(苹果手机爱奇艺怎么不能横屏播放)

    苹果手机爱奇艺看不了视频了怎么办(苹果手机爱奇艺怎么不能横屏播放)

  • 华为平板m6与m5的对比区别在哪(华为平板m6与matepad区别)

    华为平板m6与m5的对比区别在哪(华为平板m6与matepad区别)

  • 为什么抖音搜不到好友(为什么抖音搜不到店铺位置)

    为什么抖音搜不到好友(为什么抖音搜不到店铺位置)

  • 手机上有一个hd是什么意思(手机上有一个hd2什么意思)

    手机上有一个hd是什么意思(手机上有一个hd2什么意思)

  • flaal20华为是什么型号(华为型号fla-al20是什么手机)

    flaal20华为是什么型号(华为型号fla-al20是什么手机)

  • vivos1怎么没有语音唤醒(vivo语音没有声音)

    vivos1怎么没有语音唤醒(vivo语音没有声音)

  • 无法捕获屏幕截图怎么回事(无法捕获屏幕截图是什么意思)

    无法捕获屏幕截图怎么回事(无法捕获屏幕截图是什么意思)

  • word多出来的一页怎么删(word多出来的一夜如何删除)

    word多出来的一页怎么删(word多出来的一夜如何删除)

  • 电脑拷贝的软件怎么安装(电脑拷贝软件到另一台电脑)

    电脑拷贝的软件怎么安装(电脑拷贝软件到另一台电脑)

  • vivo的运动记步在哪(vivo计步数)

    vivo的运动记步在哪(vivo计步数)

  • 华为p30pro怎么拍星空(华为p30pro怎么拍星空夜景)

    华为p30pro怎么拍星空(华为p30pro怎么拍星空夜景)

  • 公交乘车码可以刷2次吗(公交乘车码可以坐地铁吗)

    公交乘车码可以刷2次吗(公交乘车码可以坐地铁吗)

  • vivo计算器怎么计算汇率(vivo计算器怎么开三次方)

    vivo计算器怎么计算汇率(vivo计算器怎么开三次方)

  • 华为p30pro来电闪光灯怎么设置(华为p30pro来电闪灯哪设置)

    华为p30pro来电闪光灯怎么设置(华为p30pro来电闪灯哪设置)

  • 苹果7手机严重卡顿(苹果手机严重发热原因)

    苹果7手机严重卡顿(苹果手机严重发热原因)

  • 商户扫码退款流程(商户扫码退款流程视频)

    商户扫码退款流程(商户扫码退款流程视频)

  • 装卸费发票怎么备注
  • 购买软件无形资产
  • 机器设备一般折旧几年
  • 个人所得税分摊方式月扣除金额修改
  • 新公司开基本户银行选择
  • 工程施工企业的账务处理
  • 如何确定商品交易价格?
  • 一般纳税人无票收入会计分录
  • 个人转让商铺个人所得税核定征收
  • 个税登记app
  • 零余额帐户如何转账
  • 广告位租赁交印花税吗
  • 调整已结转的税种有哪些
  • 公司注册核税后如何建账?
  • 小微企业免征税额
  • 本月出口下月开发票可以吗
  • 旅行社代订机票可以入差旅费报销吗
  • 免征土地增值税的有哪些
  • 商贸公司购买货物会计分录
  • 茶具可以作为固定资产吗
  • 公司注销资产负债表期末余额不能为0
  • 怎么判断分红前已提取足够法定公积金?
  • 电子发票增加开票项目
  • 苹果mac os x 怎样打开DVD播放程序
  • 税收是财政政策传导机制中重要的媒介之一
  • 农产品成本法计算抵扣
  • 360pci.exe
  • PHP:time_nanosleep()的用法_misc函数
  • 短缺的材料算不算入账价值
  • 房贷每月利息如何算
  • 埃托沙国家公园发展观兽旅游的优势条件
  • 租入固定资产改良支出属于资本性支出吗
  • novelai本地部署电脑要求
  • java web中的转发和重定向
  • 区块链技术开发入门
  • 网上怎么申请增驾摩托车
  • 企业发生的咨询费应计入哪个科目
  • SQLite之Autoincrement关键字(自动递增)
  • 在mysql中子查询是
  • 零申报未申报可以不处罚吗
  • 如何在税控盘上变更一般纳税人
  • 增值税买票卖票
  • 个体户开出的发票没跟对方说自己冲红了怎么办
  • 营业执照注销要钱吗
  • 项目费用有哪些
  • 企业现金购货限额
  • 增值税申报表中期初未缴税额指什么
  • 开公司的车出差违章算谁的
  • 免税农产品怎么开发票
  • 高新技术企业的税收优惠政策
  • 税金及附加是按什么基础交的
  • 收入与成本不配合
  • 其他业务收入冲减应付账款
  • 劳动就业失业金怎么申请
  • sql server随机数函数
  • mysql输入中文显示乱码
  • Win7 64位旗舰版系统中实现照片的批量重命名
  • windows service 2003
  • windows server 2008 r2激活密钥
  • ubuntu怎么录音
  • 通过修改注册表来修改chrome配置
  • debian7安装教程
  • tomcat调用servlet流程
  • rsrcmtr.exe - rsrcmtr是什么进程 有什么用
  • windows7开机提示盗版
  • android 资源管理器
  • 一键删除通讯录联系人
  • shell脚本cut -d
  • unity 3d ui
  • jquery怎么给文本框赋值
  • python中的条件判断和循环语句
  • 不同版本安卓控制台区别
  • jquery weui
  • python djang
  • 税务工作创新
  • 税控盘开票资料怎么导出几年前的开票信息
  • 劳务派遣服务计税
  • 南京税务局几点下班
  • 我国税收征收机关包括
  • 作废的发票验旧之后怎么领取新发票
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设