位置: IT常识 - 正文

自注意力(Self-Attention)与Multi-Head Attention机制详解(自注意力机制是什么)

编辑:rootadmin
自注意力(Self-Attention)与Multi-Head Attention机制详解

推荐整理分享自注意力(Self-Attention)与Multi-Head Attention机制详解(自注意力机制是什么),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:自注意力机制的作用,自注意力机制和多头注意力机制的区别,自注意力机制和多头注意力机制的区别,自注意力机制权重怎么求的,自注意力机制和卷积神经网络,自注意力机制原理,自注意力机制原理,自注意力机制是什么,内容如对您有帮助,希望把文章链接给更多的朋友!

  自注意力机制属于注意力机制之一。与传统的注意力机制作用相同,自注意力机制可以更多地关注到输入中的关键信息。self-attention可以看成是multi-head attention的输入数据相同时的一种特殊情况。所以理解self attention的本质实际上是了解multi-head attention结构。

一:基本原理  

  对于一个multi-head attention,它可以接受三个序列query、key、value,其中key与value两个序列长度一定相同,query序列长度可以与key、value长度不同。multi-head attention的输出序列长度与输入的query序列长度一致。兔兔这里记query的长度为Lq,key与value的长度记为Lk。

  其次,对于输入序列query、key、value,它们特征长度(每个元素维度dim)是可以不同的,记这三个序列的dim分别为Dq、Dk、Dv。在这些序列输入multi-head attention后,内部的序列的dim是可以与Dq、Dk与Dv不同的,我们称之为嵌入(embedding)维度,记为De,输出的序列dim也是De。

  multi-head attention是由一个或多个平行的单元结构组合而成,我们称每个这样的单元结构为一个head(one head,实际上也可以称为一个layer),为了方便,兔兔暂且命名这个单元结构为one-head attention,广义上head数为1 时也是multi-head attention。one-head attention结构是scaled dot-product attention与三个权值矩阵(或三个平行的全连接层)的组合,结构如下图所示

二:Scale Dot-Product Attention具体结构

  对于上图,我们把每个输入序列q,k,v看成形状是(Lq,Dq),(Lk,Dk),(Lk,Dv)的矩阵,即每个元素向量按行拼接得到的矩阵。Linear层的参数分别为(Dq,De),(Dk,De),(Dv,De),则通过全连接层,输出矩阵形状为(Lq,De),(Lk,De),(Lv,De),我们令通过全连接层得到的矩阵为Q、K、V。

  Linear层的本质是权值矩阵W与输入矩阵相乘(有时也可以加上偏置bias),在one-head attention中,我们令与Q、K、V相乘的权值矩阵分别为,它们的形状为(Dq,De),(Dk,De),(Dv,De)。bias的使用与否对后面的结构并无影响,在一些深度学习框架中默认加bias,但是《Attention Is All You Need》原文公式中并未体现bias,只有W,所以兔兔在后面讲解部分,不考虑bias。

  在输入数据通过Linear操作得到Q、K、V矩阵后,我们才真正来到Scale dot-product attention部分。

  Scale dot-product attention可以由一个简洁的公式来表示,其中dk即为我们前面的Dk:

  这个公式得到的输出即为onehead-attention的输出,它是一个形状为(Lq,De)的矩阵,表示长度为Lq,维度为De的输出序列。公式中:

