位置: IT常识 - 正文

详解Transformer中Self-Attention以及Multi-Head Attention(transformer for)

编辑:rootadmin
详解Transformer中Self-Attention以及Multi-Head Attention

推荐整理分享详解Transformer中Self-Attention以及Multi-Head Attention(transformer for),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:transformer的原理,transformer with,transformer in transformer,transformer with,transformer with,transformer s,transformer.transform,transformer.transform,内容如对您有帮助,希望把文章链接给更多的朋友!

原文名称:Attention Is All You Need 原文链接:https://arxiv.org/abs/1706.03762

如果不想看文章的可以看下我在b站上录的视频:https://b23.tv/gucpvt

最近Transformer在CV领域很火,Transformer是2017年Google在Computation and Language上发表的,当时主要是针对自然语言处理领域提出的(之前的RNN模型记忆长度有限且无法并行化,只有计算完tit_iti​时刻后的数据才能计算ti+1t_{i+1}ti+1​时刻的数据,但Transformer都可以做到)。在这篇文章中作者提出了Self-Attention的概念,然后在此基础上提出Multi-Head Attention,所以本文对Self-Attention以及Multi-Head Attention的理论进行详细的讲解。在阅读本文之前,建议大家先去看下李弘毅老师讲的Transformer的内容。本文的内容是基于李宏毅老师讲的内容加上自己阅读一些源码进行的总结。

文章目录前言Self-AttentionMulti-Head AttentionSelf-Attention与Multi-Head Attention计算量对比Positional Encoding超参对比前言

如果之前你有在网上找过self-attention或者transformer的相关资料,基本上都是贴的原论文中的几张图以及公式,如下图,讲的都挺抽象的,反正就是看不懂(可能我太菜的原因)。就像李弘毅老师课程里讲到的"不懂的人再怎么看也不会懂的"。那接下来本文就结合李弘毅老师课上的内容加上原论文的公式来一个个进行详解。

Self-Attention

下面这个图是我自己画的,为了方便大家理解,假设输入的序列长度为2,输入就两个节点x1,x2x_1, x_2x1​,x2​,然后通过Input Embedding也就是图中的f(x)f(x)f(x)将输入映射到a1,a2a_1, a_2a1​,a2​。紧接着分别将a1,a2a_1, a_2a1​,a2​分别通过三个变换矩阵Wq,Wk,WvW_q, W_k, W_vWq​,Wk​,Wv​(这三个参数是可训练的,是共享的)得到对应的qi,ki,viq^i, k^i, v^iqi,ki,vi(这里在源码中是直接使用全连接层实现的,这里为了方便理解,忽略偏执)。

其中

qqq代表query,后续会去和每一个kkk进行匹配kkk代表key,后续会被每个qqq匹配vvv代表从aaa中提取得到的信息后续qqq和kkk匹配的过程可以理解成计算两者的相关性,相关性越大对应vvv的权重也就越大

