位置: IT常识 - 正文

yolox改进--添加Coordinate Attention模块(CVPR2021)(yolo改进方法)

编辑:rootadmin
yolox改进--添加Coordinate Attention模块(CVPR2021) yolox改进--添加Coordinate Attention模块Coordinate Attention代码建立包含CAM代码的attention.py在yolo_pafpn.py中添加CAM总结

推荐整理分享yolox改进--添加Coordinate Attention模块(CVPR2021)(yolo改进方法),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:yolov2改进,yolo增加检测层,yolov5如何改进,yolo增加检测层,改进yolov3,改进yolov3,yolov2改进,改进yolov3,内容如对您有帮助,希望把文章链接给更多的朋友!

因为项目需要,尝试魔改一下yolox-s,看看能不能在个人数据集上刷高点mAP。因为Coordinate Attention模块(以下简称CAM)的作者提供了代码,并且之前不少博主公开了CAM用在yolov5或者yolox等模型的代码,所以一开始我直接当了搬运工,但在搬运过程,我发现官方的代码不能直接用在yolox上,且之前公开CAM用在yolox的代码根本跑不通。在debug之后,发现问题是出现在官方的代码上,于是心血来潮写下这篇文章,废话不多说,来看修改后的代码吧!

Coordinate Attentionyolox改进--添加Coordinate Attention模块(CVPR2021)(yolo改进方法)

论文来源: http://arxiv.org/abs/2103.02907 官方代码:https://github.com/Andrew-Qibin/CoordAttention

注意力机制广泛用于深度神经网络中来提高模型的性能。然而,因为其昂贵的计算代价,很难应用在一些轻量级网络,但不乏有一些注意力模块脱颖而出,具有代表性的有SE、CBAM等。SE模块通过2D全局池化来计算通道注意力,在非常低的计算成本下达到了提升网络性能的目的,遗憾的是,SE模块忽视了捕获位置信息的注意力;CBAM模块通过使用大尺寸卷积来获得位置信息的注意力,但只偏向于捕获局部的位置信息。 CAM模块来源于2021CVPR,该模块通过将位置信息嵌入到通道注意力中,因为其较少的计算代价,使轻量级网可以较大的区域中获得注意力。为了缓解位置信息丢失的问题,论文作者将2D全局池化替换成分别在特征的w和h并行提取特征的两个1D池化,可以有效捕获空间坐标信息;而后这两个并行的特征图通过两个卷积来生成两个独立方向的注意力图;通过将两个注意力图乘入到原始特征图中,以达到增强特征图的表征能力。

代码建立包含CAM代码的attention.py

在./yolox/models/文件夹下建立attention.py,CAM代码如下。相较于官方的代码,为了适配yolox,这里将nn.AdaptiveAvgPool2d直接用于forward。