有一个名字:attention weights,形状为(Lq,Lk),它可以大概理解为q序列与k序列各个对应元素之间相关性,类似于你在网页上输入关键词query,网页中之前存在的索引key,根据query与key的相关与否来决定选哪些索引key,并根据key来推荐相应的value。

  讲到这里,实际上已经介绍完multihead-attention的单元结构了。但是这个过程还可以更加深入地理解,下图是Lq与Lk相同时Scale dot-product attention的详细结构(一般Lq和Lk相等很可能Q,K,v来自同一序列,此时即为self attention,兔兔后面会讲到)。

  上图展示的是一个接收Q,K,V形状都是(3,De)的一个scale dot-product attention结构,我们把Q、K、V都拆解成长度为3,维度为De的序列。每次q与各个k计算内积得到一个数a,这些数通过softmax得到新的数a'(这里softmax是整体)。得到的a'与各自的v向量相乘得到新的向量,最终这些新的向量相加得到一个长度为De的向量,之后依次计算得到向量b1、b2,把这些向量b拼成矩阵即为最终的输出。对于这个过程,如果把序列q、k、v用前面的矩阵Q、K、V整体表示,实际上就是前面兔兔给出的那个公式,只不过该该公式以矩阵的形式并行运算,使整个计算过程简洁并且速度更快。

   当然,Lq在很多情况不一定等于Lk,此时若再用上图表示该过程会很乱。所以兔兔用下图来表示scale dot-product attention过程。

三:Scale Dot-Product Attention中的掩码mask问题自注意力(Self-Attention)与Multi-Head Attention机制详解(自注意力机制是什么)

  mask在scale dot-product attention中是可有可无的,在有些情况下使用mask效果会更好,有时则不需要mask。mask作用于scale dot-product attention中的attention weight。前面讲到atttention weights形状是(Lq,Lk),而使用mask时一般是self-attention的情况,此时Lq=Lk,attention weights 为方阵。mask的目的是使方阵上三角为负无穷(或是一个很小的负数),只保留下三角,这样通过softmax后矩阵上三角趋近于0。这样处理的目的是考虑到实际应用中的情况,例如翻译任务中,我们希望在读取句子序列时每次只利用前面读过的词,与后面还没有读到的词句无关。

   实际上,mask的种类可以不止是掩去上三角,根据实际情况也可以使矩阵右侧某些列或任意某些位置为-inf,来掩掉这些位置的信息。

  对于multi-head attention,如果使用mask,则每个head一般都使用相同的mask,此时该模型也称为masked multihead-attention

import numpy as npimport torchweight=torch.randint(0,5,size=(5,5))mask=torch.tensor(np.array([[False,True,True,True,True], [False,False,True,True,True], [False,False,False,True,True], [False,False,False,False,True], [False,False,False,False,False]]))masked_weight=weight.masked_fill(mask,-1000)out=nn.Sigmoid()(masked_weight)print(masked_weight)print(out)'''-------------------------------'''>>>tensor([[ 0, -1000, -1000, -1000, -1000], [ 3, 4, -1000, -1000, -1000], [ 3, 2, 0, -1000, -1000], [ 4, 3, 1, 2, -1000], [ 2, 3, 0, 2, 3]])>>>tensor([[0.5000, 0.0000, 0.0000, 0.0000, 0.0000], [0.9526, 0.9820, 0.0000, 0.0000, 0.0000], [0.9526, 0.8808, 0.5000, 0.0000, 0.0000], [0.9820, 0.9526, 0.7311, 0.8808, 0.0000], [0.8808, 0.9526, 0.5000, 0.8808, 0.9526]])四:Multi-Head Attention结构

  multi-head attention由多个one-head attention组成。我们记一个multi-head attention有n个head,第i个head的权值分别为,则:

这个过程为:输入q,k,v矩阵分别输入各one-head attention,各个head输出矩阵按特征(dim)维度拼接得到新的矩阵,再与矩阵相乘即得到输出(实际上也可以是一个全连接层Linear),并且输出形状仍是(Lq,De)。

  关于其中的参数W,实际上可能会有两种情况,

(1)的形状为:(Lq,De),(Lk,De),(Lk,De),则每个head形状为(Lq,De),拼接后得到的矩阵形状(Lq,n×De),形状为:(n×De,De)。

(2)的形状为:(Lq,De/n),(Lk,De/n),(Lk,De/n)(此时要保证嵌入维度De能整除head数n),则每个head的形状为(Lq,De/n),拼接后得到的矩阵形状(Lq,De),形状为:(De,De)。

虽然这两种方式内部参数不同,但输入与输出数据形状不变。Pytorch中的MuitiheadAttention使用的是方法(2)。