假设a1=(1,1),a2=(1,),Wq=(1,1,1)a_1=(1, 1), a_2=(1,0), W^q= \binom{1, 1}{0, 1}a1​=(1,1),a2​=(1,0),Wq=(0,11,1​)那么: q1=(1,1)(1,1,1)=(1,2),   q2=(1,)(1,1,1)=(1,1)q^1 = (1, 1) \binom{1, 1}{0, 1} =(1, 2) , \ \ \ q^2 = (1, 0) \binom{1, 1}{0, 1} =(1, 1)q1=(1,1)(0,11,1​)=(1,2),   q2=(1,0)(0,11,1​)=(1,1) 前面有说Transformer是可以并行化的,所以可以直接写成: (q1q2)=(1,11,)(1,1,1)=(1,21,1)\binom{q^1}{q^2} = \binom{1, 1}{1, 0} \binom{1, 1}{0, 1} = \binom{1, 2}{1, 1}(q2q1​)=(1,01,1​)(0,11,1​)=(1,11,2​) 同理我们可以得到(k1k2)\binom{k^1}{k^2}(k2k1​)和(v1v2)\binom{v^1}{v^2}(v2v1​),那么求得的(q1q2)\binom{q^1}{q^2}(q2q1​)就是原论文中的QQQ,(k1k2)\binom{k^1}{k^2}(k2k1​)就是KKK,(v1v2)\binom{v^1}{v^2}(v2v1​)就是VVV。接着先拿q1q^1q1和每个kkk进行match,点乘操作,接着除以d\sqrt{d}d​得到对应的α\alphaα,其中ddd代表向量kik^iki的长度,在本示例中等于2,除以d\sqrt{d}d​的原因在论文中的解释是“进行点乘后的数值很大,导致通过softmax后梯度变的很小”,所以通过除以d\sqrt{d}d​来进行缩放。比如计算α1,i\alpha_{1, i}α1,i​: α1,1=q1⋅k1d=1×1+2×2=0.71α1,2=q1⋅k2d=1×+2×12=1.41\alpha_{1, 1} = \frac{q^1 \cdot k^1}{\sqrt{d}}=\frac{1\times 1+2\times 0}{\sqrt{2}}=0.71 \\ \alpha_{1, 2} = \frac{q^1 \cdot k^2}{\sqrt{d}}=\frac{1\times 0+2\times 1}{\sqrt{2}}=1.41α1,1​=d​q1⋅k1​=2​1×1+2×0​=0.71α1,2​=d​q1⋅k2​=2​1×0+2×1​=1.41 同理拿q2q^2q2去匹配所有的kkk能得到α2,i\alpha_{2, i}α2,i​,统一写成矩阵乘法形式: (α1,1  α1,2α2,1  α2,2)=(q1q2)(k1k2)Td\binom{\alpha_{1, 1} \ \ \alpha_{1, 2}}{\alpha_{2, 1} \ \ \alpha_{2, 2}}=\frac{\binom{q^1}{q^2}\binom{k^1}{k^2}^T}{\sqrt{d}}(α2,1​  α2,2​α1,1​  α1,2​​)=d​(q2q1​)(k2k1​)T​ 接着对每一行即(α1,1,α1,2)(\alpha_{1, 1}, \alpha_{1, 2})(α1,1​,α1,2​)和(α2,1,α2,2)(\alpha_{2, 1}, \alpha_{2, 2})(α2,1​,α2,2​)分别进行softmax处理得到(α^1,1,α^1,2)(\hat\alpha_{1, 1}, \hat\alpha_{1, 2})(α1,1​,α1,2​)和(α^2,1,α^2,2)(\hat\alpha_{2, 1}, \hat\alpha_{2, 2})(α2,1​,α2,2​),这里的α^\hat{\alpha}α相当于计算得到针对每个vvv的权重。到这我们就完成了Attention(Q,K,V){\rm Attention}(Q, K, V)Attention(Q,K,V)公式中softmax(QKTdk){\rm softmax}(\frac{QK^T}{\sqrt{d_k}})softmax(dk​​QKT​)部分。

上面已经计算得到α\alphaα,即针对每个vvv的权重,接着进行加权得到最终结果: b1=α^1,1×v1+α^1,2×v2=(0.33,0.67)b2=α^2,1×v1+α^2,2×v2=(0.50,0.50)b_1 = \hat{\alpha}_{1, 1} \times v^1 + \hat{\alpha}_{1, 2} \times v^2=(0.33, 0.67) \\ b_2 = \hat{\alpha}_{2, 1} \times v^1 + \hat{\alpha}_{2, 2} \times v^2=(0.50, 0.50)b1​=α1,1​×v1+α1,2​×v2=(0.33,0.67)b2​=α2,1​×v1+α2,2​×v2=(0.50,0.50) 统一写成矩阵乘法形式: (b1b2)=(α^1,1  α^1,2α^2,1  α^2,2)(v1v2)\binom{b_1}{b_2} = \binom{\hat\alpha_{1, 1} \ \ \hat\alpha_{1, 2}}{\hat\alpha_{2, 1} \ \ \hat\alpha_{2, 2}}\binom{v^1}{v^2}(b2​b1​​)=(α2,1​  α2,2​α1,1​  α1,2​​)(v2v1​) 到这,Self-Attention的内容就讲完了。总结下来就是论文中的一个公式: Attention(Q,K,V)=softmax(QKTdk)V{\rm Attention}(Q, K, V)={\rm softmax}(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk​​QKT​)V

