位置: IT常识 - 正文

计算机视觉中的注意力机制(计算机视觉中的数学方法)

编辑:rootadmin
计算机视觉中的注意力机制 计算机视觉中的注意力机制什么是注意力机制常用的简单的注意力机制SE AttentionCBAM Attention其他注意力机制注意力机制该加到网络的哪里什么是注意力机制

推荐整理分享计算机视觉中的注意力机制(计算机视觉中的数学方法),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:计算机视觉中的人脸识别过程包括,计算机视觉中的多视图几何pdf,计算机视觉中的数学方法pdf,计算机视觉中的多视图几何,计算机视觉中的人脸识别过程包括,计算机视觉中的数学方法,计算机视觉中的多视图几何中文版pdf,计算机视觉中的多视图几何中文版pdf,内容如对您有帮助,希望把文章链接给更多的朋友!

注意力机制(Attention Mechanism)源于对人类视觉的研究。 在认知科学中,由于信息处理的瓶颈,人类会选择性地关注所有信息的一部分,同时忽略其他可见的信息。 上述机制通常被称为注意力机制。 人类视网膜不同的部位具有不同程度的信息处理能力,即敏锐度(Acuity),只有视网膜中央凹部位具有最强的敏锐度。(以上为官方解释:个人的理解是注意力机制就是通过一通操作,将数据中关键的特征标识出来,让网络学到数据中需要关注的区域,也就形成了注意力。从而起到突出重要特征的作用。)

常用的简单的注意力机制计算机视觉中的注意力机制(计算机视觉中的数学方法)

常用的注意力机制多为SE Attention和CBAM Attention。为什么常用的是它们呢?其实回看所有注意力机制的代码,都不难发现,它们基本都可以当成一个简单的网络。例如SE注意力机制,它主要就是由两个全连接层组成,这就是一个简单的MLP模型,只是它的输出变了样。所以,在我们把注意力机制加入主干网络里时,所选注意力机制的复杂程度也是我们要考虑的一个方面,因为增加注意力机制,也变相的增加了我们网络的深度,大小。下面我们将介绍两个比较简单的注意力机制。

SE Attention

SE Attentionq其实称为SENet,它全称是Squeeze-and-Excitation Networks。它是2017ImageNet冠军模型。SENet它注意的是我们的通道重要性。SENet的结构如下面两张图。它的结构resnet相识,同样采用短路连接的方式来避免梯度消失,从而可以加深网络和更好地训练模型。不同的是SENet增加了一个由两个FC层,一个池化和两个激活函数组成的block来学习到不同通道特征的重要程度。这种注意力机制可以让模型更加关注信息量最大的通道特征,而抑制那些不重要的通道特征。还有一点是SE模块是通用的,这意味着其可以嵌入到现有的网络架构中,而且非常简洁有效。想更深入了解的可以阅读SENet的论文。论文地址 SE Attention的pytorch代码:

import torch.nn as nnimport torchclass SEAttention(nn.Module): def __init__(self, channel, reduction=16): # channel为输入通道数,reduction压缩大小 super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel // reduction), nn.ReLU(), nn.Linear(channel // reduction, channel), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * yCBAM Attention

CBAM(Convolutional Block Attention Module) 表示卷积模块的注意力机制模块。是一种结合了空间(spatial)和通道(channel)的注意力机制模块。一般情况下,相比于SEnet只关注通道(channel)的注意力机制可以取得更好的效果。其中CBAM的结构如下面两张图,由Channel Attention和 Spatial Attention这两个模块组成,其中Channel Attention模块和SENet是十分相似的,只是在池化上做了最大和平均池化,把FC层换成了卷积。至于Spatial Attention模块,这个更为简单,本质上就是一个卷积层。想要更深入的理解可以去阅读原论文。论文地址

CBAM Attention 代码(Pytorch版):