四:对self-attention的理解

  self-attention是multi-head attention三个输入序列都来源于同一序列的情况。设输入序列为input,此时输入的q,k,v三个序列全是input,所以此时Lq=Lk,Dq=Dk=Dv。由于所有输入都是同一个序列,所以也很好理解为什么叫做自注意力。

五:query、key、value的理解与来源

  query、key、value分别为查询、键、值。它们可以由同一个序列得到,也可以是具有实际意义的不同序列。从检索的角度来看,query是需要检索的内容,key是索引,value为待检索的值,attention的过程是计算query与key的相关性,获得attention map,在利用 attention map获取value中的特征值。在self-attention中,query,key,value为同一序列,一般情况下,query为一个序列,key与value为同一序列,更一般情况,query,key,value为三个不同的序列。

六:应用实例1.使用Pytorch构建multi-head attentionclass attention(nn.Module): def __init__(self,embed_dim,num_heads): ''' :param embed_dim: 嵌入特征个数 :param num_heads: scale dot-product attention层数 ''' super(attention, self).__init__() self.embed_dim=embed_dim self.num_heads=num_heads self.w_q=[nn.Linear(embed_dim,embed_dim) for i in range(num_heads)] self.w_k=[nn.Linear(embed_dim,embed_dim) for i in range(num_heads)] self.w_v=[nn.Linear(embed_dim,embed_dim) for i in range(num_heads)] self.w_o=nn.Linear(embed_dim*num_heads,embed_dim) self.softmax=nn.Softmax() def single_head(self,q,k,v,head_idx): '''scale dot-scale attention ''' q=self.w_q[head_idx](q) k=self.w_k[head_idx](k) v=self.w_v[head_idx](v) out=torch.matmul(torch.matmul(q,k.permute(0,2,1)),v)/self.embed_dim return out def forward(self,q,k,v): output=[] for i in range(self.num_heads): out=self.single_head(q,k,v,i) output.append(out) output=torch.cat(output,dim=2) output=self.w_o(output) print(output.shape) return outputif __name__=='__main__': x=torch.randn(size=(3,2,8),dtype=torch.float32) q,k,v=x,x,x att=attention(embed_dim=8,num_heads=4) output,attention_weight=att(q,k,v)2.使用Pytoch中nn.MultiheadAttention方法

在Pytorch中,MultiheadAttention方法中必需参数有2个:

  embed_dim:嵌入维度,即De。

  num_heads:head数

  虽然前面讲到Dq、Dk、Dv、De是可以不等的,但是pytorch中输入的Dq要等于De,并且默认Dv、De也等于De,如果k,v的特征dim不等于De,需要修改kdim,vdim参数。对于接收的数据,pytorch默认形式是(seq,batch,feature),即第一个维度是序列长度,第二个是batch size,第三个是特征dim。如果我们习惯于(batch,seq,feature)形式,可以修改参数batch_first=True。

import torchfrom torch import nnq=torch.randint(0,10,size=(10,9,8),dtype=torch.float32) #batch_size,seq_length,dimk=torch.randint(0,10,size=(10,7,4),dtype=torch.float32)v=torch.randint(0,10,size=(10,7,3),dtype=torch.float32)attention=nn.MultiheadAttention(embed_dim=8,num_heads=4,kdim=4,vdim=3,batch_first=True)attn_output, attn_output_weights=attention(q,k,v)print(attn_output.shape)print(attn_output_weights.shape)

当然,除了这些参数,pytorch的MultiheadAttention中还有更多的参数,例如各种bias,表示是否加入偏置。

七:总结

  自注意力机是multi-head attention模型在所有输入都是同一序列一种情况。multi-head attention结构上是一个或多个one head  attention 平行组合。每个one head attention由scale dot-product attention与三个相应的权值矩阵组成。multi-head attention作为神经网络的单元层种类之一,在许多神经网络模型中具有重要应用,并且它也是当今十分火热的transformer模型的核心结构之一,掌握好这部分内容对transformer的理解具有重要意义。  

本文链接地址:https://www.jiuchutong.com/zhishi/300728.html 转载请保留说明!