Multi-Head Attention

刚刚已经聊完了Self-Attention模块,接下来再来看看Multi-Head Attention模块,实际使用中基本使用的还是Multi-Head Attention模块。原论文中说使用多头注意力机制能够联合来自不同head部分学习到的信息。Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.其实只要懂了Self-Attention模块Multi-Head Attention模块就非常简单了。

详解Transformer中Self-Attention以及Multi-Head Attention(transformer for)

首先还是和Self-Attention模块一样将aia_iai​分别通过Wq,Wk,WvW^q, W^k, W^vWq,Wk,Wv得到对应的qi,ki,viq^i, k^i, v^iqi,ki,vi,然后再根据使用的head的数目hhh进一步把得到的qi,ki,viq^i, k^i, v^iqi,ki,vi均分成hhh份。比如下图中假设h=2h=2h=2然后q1q^1q1拆分成q1,1q^{1,1}q1,1和q1,2q^{1,2}q1,2,那么q1,1q^{1,1}q1,1就属于head1,q1,2q^{1,2}q1,2属于head2。

看到这里,如果读过原论文的人肯定有疑问,论文中不是写的通过WiQ,WiK,WiVW^Q_i, W^K_i, W^V_iWiQ​,WiK​,WiV​映射得到每个head的Qi,Ki,ViQ_i, K_i, V_iQi​,Ki​,Vi​吗: headi=Attention(QWiQ,KWiK,VWiV)head_i = {\rm Attention}(QW^Q_i, KW^K_i, VW^V_i)headi​=Attention(QWiQ​,KWiK​,VWiV​) 但我在github上看的一些源码中就是简单的进行均分,其实也可以将WiQ,WiK,WiVW^Q_i, W^K_i, W^V_iWiQ​,WiK​,WiV​设置成对应值来实现均分,比如下图中的Q通过W1QW^Q_1W1Q​就能得到均分后的Q1Q_1Q1​。

通过上述方法就能得到每个headihead_iheadi​对应的Qi,Ki,ViQ_i, K_i, V_iQi​,Ki​,Vi​参数,接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。 Attention(Qi,Ki,Vi)=softmax(QiKiTdk)Vi{\rm Attention}(Q_i, K_i, V_i)={\rm softmax}(\frac{Q_iK_i^T}{\sqrt{d_k}})V_iAttention(Qi​,Ki​,Vi​)=softmax(dk​​Qi​KiT​​)Vi​

接着将每个head得到的结果进行concat拼接,比如下图中b1,1b_{1,1}b1,1​(head1head_1head1​得到的b1b_1b1​)和b1,2b_{1,2}b1,2​(head2head_2head2​得到的b1b_1b1​)拼接在一起,b2,1b_{2,1}b2,1​(head1head_1head1​得到的b2b_2b2​)和b2,2b_{2,2}b2,2​(head2head_2head2​得到的b2b_2b2​)拼接在一起。

接着将拼接后的结果通过WOW^OWO(可学习的参数)进行融合,如下图所示,融合后得到最终的结果b1,b2b_1, b_2b1​,b2​。

到这,Multi-Head Attention的内容就讲完了。总结下来就是论文中的两个公式: MultiHead(Q,K,V)=Concat(head1,...,headh)WOwhere headi=Attention(QWiQ,KWiK,VWiV){\rm MultiHead}(Q, K, V) = {\rm Concat(head_1,...,head_h)}W^O \\ {\rm where \ head_i = Attention}(QW_i^Q, KW_i^K, VW_i^V)MultiHead(Q,K,V)=Concat(head1​,...,headh​)WOwhere headi​=Attention(QWiQ​,KWiK​,VWiV​)

Self-Attention与Multi-Head Attention计算量对比

在原论文章节3.2.2中最后有说两者的计算量其实差不多。Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.下面做了个简单的实验,这个model文件大家先忽略哪来的。这个Attention就是实现Multi-head Attention的方法,其中包括上面讲的所有步骤。