import numpy as npimport torchfrom torch import nnfrom torch.nn import initclass ChannelAttention(nn.Module): def __init__(self,channel,reduction=16): super().__init__() self.maxpool=nn.AdaptiveMaxPool2d(1) self.avgpool=nn.AdaptiveAvgPool2d(1) self.se=nn.Sequential( nn.Conv2d(channel,channel//reduction,1,bias=False), nn.ReLU(), nn.Conv2d(channel//reduction,channel,1,bias=False) ) self.sigmoid=nn.Sigmoid() def forward(self, x) : max_result=self.maxpool(x) avg_result=self.avgpool(x) max_out=self.se(max_result) avg_out=self.se(avg_result) output=self.sigmoid(max_out+avg_out) return outputclass SpatialAttention(nn.Module): def __init__(self,kernel_size=7): super().__init__() self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2) self.sigmoid=nn.Sigmoid() def forward(self, x) : max_result,_=torch.max(x,dim=1,keepdim=True) avg_result=torch.mean(x,dim=1,keepdim=True) result=torch.cat([max_result,avg_result],1) output=self.conv(result) output=self.sigmoid(output) return outputclass CBAMBlock(nn.Module): def __init__(self, channel=512,reduction=16,kernel_size=49): super().__init__() self.ca=ChannelAttention(channel=channel,reduction=reduction) self.sa=SpatialAttention(kernel_size=kernel_size) def forward(self, x): b, c, _, _ = x.size() residual=x out=x*self.ca(x) out=out*self.sa(out) return out+residual其他注意力机制

除了上面介绍的两种视觉的注意力机制,其实还有很多各种花样且“有效”的注意力机制。如:SK Attention(Selective Kernel Networks) ,Shuffle Attention (Sa-net: Shuffle attention for deep convolutional neural networks) , Pyramid Split Attention(论文地址)等等。想要了解和使用更多的视觉注意力代码可以到这个github里找,这里面复现了许多注意力机制的代码,而且不限与视觉的,有兴趣的同学可以去了解下。 github地址

注意力机制该加到网络的哪里

当我们的网络需要加注意力机制时,它该加在哪里呢?虽然注意机制是一个独立的块,一般来说加在哪里都是可以的,但是,注意机制加入我们的网络中时,他是会影响我们网络的特征提取的,即它注意的特征不一定都是我们重要的特征。所以注意力机制加入我们网络的位置就比较重要了。其实注意力提出的作者就告诉了我们注意力机制加在哪里比较合适。所以,当我我们使用一个注意力机制不知道加在哪里时可以去看看提出注意力机制作者的源代码。下面的代码就是CBAM注意力机制的源代码,可以看出,注意力机制加在了残差网络(以resnet18为例)的残差块后面。当然,如果我们使用的网络不是注意力机制作者使用的网络,那这是注意力机制该加在哪里呢?这里个人建议加在最后一个卷积层后面或者第一个全连接层前面。当然并不是每个注意力机制或者每个网络都适用,因为不同的注意力机制注意的地方可能都不一样,所以加到主干网络的地方可能也不一样。

import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathfrom torch.nn import initfrom .cbam import *from .bam import *def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride if use_cbam: self.cbam = CBAM( planes, 16 ) else: self.cbam = None def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) if not self.cbam is None: out = self.cbam(out) out += residual out = self.relu(out) return outclass Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride if use_cbam: self.cbam = CBAM( planes * 4, 16 ) else: self.cbam = None def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) if not self.cbam is None: out = self.cbam(out) out += residual out = self.relu(out) return outclass ResNet(nn.Module): def __init__(self, block, layers, network_type, num_classes, att_type=None): self.inplanes = 64 super(ResNet, self).__init__() self.network_type = network_type # different model config between ImageNet and CIFAR if network_type == "ImageNet": self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.avgpool = nn.AvgPool2d(7) else: self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) if att_type=='BAM': self.bam1 = BAM(64*block.expansion) self.bam2 = BAM(128*block.expansion) self.bam3 = BAM(256*block.expansion) else: self.bam1, self.bam2, self.bam3 = None, None, None self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type) self.fc = nn.Linear(512 * block.expansion, num_classes) init.kaiming_normal(self.fc.weight) for key in self.state_dict(): if key.split('.')[-1]=="weight": if "conv" in key: init.kaiming_normal(self.state_dict()[key], mode='fan_out') if "bn" in key: if "SpatialGate" in key: self.state_dict()[key][...] = 0 else: self.state_dict()[key][...] = 1 elif key.split(".")[-1]=='bias': self.state_dict()[key][...] = 0 def _make_layer(self, block, planes, blocks, stride=1, att_type=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM')) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM')) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) if self.network_type == "ImageNet": x = self.maxpool(x) x = self.layer1(x) if not self.bam1 is None: x = self.bam1(x) x = self.layer2(x) if not self.bam2 is None: x = self.bam2(x) x = self.layer3(x) if not self.bam3 is None: x = self.bam3(x) x = self.layer4(x) if self.network_type == "ImageNet": x = self.avgpool(x) else: x = F.avg_pool2d(x, 4) x = x.view(x.size(0), -1) x = self.fc(x) return xdef ResidualNet(network_type, depth, num_classes, att_type): assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100" assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101' if depth == 18: model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type) elif depth == 34: model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type) elif depth == 50: model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type) elif depth == 101: model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type) return
本文链接地址:https://www.jiuchutong.com/zhishi/300694.html 转载请保留说明!