class CAM(nn.Module): def __init__(self, channels, reduction=32): super(CAM, self).__init__() self.conv_1x1 = nn.Conv2d(in_channels=channels, out_channels=channels // reduction, kernel_size=1, stride=1, bias=False) self.mish = Mish() # 可用自行选择激活函数 self.bn = nn.BatchNorm2d(channels // reduction) self.F_h = nn.Conv2d(in_channels=channels // reduction, out_channels=channels, kernel_size=1, stride=1, bias=False) self.F_w = nn.Conv2d(in_channels=channels // reduction, out_channels=channels, kernel_size=1, stride=1, bias=False) self.sigmoid_h = nn.Sigmoid() self.sigmoid_w = nn.Sigmoid() def forward(self, x): h, w = x.shape[2], x.shape[3] avg_pool_x = nn.AdaptiveAvgPool2d((h, 1)) avg_pool_y = nn.AdaptiveAvgPool2d((1, w)) x_h = avg_pool_x(x).permute(0, 1, 3, 2) x_w = avg_pool_y(x) x_cat_conv_relu = self.mish(self.conv_1x1(torch.cat((x_h, x_w), 3))) x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3) s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2))) s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w)) out = x * s_h.expand_as(x) * s_w.expand_as(x) return out在yolo_pafpn.py中添加CAM

CAM作为即插即用的注意力模块,添加位置可以完全替换例如CBAM等经典的注意力机制模块,具体可参考其他有关yolox在head中插入注意力机制的教程,这里给的代码以添加在pafpn为例,添加在哪效果好要取决于添加位置在特定数据集的表现。

#!/usr/bin/env python# -*- encoding: utf-8 -*-# Copyright (c) Megvii Inc. All rights reserved.import torchimport torch.nn as nnfrom .darknet import CSPDarknetfrom .network_blocks import BaseConv, CSPLayer, DWConvfrom .attention import CAMclass YOLOPAFPN(nn.Module): """ YOLOv3 model. Darknet 53 is the default backbone of this model. """ def __init__( self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"), in_channels=[256, 512, 1024], depthwise=False, act="silu", ): super().__init__() self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act) self.in_features = in_features self.in_channels = in_channels Conv = DWConv if depthwise else BaseConv self.upsample = nn.Upsample(scale_factor=2, mode="nearest") # self.upsample = nn.Upsample(scale_factor=2, mode="bilinear") self.lateral_conv0 = BaseConv( int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act ) self.C3_p4 = CSPLayer( int(2 * in_channels[1] * width), int(in_channels[1] * width), round(3 * depth), False, depthwise=depthwise, act=act, ) # cat self.reduce_conv1 = BaseConv( int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act ) self.C3_p3 = CSPLayer( int(2 * in_channels[0] * width), int(in_channels[0] * width), round(3 * depth), False, depthwise=depthwise, act=act, ) # bottom-up conv self.bu_conv2 = Conv( int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act ) self.C3_n3 = CSPLayer( int(2 * in_channels[0] * width), int(in_channels[1] * width), round(3 * depth), False, depthwise=depthwise, act=act, ) # bottom-up conv self.bu_conv1 = Conv( int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act ) self.C3_n4 = CSPLayer( int(2 * in_channels[1] * width), int(in_channels[2] * width), round(3 * depth), False, depthwise=depthwise, act=act, ) self.CAM0 = CAM(int(in_channels[2] * width)) self.CAM1 = CAM(int(in_channels[1] * width)) self.CAM2 = CAM(int(in_channels[0] * width)) # self.CAM3 = CAM(int(in_channels[0] * width)) # self.CAM4 = CAM(int(in_channels[1] * width)) # self.CAM5 = CAM(int(in_channels[2] * width)) def forward(self, input): """ Args: inputs: input images. Returns: Tuple[Tensor]: FPN feature. """ # backbone out_features = self.backbone(input) features = [out_features[f] for f in self.in_features] [x2, x1, x0] = features #############add CAM############## x0 = self.CAM0(x0) x1 = self.CAM1(x1) x2 = self.CAM2(x2) ################################## fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 f_out0 = self.upsample(fpn_out0) # 512/16 f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16 f_out0 = self.C3_p4(f_out0) # 1024->512/16 fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 f_out1 = self.upsample(fpn_out1) # 256/8 f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8 pan_out2 = self.C3_p3(f_out1) # 512->256/8 # pan_out2 = self.CAM3(pan_out2) p_out1 = self.bu_conv2(pan_out2) # 256->256/16 p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16 pan_out1 = self.C3_n3(p_out1) # 512->512/16 # p_out1 = self.CAM4(p_out1) p_out0 = self.bu_conv1(pan_out1) # 512->512/32 p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32 pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 # pan_out0 = self.CAM5(pan_out0) outputs = (pan_out2, pan_out1, pan_out0) return outputs总结

CAM,同SE、CBAM等模块一样,作为即插即用的注意力机制,在yolov5、yolox等轻量级网络中有着重要的作用。本文介绍的CAM+yolox在我的数据集上,mAP比不添加的时候提高了0.02个点,相比使用CBAM提高了0.01个点,效果还是很可观的。

本文链接地址:https://www.jiuchutong.com/zhishi/297518.html 转载请保留说明!

上一篇:前端使用lottie-web,使用AE导出的JSON动画贴心教程(前端使用vue)

下一篇:下载、编译、安装、使用 vue-devtools(编译安装和普通安装)

  • App推广下半场怎么做?(app推广下半场怎么做)

    App推广下半场怎么做?(app推广下半场怎么做)

  • 小米手机锁屏密码忘了怎么解开(小米手机锁屏密码忘了不想清除数据怎么办)

    小米手机锁屏密码忘了怎么解开(小米手机锁屏密码忘了不想清除数据怎么办)

  • 爱奇艺二维码登录码在哪(爱奇艺二维码登录在哪)

    爱奇艺二维码登录码在哪(爱奇艺二维码登录在哪)

  • 苹果手机网易云音乐怎么悬浮歌词(苹果手机网易云的歌怎么下载到本地文件)

    苹果手机网易云音乐怎么悬浮歌词(苹果手机网易云的歌怎么下载到本地文件)

  • 爱奇艺vip可以同时几个人用呢(爱奇艺vip能共用吗)

    爱奇艺vip可以同时几个人用呢(爱奇艺vip能共用吗)

  • 如何通过淘宝昵称查人(如何通过淘宝昵称找到别人的旺旺账号)

    如何通过淘宝昵称查人(如何通过淘宝昵称找到别人的旺旺账号)

  • w2020和fold区别(w2021和fold3的区别)

    w2020和fold区别(w2021和fold3的区别)

  • 微信朋友圈怎么发完整信息(微信朋友圈怎么批量删除)

    微信朋友圈怎么发完整信息(微信朋友圈怎么批量删除)

  • 计算机中用来表示存储器容量的基本单位是( )。(计算机中用来表示内存容量大小的基本单位是)

    计算机中用来表示存储器容量的基本单位是( )。(计算机中用来表示内存容量大小的基本单位是)

  • 抖音0播放怎么恢复(抖音播放怎么静音)

    抖音0播放怎么恢复(抖音播放怎么静音)

  • mbr和guid哪个速度快(mbr快还是guid快)

    mbr和guid哪个速度快(mbr快还是guid快)

  • 微信红包限制200如何提高(微信红包限制多长时间能恢复)

    微信红包限制200如何提高(微信红包限制多长时间能恢复)

  • ipad微信发不了视频朋友圈(ipad微信发不了相册视频)

    ipad微信发不了视频朋友圈(ipad微信发不了相册视频)

  • 小米8支持4g 吗(小米8支持五g吗)

    小米8支持4g 吗(小米8支持五g吗)

  • 华为matebook14尺寸多大(华为matebook14尺寸图)

    华为matebook14尺寸多大(华为matebook14尺寸图)

  • 乐视手机怎么设置亮度(乐视手机怎么设置青少年模式)

    乐视手机怎么设置亮度(乐视手机怎么设置青少年模式)

  • 手机支付宝花呗怎么开通(手机支付宝花呗收款二维码怎么开通)

    手机支付宝花呗怎么开通(手机支付宝花呗收款二维码怎么开通)

  • 一加nfc怎么复制门禁卡(一加nfc怎么复制饭卡)

    一加nfc怎么复制门禁卡(一加nfc怎么复制饭卡)

  • 怎么删除苹果云盘资料(怎么删除苹果云备份)

    怎么删除苹果云盘资料(怎么删除苹果云备份)

  • 荣耀畅玩9x上市时间(荣耀畅玩9c即将发布)

    荣耀畅玩9x上市时间(荣耀畅玩9c即将发布)

  • 摇一摇为什么收不到打招呼的人(摇一摇为什么收不到打招呼的人怎么办)

    摇一摇为什么收不到打招呼的人(摇一摇为什么收不到打招呼的人怎么办)

  • xr关后台的方法

    xr关后台的方法

  • WAN口有IP地址上不了网,怎么办?(wan口ip地址和lan口ip地址不能)

    WAN口有IP地址上不了网,怎么办?(wan口ip地址和lan口ip地址不能)

  • tensorflow使用显卡gpu进行训练详细教程(tensorflow dlib)

    tensorflow使用显卡gpu进行训练详细教程(tensorflow dlib)

  • 如何vue使用ant design Vue中的select组件实现下拉分页加载数据,并解决存在的一个问题。(ant desgin-vue)

    如何vue使用ant design Vue中的select组件实现下拉分页加载数据,并解决存在的一个问题。(ant desgin-vue)

  • 注意力机制详解系列(四):混合注意力机制(注意力机制 q k v)

    注意力机制详解系列(四):混合注意力机制(注意力机制 q k v)

  • 增值税申报表上的销售收入
  • 公司有流水不申报会怎么样
  • 印花税减半征收吗
  • 每月记账报税客户怎么填
  • 专利在审可以入库吗
  • 出纳如何做好保密工作
  • 不得抵扣的进项税额的情形有
  • 商业企业库存商品和销售对不上
  • 专用发票跨年度能入账吗
  • 小规模公司退税
  • 已审核已过账已经生成凭证还能修改吗?
  • 预缴税款后怎么开票
  • 印花税为什么不计入资产成本
  • 采购合同中含税金额
  • 吊车租赁增值税税率最新2022
  • 增值税税负最终由谁承担
  • 契税和印花税入哪个科目
  • 个人土地征收款协议模板
  • 背书转让后的电子承兑怎么打印
  • 企业职工集资款的认定标准
  • mac os 10.15安装教程
  • 银行存款转定期存款计入什么科目
  • 出租人负责维修
  • 什么是货币资产负债表
  • 可转债 承销
  • 如何永久关闭win10系统更新
  • php array_search() 函数使用
  • win101903怎么查看
  • 固定资产改造时的账面价值
  • 何为职工
  • 网络通信的整个流程
  • 依夫城堡
  • 韦罗尼卡
  • php判断查询是否有结果
  • 季度申报残保金怎么计算
  • ts中如何定义一个数组
  • nlp自然语言处理框架
  • 新制度设置了应缴财政款科目原制度设置了什么科目
  • 差旅费的进项税额需要转出吗
  • 现金劳务收入会计分录
  • 电子产品报废清理是否缴纳教育附加税
  • 发票开具与小票的关系是怎样的
  • 企业残保金什么情况下可以减免
  • db2 798
  • mysql和mongo的区别
  • 施工企业结算单能不能入账
  • 资产负债表编制
  • 结转生产成本的数据从哪来的
  • 无形资产如何做账务处理
  • 计提个人经营所得税怎么算
  • 预收货款未发货怎么办
  • 财务费用具体包括
  • 小规模纳税人可以转为一般纳税人吗?
  • 预付款项为什么属于资产
  • 资产负债表的编制依据是会计恒等式
  • mysql基于什么模型
  • sqlserver日期范围
  • 微软补丁星期二更新吗
  • Win10 Mobile 10549 预览版新功能上手体验视频
  • windowsserver2008r2忘记开机密码怎么办
  • ubuntu 添加开机启动
  • fedora linux安装教程
  • mac系统怎么清理Adobe残留
  • win10系统如何删除账户
  • xp远程连接win7
  • win7开机自动弹出注册表编辑器怎么办
  • win8 Could not load type System.ServiceModel.Activation.HttpModule 错误解决方案
  • 批处理有何限制
  • 冒充咋写
  • opengl快速入门
  • js 模拟滑动
  • 从零基础开始学
  • python单链表输出1到10
  • 广东电子税务局官网登录入口
  • 百望税控盘电子发票怎么打
  • 车船税的纳税期限是
  • 四川农村信用社电话
  • 企业职工病退后一般能领多少钱
  • 非载货专项作业车属于什么车
  • 徐州市哪些区域有疫情
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设