首先创建了一个Self-Attention模块(单头)a1,然后把proj变量置为Identity(Identity对应的是Multi-Head Attention中最后那个WoW^oWo的映射,单头中是没有的,所以置为Identity即不做任何操作)。再创建一个Multi-Head Attention模块(多头)a2,然后设置8个head。创建一个随机变量,注意shape使用fvcore分别计算两个模块的FLOPsimport torchfrom fvcore.nn import FlopCountAnalysisfrom model import Attentiondef main(): # Self-Attention a1 = Attention(dim=512, num_heads=1) a1.proj = torch.nn.Identity() # remove Wo # Multi-Head Attention a2 = Attention(dim=512, num_heads=8) # [batch_size, num_tokens, total_embed_dim] t = (torch.rand(32, 1024, 512),) flops1 = FlopCountAnalysis(a1, t) print("Self-Attention FLOPs:", flops1.total()) flops2 = FlopCountAnalysis(a2, t) print("Multi-Head Attention FLOPs:", flops2.total())if __name__ == '__main__': main()

终端输出如下, 可以发现确实两者的FLOPs差不多,Multi-Head Attention比Self-Attention略高一点:

Self-Attention FLOPs: 60129542144Multi-Head Attention FLOPs: 68719476736

其实两者FLOPs的差异只是在最后的WOW^OWO上,如果把Multi-Head Attentio的WOW^OWO也删除(即把a2的proj也设置成Identity),可以看出两者FLOPs是一样的:

Self-Attention FLOPs: 60129542144Multi-Head Attention FLOPs: 60129542144Positional Encoding

如果仔细观察刚刚讲的Self-Attention和Multi-Head Attention模块,在计算中是没有考虑到位置信息的。假设在Self-Attention模块中,输入a1,a2,a3a_1, a_2, a_3a1​,a2​,a3​得到b1,b2,b3b_1, b_2, b_3b1​,b2​,b3​。对于a1a_1a1​而言,a2a_2a2​和a3a_3a3​离它都是一样近的而且没有先后顺序。假设将输入的顺序改为a1,a3,a2a_1, a_3, a_2a1​,a3​,a2​,对结果b1b_1b1​是没有任何影响的。下面是使用Pytorch做的一个实验,首先使用nn.MultiheadAttention创建一个Self-Attention模块(num_heads=1),注意这里在正向传播过程中直接传入QKVQKVQKV,接着创建两个顺序不同的QKVQKVQKV变量t1和t2(主要是将q2,k2,v2q^2, k^2, v^2q2,k2,v2和q3,k3,v3q^3, k^3, v^3q3,k3,v3的顺序换了下),分别将这两个变量输入Self-Attention模块进行正向传播。