上一篇:Mask RCNN详解(mask rcnn优点)

下一篇:手把手带你做一套毕业设计-征程开启(手把手带你做一件事)

  • 淘宝新店使用的推广方法(淘宝新店的DSR多久会发生变化)

    淘宝新店使用的推广方法(淘宝新店的DSR多久会发生变化)

  • 手机qq如何分享屏幕(手机QQ如何分享文件)

    手机qq如何分享屏幕(手机QQ如何分享文件)

  • 拼多多没有自然流量的原因(拼多多没有自然流量怎么提升)

    拼多多没有自然流量的原因(拼多多没有自然流量怎么提升)

  • 微信朋友圈不发东西是怎么显示(微信朋友圈不发图片只发文字怎么弄)

    微信朋友圈不发东西是怎么显示(微信朋友圈不发图片只发文字怎么弄)

  • qq号被永久冻结了还可以解封吗(qq号被永久冻结怎么办)

    qq号被永久冻结了还可以解封吗(qq号被永久冻结怎么办)

  • 抖音亮灯牌是什么意思(抖音亮灯牌是什么意思呀)

    抖音亮灯牌是什么意思(抖音亮灯牌是什么意思呀)

  • 苹果手机如何上两个微信号(苹果手机如何上传视频到电脑)

    苹果手机如何上两个微信号(苹果手机如何上传视频到电脑)

  • qq群聊天记录多久会消失(qq群聊天记录多了一个不认识的人)

    qq群聊天记录多久会消失(qq群聊天记录多了一个不认识的人)

  • 手机掉水里声音变小了(手机掉水里声音很小怎么办)

    手机掉水里声音变小了(手机掉水里声音很小怎么办)

  • 抖音私信显示已读(抖音私信显示已送达是什么意思)

    抖音私信显示已读(抖音私信显示已送达是什么意思)

  • 小米10屏幕不灵敏(小米屏幕不灵敏怎么修复)

    小米10屏幕不灵敏(小米屏幕不灵敏怎么修复)

  • 5G手机关掉5G还耗电吗(5g手机关掉5g后流量用得快吗)

    5G手机关掉5G还耗电吗(5g手机关掉5g后流量用得快吗)

  • 手机摔过后有什么影响(手机摔了后会影响性能吗)

    手机摔过后有什么影响(手机摔了后会影响性能吗)

  • 华为荣耀6plus是不是全网通(华为 荣耀 6 plus)

    华为荣耀6plus是不是全网通(华为 荣耀 6 plus)

  • 喵喵机充电多久才会满(喵喵机充电注意事项)

    喵喵机充电多久才会满(喵喵机充电注意事项)

  • oppor11支持5g网络吗(oppor11支持5g频段wifi吗)

    oppor11支持5g网络吗(oppor11支持5g频段wifi吗)

  • 荣耀20和20s的区别(荣耀20和20S的区别)

    荣耀20和20s的区别(荣耀20和20S的区别)

  • vivo手机怎么使用内存卡(vivo手机怎么使用空调万能遥控器)

    vivo手机怎么使用内存卡(vivo手机怎么使用空调万能遥控器)

  • 什么叫4k超高清(什么叫4k超高清视频)

    什么叫4k超高清(什么叫4k超高清视频)

  • 抖音密码忘了怎么找回(抖音密码忘了怎么设置修改)

    抖音密码忘了怎么找回(抖音密码忘了怎么设置修改)

  • z5x有微信视频美颜功能吗(vivo z5i微信视频美颜在哪里)

    z5x有微信视频美颜功能吗(vivo z5i微信视频美颜在哪里)

  • 微信头像点两下为什么会动(微信头像点两下拍了拍怎么设置)

    微信头像点两下为什么会动(微信头像点两下拍了拍怎么设置)

  • 斗鱼tv投屏怎么总断开(斗鱼tv投屏怎么看弹幕)

    斗鱼tv投屏怎么总断开(斗鱼tv投屏怎么看弹幕)

  • svoice是什么意思(svon什么意思)

    svoice是什么意思(svon什么意思)

  • 网易考拉如何查看资料(网易考拉怎么查真伪)

    网易考拉如何查看资料(网易考拉怎么查真伪)

  • 推荐国内免费使用chatGPT的工具(推荐国内免费使用的电影)

    推荐国内免费使用chatGPT的工具(推荐国内免费使用的电影)

  • uniapp宽屏开发PC端方案,及衍生问题解决(uniapp宽度)

    uniapp宽屏开发PC端方案,及衍生问题解决(uniapp宽度)

  • vue中组件间通信的6种方式(vue之间的组件通信)

    vue中组件间通信的6种方式(vue之间的组件通信)

  • 牵引车和挂车都要购买交强险吗
  • 资产减值准备的计提方法
  • 工会经费按什么交
  • 外地工程款没有预缴会怎么样
  • 分公司可以享受小规模纳税人优惠
  • 市政押金无法收回的损失可以税前扣除吗
  • 个人所得税可以退几年前的?
  • 固定资产加速折旧税收优惠政策
  • 上市公司现金流充足说明什么
  • 自建生产用机器设备领用本企业生产的产品
  • 公司预支了然后来报销的帐怎么做?
  • 建筑企业预缴增值税计算
  • 免单计入什么科目
  • 维护费开的普票能全额抵扣吗?
  • 建筑业商业保险受益人可以是公司吗
  • 增值税普通发票有什么用
  • 采购合同清单的安装调试费如何开具发票?税率是多少
  • 个人年终奖如何交税
  • 计提所得税费用会计分录
  • 企业提取的盈余公积是什么会计科目
  • 消防改造费用
  • 多发工资还给老板是傻吗
  • 积分兑换现金的软件
  • 税务申报系统叫什么
  • 企业停产后员工怎么办
  • 财务报告成本
  • win10系统怎么设置锁屏壁纸
  • win11比win10是更流畅了吗
  • 支付项目工程款可以借流贷吗
  • 对某公司的了解
  • 工会经费,职工福利费,教育经费的扣除标准
  • php apc
  • win11测试版和正式版区别
  • php简单实例
  • 实收资本可以大于注册资本嘛
  • PHP:preg_split()的用法_PCRE正则函数
  • 土方工程公司账务实例
  • 如何自己搭建一个邮箱服务器
  • 企业出租房屋怎么做账
  • php的序列化操作生成的哪种格式
  • vue 跳转页面
  • 财务人离职了怎么说
  • 自动驾驶决策规划技术理论与实践电子版
  • 用more命令查看文件内容
  • 关于低值易耗品的说法中不正确的是
  • 施工项目的费用包括
  • 公司租入厂房怎么做账
  • 贸易企业出口退税计算方法
  • sqlserver触发器在哪个位置
  • 进出车间管理规定适用
  • 预收账款要预交税金吗
  • 固定资产多少钱算固定资产
  • 出口没做免税申请怎么办
  • 存货的盘盈
  • 小规模企业税金怎么做账
  • 小规模纳税人哪里可以查
  • 印花税减免税额怎么填
  • 无建账能力的纳税人是什么意思
  • win7如何设置自动锁屏时间
  • centos备份文件夹
  • freebsd6.2 nginx+php+mysql+zend系统优化防止ddos攻击
  • VMware虚拟机安装苹果Mac OS
  • win10预览版最新
  • dsapi.exe是什么
  • windows7开机
  • win8.1怎么设置
  • win1020h2正式版
  • xp取消开机启动项
  • windows 10预览版
  • ubuntu怎么将文件传送到电脑
  • windows 10 升级
  • 表单验证插件
  • easyui combobox设置值
  • unity3d教学视频
  • python生成器有几种写法
  • unity3ds
  • apk反编译去广告教程
  • python中fd
  • 养殖业免税用报税吗
  • 免税企业可以收增值税专票吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设