位置: IT常识 - 正文

对Transformer中Add&Norm层的理解(transformer中的参数)

编辑:rootadmin
对Transformer中Add&Norm层的理解 对Add&Norm层的理解Add操作Norm操作Add操作

推荐整理分享对Transformer中Add&Norm层的理解(transformer中的参数),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:transform方法会对产生的标量值进行,transform方法会对产生的标量值进行,transformer add norm,transformer中的参数,transformer中的参数,transformer方法,transformer中的参数,transformer中的参数,内容如对您有帮助,希望把文章链接给更多的朋友!

首先我们还是先来回顾一下Transformer的结构:Transformer结构主要分为两大部分,一是Encoder层结构,另一个则是Decoder层结构,Encoder 的输入由 Input Embedding 和 Positional Embedding 求和输入Multi-Head-Attention,再通过Feed Forward进行输出。

由下图可以看出:在Encoder层和Decoder层中都用到了Add&Norm操作,即残差连接和层归一化操作。 什么是残差连接呢?残差连接就是把网络的输入和输出相加,即网络的输出为F(x)+x,在网络结构比较深的时候,网络梯度反向传播更新参数时,容易造成梯度消失的问题,但是如果每层的输出都加上一个x的时候,就变成了F(x)+x,对x求导结果为1,所以就相当于每一层求导时都加上了一个常数项‘1’,有效解决了梯度消失问题。

Norm操作

首先要明白Norm做了一件什么事,从刚开始接触Transformer开始,我认为所谓的Norm就是BatchNorm,但是有一天我看到了这篇文章,才明白了Norm是什么。

假设我们输入的词向量的形状是(2,3,4),2为批次(batch),3为句子长度,4为词向量的维度,生成以下数据:

[[w11, w12, w13, w14], [w21, w22, w23, w24], [w31, w32, w33, w34][w41, w42, w43, w44], [w51, w52, w53, w54], [w61, w62, w63, w64]]

如果是在做BatchNorm(BN)的话,其计算过程如下:BN1=(w11+w12+w13+w14+w41+ w42+w43+w44)/8,同理会得到BN2和BN3,最终得到[BN1,BN2,BN3] 3个mean

对Transformer中Add&Norm层的理解(transformer中的参数)

如果是在做LayerNorm(LN)的话,则会进如下计算:LN1=(w11+w12+w13+w14+w21+ w22+w23+w24+w31+w32+w33+w34)/12,同理会得到LN2,最终得到[LN1,LN2]两个mean

如果是在做InstanceNorm(IN)的话,则会进如下计算:IN1=(w11+w12+w13+w14)/4,同理会得到IN2,IN3,IN4,IN5,IN6,六个mean,[[IN1,IN2,IN3],[IN4,IN5,IN6]] 下图完美的揭示了,这几种Norm 接下来我们来看一下Transformer中的Norm:首先生成[2,3,4]形状的数据,使用原始的编码方式进行编码:

import torchfrom torch.nn import InstanceNorm2drandom_seed = 123torch.manual_seed(random_seed)batch_size, seq_size, dim = 2, 3, 4embedding = torch.randn(batch_size, seq_size, dim)layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)print("y: ", layer_norm(embedding))

输出:

y: tensor([[[ 1.5524, 0.0155, -0.3596, -1.2083], [ 0.5851, 1.3263, -0.7660, -1.1453], [ 0.2864, 0.0185, 1.2388, -1.5437]], [[ 1.1119, -0.3988, 0.7275, -1.4406], [-0.4144, -1.1914, 0.0548, 1.5510], [ 0.3914, -0.5591, 1.4105, -1.2428]]])

接下来手动去进行一下编码:

eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))mean: torch.Size([2, 3, 1])y_custom: tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以发现和LayerNorm的结果是一样的,也就是说明Norm是对d_model进行的Norm,会给我们[batch,sqe_length]形状的平均值。 加下来进行batch_norm,

layer_norm = torch.nn.LayerNorm([seq_size,dim], elementwise_affine = False)eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

输出:

mean: torch.Size([2, 1, 1])y_custom: tensor([[[ 1.1822, 0.4419, -0.3196, -1.9889], [-0.6677, -0.2537, -0.8151, 1.5143], [ 0.7174, 1.2147, -0.0852, -0.9403]], [[-0.0138, 1.5666, -2.1726, 1.0590], [ 0.6646, 0.6852, -0.8706, -0.0442], [-0.1163, 0.1389, 0.4454, -1.3423]]])

可以看到BN的计算的mean形状为[2, 1, 1],并且Norm结果也和上面的两个不一样,这就充分说明了Norm是在对最后一个维度求平均。 那么什么又是Instancenorm呢?接下来再来实现一下instancenorm

instance_norm = InstanceNorm2d(3, affine=False)output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)print(layer_norm(embedding))

输出:

tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以看出无论是layernorm还是instancenorm,还是我们手动去求平均计算其Norm,结果都是一样的,由此我们可以得出一个结论:Layernorm实际上是在做Instancenorm!

如果喜欢文章请点个赞,笔者也是一个刚入门Transformer的小白,一起学习,共同努力。

本文链接地址:https://www.jiuchutong.com/zhishi/296105.html 转载请保留说明!

上一篇:图像融合、Transformer、扩散模型(图像融合名词解释)