上一篇:【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

下一篇:GPT-4:关于下一代人工智能模型的事实、谣言和期望

  • 魅族18spro是不是曲面屏(魅族18 spro)

    魅族18spro是不是曲面屏(魅族18 spro)

  • 红米usb调试在哪里打开(红米的usb调试)

    红米usb调试在哪里打开(红米的usb调试)

  • 小米10s蓝牙耳机怎么连接(小米10S蓝牙耳机设置)

    小米10s蓝牙耳机怎么连接(小米10S蓝牙耳机设置)

  • WPS中PPT模板怎么去水印(wps中ppt模板)

    WPS中PPT模板怎么去水印(wps中ppt模板)

  • 抖音头像变成红色音符是注销了吗(抖音头像变成红色音符怎么解封)

    抖音头像变成红色音符是注销了吗(抖音头像变成红色音符怎么解封)

  • 华为nova4e能不能指关节截屏(华为nova4e能不能语音叫小艺)

    华为nova4e能不能指关节截屏(华为nova4e能不能语音叫小艺)

  • imessag信息是什么(imessag信息怎么使用)

    imessag信息是什么(imessag信息怎么使用)

  • 苹果x怎样设置黑名单打不进电话来(苹果x怎样设置面容解锁)

    苹果x怎样设置黑名单打不进电话来(苹果x怎样设置面容解锁)

  • 抖音作品在什么时间段发布可以热门(抖音作品在什么条件下可以上同城推荐)

    抖音作品在什么时间段发布可以热门(抖音作品在什么条件下可以上同城推荐)

  • oppoa77什么时候上市(oppoa7什么时候上市的)

    oppoa77什么时候上市(oppoa7什么时候上市的)

  • 华为手机屏幕不灵敏是怎么回事(华为手机屏幕不停的自动跳)

    华为手机屏幕不灵敏是怎么回事(华为手机屏幕不停的自动跳)

  • 怎么把ipad声音放大(怎么把ipad声音调小)

    怎么把ipad声音放大(怎么把ipad声音调小)

  • 华为p20pro是多少寸(华为p20pro是多少像素的)

    华为p20pro是多少寸(华为p20pro是多少像素的)

  • 手机安装包损坏怎么办(为什么手机安装包损坏)

    手机安装包损坏怎么办(为什么手机安装包损坏)

  • 淘宝修改评价可以改评分吗(淘宝修改评价可以恢复吗)

    淘宝修改评价可以改评分吗(淘宝修改评价可以恢复吗)

  • 计算器上off代表什么(计算器上off是什么)

    计算器上off代表什么(计算器上off是什么)

  • 滴滴司机可以异地接单吗(滴滴司机可以异地跑车不)

    滴滴司机可以异地接单吗(滴滴司机可以异地跑车不)

  • word怎么做条形码(word怎么做条形统计图并标上数值)

    word怎么做条形码(word怎么做条形统计图并标上数值)

  • Flash动画中如何导入音乐(flash动画如何保存成swf格式)

    Flash动画中如何导入音乐(flash动画如何保存成swf格式)

  • 苹果xsmax屏幕是不是2k(苹果xsmax是什么屏幕)

    苹果xsmax屏幕是不是2k(苹果xsmax是什么屏幕)

  • 解决错误1907视频(错误2019)

    解决错误1907视频(错误2019)

  • 红米手机如何恢复出厂设置(红米手机如何恢复备份数据)

    红米手机如何恢复出厂设置(红米手机如何恢复备份数据)

  • u盘无法读取什么情况(u盘无法读取怎么办)

    u盘无法读取什么情况(u盘无法读取怎么办)

  • 500m宽带用什么无线路由器(500M宽带用什么光猫)

    500m宽带用什么无线路由器(500M宽带用什么光猫)

  • 优酷怎么设置只看主角(优酷怎么设置只看谁的片段)

    优酷怎么设置只看主角(优酷怎么设置只看谁的片段)

  • 手机反应慢是什么原因(手机反应慢是什么情况)

    手机反应慢是什么原因(手机反应慢是什么情况)

  • 在win10系统中控制面板打不开该怎么处理?(w10控制中心在哪)

    在win10系统中控制面板打不开该怎么处理?(w10控制中心在哪)

  • win10电脑记事本怎么保存(win10电脑记事本在哪)

    win10电脑记事本怎么保存(win10电脑记事本在哪)

  • Pradollano滑雪站,西班牙内华达山脉国家公园 (© NTCo/iStock/Getty Images Plus)(paul滑雪)

    Pradollano滑雪站,西班牙内华达山脉国家公园 (© NTCo/iStock/Getty Images Plus)(paul滑雪)

  • vue3 + ts: layout布局(vue3+ts+vite)

    vue3 + ts: layout布局(vue3+ts+vite)

  • 小型微利企业所得税计算公式2023
  • 企业所得税季度预缴
  • 支票退票怎么做账务处理
  • 承兑汇票属于货款吗
  • 企业所得税税率
  • 资本成本与财务风险的区别
  • 政府补助的会计核算形式有哪些
  • 发票报销的条件是什么?
  • 财产理赔收入怎么做账
  • 单位班车费用是福利费吗
  • 内部损益表
  • 专项费用包括哪两种
  • 代交社保费会计账务处理
  • 如何查询增值税申报表
  • 不付供应商尾款了怎么清账
  • 银行手续费要发票什么时候开始的
  • 增值税代扣代缴抵扣
  • 对外投资公司经营范围
  • 培训费增值税专用发票怎么开
  • 农民专业合作社属于什么企业类型
  • 股权转让的标的
  • 税控盘没清盘怎么处罚
  • 发票开错了要退税怎么操作
  • 交通费进项税抵扣计算
  • 公司注销清算债权委托另一股东处理
  • 现金流量表的编制依据
  • 个税汇算清缴包含退休金吗
  • 动产抵押交付生效还是登记生效
  • 城市生活垃圾处理方法
  • 财务收入怎么写
  • 营业外支出与营业收入之比应小于1%的说明
  • 其他业务收入和其他业务成本区别
  • 公司收入可以打折吗
  • PHP:Memcached::decrement()的用法_Memcached类
  • 项目提成比例
  • 出租不动产增值税纳税义务发生时间
  • 企业支付给员工的一次性伤残就业补助金计入哪项费用
  • el-tree方法
  • 社会保险费征收机构责令限期缴纳
  • 手续费在银行系统哪里查
  • 一般纳税人取得普票会计分录
  • 企业所得税季初数怎么填
  • 理解DALL·E 2, Stable Diffusion和 Midjourney工作原理
  • 金融机构存放的保证金存款
  • 固定资产包括无形资产吗?
  • Python中自定义异常
  • 增值税报表附表三
  • 购买材料增值税税率
  • 春节补贴是正数还是负数
  • 核电站弃置费用通常多少钱
  • 负商誉的分录
  • 营改增会计分录怎么做
  • 应付职工薪酬中社保费怎么记账
  • 驾校挂靠车辆账务处理是?
  • 已经认证的进项发票在哪里查询
  • 小规模纳税人怎么算税
  • 小规模纳税人应纳增值税额的计算
  • win8.1 下载
  • win8系统切换桌面
  • windows锁屏界面设置
  • securecrt设置英文
  • win8的运行在哪里打开
  • Msssrv.exe - Msssrv是什么进程 有什么用
  • win8.1卸载软件在哪里
  • win7显示ipv6无网络访问权限
  • 红石cpu教程
  • cocos2dx3.2 android平台APK打包
  • Tutorial 6:Translation Transformation
  • jquery 异步请求
  • shell脚本实例精讲
  • vue缓存数据
  • 以下关于js函数说法错误的是
  • jquery日历插件代码
  • android中的webview
  • 重庆市电子发票样式
  • 电子税务局怎么删除办税员
  • 个体户增值税怎么计算方法
  • 工会经费范围税率是多少
  • 村纪检书记主要工作
  • 沙子属于矿产资源
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设