import torchimport torch.nn as nnm = nn.MultiheadAttention(embed_dim=2, num_heads=1)t1 = [[[1., 2.], # q1, k1, v1 [2., 3.], # q2, k2, v2 [3., 4.]]] # q3, k3, v3t2 = [[[1., 2.], # q1, k1, v1 [3., 4.], # q3, k3, v3 [2., 3.]]] # q2, k2, v2q, k, v = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)print("result1: \n", m(q, k, v))q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)print("result2: \n", m(q, k, v))

对比结果可以发现,即使调换了q2,k2,v2q^2, k^2, v^2q2,k2,v2和q3,k3,v3q^3, k^3, v^3q3,k3,v3的顺序,但对于b1b_1b1​是没有影响的。

为了引入位置信息,在原论文中引入了位置编码positional encodings。To this end, we add "positional encodings" to the input embeddings at the bottoms of the encoder and decoder stacks.如下图所示,位置编码是直接加在输入的a={a1,...,an}a=\{a_1,...,a_n\}a={a1​,...,an​}中的,即pe={pe1,...,pen}pe=\{pe_1,...,pe_n\}pe={pe1​,...,pen​}和a={a1,...,an}a=\{a_1,...,a_n\}a={a1​,...,an​}拥有相同的维度大小。关于位置编码在原论文中有提出两种方案,一种是原论文中使用的固定编码,即论文中给出的sine and cosine functions方法,按照该方法可计算出位置编码;另一种是可训练的位置编码,作者说尝试了两种方法发现结果差不多(但在ViT论文中使用的是可训练的位置编码)。

超参对比

关于Transformer中的一些超参数的实验对比可以参考原论文的表3,如下图所示。其中:

N表示重复堆叠Transformer Block的次数dmodeld_{model}dmodel​表示Multi-Head Self-Attention输入输出的token维度(向量长度)dffd_{ff}dff​表示在MLP(feed forward)中隐层的节点个数h表示Multi-Head Self-Attention中head的个数dk,dvd_k, d_vdk​,dv​表示Multi-Head Self-Attention中每个head的key(K)以及query(Q)的维度PdropP_{drop}Pdrop​表示dropout层的drop_rate

到这,关于Self-Attention、Multi-Head Attention以及位置编码的内容就全部讲完了,如果有讲的不对的地方希望大家指出。

本文链接地址:https://www.jiuchutong.com/zhishi/300341.html 转载请保留说明!

上一篇:有关optimizer.param_groups用法的示例分析(有关的拼音)

下一篇:换了vue3+alova后,老板被我整笑了(vue项目更新后还是老代码)

  • QQ被暂时冻结怎么解除(qq被暂时冻结怎么看恢复时间)

    QQ被暂时冻结怎么解除(qq被暂时冻结怎么看恢复时间)

  • 爱奇艺属于腾讯应用吗(爱奇艺属于腾讯王卡免流吗)

    爱奇艺属于腾讯应用吗(爱奇艺属于腾讯王卡免流吗)

  • 苹果手机抖音怎么看不到好友在线(苹果手机抖音怎么去掉抖音号水印)

    苹果手机抖音怎么看不到好友在线(苹果手机抖音怎么去掉抖音号水印)

  • oppo主题商店里的ART是什么(opop的主题商店)

    oppo主题商店里的ART是什么(opop的主题商店)

  • 华为nova7怎么关闭振动(华为nova7怎么关闭后应用运行)

    华为nova7怎么关闭振动(华为nova7怎么关闭后应用运行)

  • cad滚轮不能放大缩小(cad2012滚轮可以缩放但不能平移)

    cad滚轮不能放大缩小(cad2012滚轮可以缩放但不能平移)

  • qq没有访问操作权限是什么意思(qq没有访问操作权限是怎么回事)

    qq没有访问操作权限是什么意思(qq没有访问操作权限是怎么回事)

  • 三星s9 充电缓慢怎么回事(三星s9充电缓慢怎么办)

    三星s9 充电缓慢怎么回事(三星s9充电缓慢怎么办)

  • matlab是应用软件吗(matlab软件的应用)

    matlab是应用软件吗(matlab软件的应用)

  • 苹果支持多少w快充(苹果支持多少瓦无线快充)

    苹果支持多少w快充(苹果支持多少瓦无线快充)

  • 华为手机背后的标签可以撕掉吗(华为手机背后的玻璃屏碎了多少钱)

    华为手机背后的标签可以撕掉吗(华为手机背后的玻璃屏碎了多少钱)

  • 苹果官方标配都有什么(iphone官方标配都有什么)

    苹果官方标配都有什么(iphone官方标配都有什么)

  • ps怎么改图片格式(ps怎么改图片格式为透明)

    ps怎么改图片格式(ps怎么改图片格式为透明)

  • 苹果11怎么切换超广角(苹果11怎么切换输入法)

    苹果11怎么切换超广角(苹果11怎么切换输入法)

  • 一加7T慢动作录像怎么开启(一加9r慢动作)

    一加7T慢动作录像怎么开启(一加9r慢动作)

  • vivox27能不能防水(vivox27防触摸怎么设置)

    vivox27能不能防水(vivox27防触摸怎么设置)

  • 高通有5g芯片吗(高通有5g芯片吗苹果)

    高通有5g芯片吗(高通有5g芯片吗苹果)

  • 拼多多如何实名认证(拼多多如何实名更改成另外一个人的名字)

    拼多多如何实名认证(拼多多如何实名更改成另外一个人的名字)

  • 毒怎么用花呗分期(毒app花呗分期需要多少额度)

    毒怎么用花呗分期(毒app花呗分期需要多少额度)

  • 华为hry一al00a是什么型号(华为hry 一al00a什么型号)

    华为hry一al00a是什么型号(华为hry 一al00a什么型号)

  • 微信名片怎么显示电话号码(微信名片怎么显示个性签名)

    微信名片怎么显示电话号码(微信名片怎么显示个性签名)

  • wps电子表格如何设置密码(wps电子表格如何调整行间距)

    wps电子表格如何设置密码(wps电子表格如何调整行间距)

  • 王者荣耀电脑版百里守约怎么操作?(王者荣耀电脑版怎么键盘操作)

    王者荣耀电脑版百里守约怎么操作?(王者荣耀电脑版怎么键盘操作)

  • python操作微信客户端:WechatPCAPI库实现自动化回复(python 微信)

    python操作微信客户端:WechatPCAPI库实现自动化回复(python 微信)

  • 增值税电子发票可以作废吗
  • 应交税费账目处理
  • 个人所得税退的多好还是少好
  • 金税四期上线后对企业的影响
  • 公司法人必须办社保吗?
  • 生产型企业怎么退税
  • 预收货款存入银行分录
  • 废弃土地的使用年限
  • 有限责任公司减资的法律规定
  • 期末应交企业所得税怎么算
  • 运输费用 成本
  • 开票方与受票方的区别
  • 公司注销地税时其他应收款要缴纳个人所得税吗?
  • 房产税要来了!租金也要交,最高达到12%
  • 公司个税申报是什么意思
  • 简易计税的劳务公司员工社保可以抵扣增值税吗
  • 每月结转本年利润会计分录
  • mac怎么调整网页大小
  • 票据状态提示付款
  • 利润分配怎么核算
  • 100%控股有什么风险
  • 劳务分包费用组成比例
  • 赠送客户的商品怎么入账
  • linux is
  • 餐饮发票可以计入什么费用
  • linux鼠标左键失灵
  • 公司注册资金存在风险吗
  • 注册资金没有的怎么做账
  • unity导出webgl报错
  • nerf 怎么瞄准
  • php rsa
  • 项目竣工决算审计与工程结算审核的区别是
  • 细说php
  • 用php制作日历2020日历表
  • 银行承兑汇票贴现率是多少
  • 增值税业务发生的时间
  • Vue Element UI 中 el-table 树形数据 tree-props 多层级使用避坑
  • 为什么我的命令提示符里显示user
  • PHP中使用什么关键字声明变量的作用域为全局
  • 帝国cms使用手册
  • 库存商品和固定资产是单位会计资产核算的两项内容
  • 建筑安装的扩展性是指
  • MySQL导入导出命令
  • 盈余公积包括哪两个明细科目
  • 零申报是怎么回事
  • 养老保险减免退税政策
  • 低值易耗品费用计入产品成本的方式有哪几种
  • 预缴税款如何做账
  • 监控 固定资产
  • 在税收方面属于什么领域
  • 固定资产内部转移流程
  • 固定资产处理附件是什么
  • 车辆购置税完税证明图片
  • 投入的资金如何做账
  • 行政事业单位会计风险来源于日常的会计活动
  • 上个月的发票可以作废吗
  • 企业共同控制持股比例怎么算
  • mysql建索引有哪些策略和原则
  • 怎么在bios里设置usb开关
  • 电脑主板bios是什么意思
  • 如何关闭windows防火墙
  • 安全组件异常,请重新下载并安装
  • CentOS 6.2(32位/64位) 安装步骤图文详解
  • linux文件后缀名解释
  • centos运行程序
  • win10系统无法启动
  • Cocos2d-x c++和java相互调用
  • cocos2d rpg
  • centos7如何分区
  • javascript数据
  • js如何实现复制
  • 电子专票票种核定
  • 残疾人就业有哪些选择
  • 销售不动产增值税税率
  • 消费税申报流程图
  • 劳动仲裁受理通知
  • 电子税务局网上登录
  • 注册海外公司如何注册
  • 无锡市国家税务局
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设