下一篇:Vue | Vue.js 全家桶 Pinia状态管理(vue全家桶的app项目代码)

  • 音频服务未运行无法启动怎么解决(音频服务未运行怎么办)(音频服务未运行win7)

    音频服务未运行无法启动怎么解决(音频服务未运行怎么办)(音频服务未运行win7)

  • pr怎么恢复默认界面(pr怎么恢复默认轨道大小)

    pr怎么恢复默认界面(pr怎么恢复默认轨道大小)

  • 京东白条怎么提升额度(京东白条怎么提前还款全部金额)

    京东白条怎么提升额度(京东白条怎么提前还款全部金额)

  • 手机哪个孔是收录声音的(手机各个孔介绍)

    手机哪个孔是收录声音的(手机各个孔介绍)

  • windows search可以关闭吗(windows searchfilterhost)

    windows search可以关闭吗(windows searchfilterhost)

  • 彩铃过滤器校验失败怎么开通(彩铃过滤器校验不通过)

    彩铃过滤器校验失败怎么开通(彩铃过滤器校验不通过)

  • 爱奇艺连续包月第一个月可以取消吗(爱奇艺连续包月12元是什么意思)

    爱奇艺连续包月第一个月可以取消吗(爱奇艺连续包月12元是什么意思)

  • 红米k30pro标准版有光学防抖吗(红米k30pro标准版和变焦版哪个好)

    红米k30pro标准版有光学防抖吗(红米k30pro标准版和变焦版哪个好)

  • 支付宝不让截屏怎么设置(支付成功截图生成器)

    支付宝不让截屏怎么设置(支付成功截图生成器)

  • 钉钉电脑版如何改头像(钉钉电脑版如何删除离职员工)

    钉钉电脑版如何改头像(钉钉电脑版如何删除离职员工)

  • 抖音拉黑和移除的区别(抖音拉黑和移除还能搜到对方吗)

    抖音拉黑和移除的区别(抖音拉黑和移除还能搜到对方吗)

  • 鼠标属于什么设备(鼠标属于什么设置)

    鼠标属于什么设备(鼠标属于什么设置)

  • 在word功能区中拥有的选项卡分别是(在word2007中,功能区的三个主要部分)

    在word功能区中拥有的选项卡分别是(在word2007中,功能区的三个主要部分)

  • 谷歌浏览器电脑上为啥不能打开(谷歌浏览器电脑版怎么下载)

    谷歌浏览器电脑上为啥不能打开(谷歌浏览器电脑版怎么下载)

  • qq等级加速包有什么好处(qq等级加速包有图标吗)

    qq等级加速包有什么好处(qq等级加速包有图标吗)

  • 网线最远传输距离多少(网线最远传输距离)

    网线最远传输距离多少(网线最远传输距离)

  • iphone7p怎么刷机(iphone7p怎么刷机手机)

    iphone7p怎么刷机(iphone7p怎么刷机手机)

  • 拼多多怎样一起下单(拼多多怎样一起结算)

    拼多多怎样一起下单(拼多多怎样一起结算)

  • 抖音如何快速删除我喜欢(抖音如何快速删除粉丝)

    抖音如何快速删除我喜欢(抖音如何快速删除粉丝)

  • 滴滴怎么修改行程目的地(滴滴怎么修改行车路线)

    滴滴怎么修改行程目的地(滴滴怎么修改行车路线)

  • 什么是ssp平台(ssp是指)

    什么是ssp平台(ssp是指)

  • iphonex的3dtouch在哪

    iphonex的3dtouch在哪

  • 前端基本知识介绍(前端基础)

    前端基本知识介绍(前端基础)

  • GRU实现时间序列预测(PyTorch版)(gcn时间序列)

    GRU实现时间序列预测(PyTorch版)(gcn时间序列)

  • 合同履约成本资本化
  • 合同执行过程中应该怎么做
  • 计提递延所得税资产
  • 税收分类编码如果选择大类开票会怎样
  • 汇算清缴时纳税调整表调增金额是怎么算出来的
  • 企业所得税如何规避
  • 个体工商户增值税怎么计算
  • 预计净残值和残值
  • 什么公司不可以上市
  • 其他应收款计提坏账吗
  • 长期待摊费用做在什么记账凭证里
  • 变更法人流程具体流程图
  • 专用发票跨年度能入账吗
  • 代扣代缴附加税怎么做账
  • 公司按揭购车可以抵扣税吗
  • 盘亏设备一台
  • 施工企业已完工程成本如何结转
  • 四级主任科员是什么级别待遇
  • 评估资产没有发票和流水怎么办
  • 2019年城建税减免政策
  • 投资性房地产房产税如何计算
  • 对公账户一直没有流水怎么办
  • 为什么盈余公积补亏不会影响留存收益
  • 华为p50新款
  • 银行收付款凭证是什么
  • php常用函数大全
  • php连接mysql查询数据
  • 奥维尔的瓦兹河岸
  • php公众号
  • 会计核算的方法主要有
  • PHP:imagecolorresolvealpha()的用法_GD库图像处理函数
  • ssm框架集成
  • 普通纳税人怎么交税
  • 我已经用尽了洪荒之力漫画表情
  • 隔两个月发票如何作废
  • 电子回单是什么样子
  • 支付货款没有收据怎么办
  • 公允价值变动损益影响利润总额吗
  • 供货方代垫运费会计分录
  • 国地税合并后工资仍然不一致
  • 利润表的组成是指
  • 开普票需要公对公吗
  • 短期借款预提利息通过短期借款科目核算
  • 如何计算保费合同未规定加成比例
  • 稳岗返还的概念
  • 企业所得税免税和减半征收
  • 建设银行e信通介绍
  • 消费税也是流转税吗
  • 生产过程中报废怎么核算成本
  • 财务预付账款情况说明
  • 开具电费发票如何入账?
  • 个人所得税代扣代缴手续费
  • 商场联营扣点缴纳增值税税率
  • 坏账准备与应收账款的影响有哪些
  • 进项税额减免部分在重点税源表中怎么填
  • 汽车折旧年限及残值率是多少
  • 原始凭证的审核和填制
  • sqlserver存储过程怎么查看
  • ARP欺骗攻击原理
  • xp系统要求
  • ubuntu20怎么连接蓝牙鼠标
  • 高手养成计划 小说
  • win1020h2版本千万别更新
  • svchost占用
  • 怎么手动安装xp系统
  • win7系统开机黑屏自检
  • win7移动软件
  • android:ViewPager与FragmentPagerAdapter
  • css图片标签
  • python基础总结
  • 一个简单的shell脚本
  • 用python写随机数
  • 编写高性能代码时以下哪种技术可用于减少内存访问延迟
  • LinearLayout layout_weight解析
  • python 线程教程
  • jq 使用
  • javascript default
  • 北京税务局网站
  • 关于房地产企业所得税涉税处理表述正确的有
  • 软件著作权可以转让公司吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设