位置: IT常识 - 正文

对Transformer中Add&Norm层的理解(transformer中的参数)

编辑:rootadmin
对Transformer中Add&Norm层的理解 对Add&Norm层的理解Add操作Norm操作Add操作

推荐整理分享对Transformer中Add&Norm层的理解(transformer中的参数),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:transform方法会对产生的标量值进行,transform方法会对产生的标量值进行,transformer add norm,transformer中的参数,transformer中的参数,transformer方法,transformer中的参数,transformer中的参数,内容如对您有帮助,希望把文章链接给更多的朋友!

首先我们还是先来回顾一下Transformer的结构:Transformer结构主要分为两大部分,一是Encoder层结构,另一个则是Decoder层结构,Encoder 的输入由 Input Embedding 和 Positional Embedding 求和输入Multi-Head-Attention,再通过Feed Forward进行输出。

由下图可以看出:在Encoder层和Decoder层中都用到了Add&Norm操作,即残差连接和层归一化操作。 什么是残差连接呢?残差连接就是把网络的输入和输出相加,即网络的输出为F(x)+x,在网络结构比较深的时候,网络梯度反向传播更新参数时,容易造成梯度消失的问题,但是如果每层的输出都加上一个x的时候,就变成了F(x)+x,对x求导结果为1,所以就相当于每一层求导时都加上了一个常数项‘1’,有效解决了梯度消失问题。

Norm操作

首先要明白Norm做了一件什么事,从刚开始接触Transformer开始,我认为所谓的Norm就是BatchNorm,但是有一天我看到了这篇文章,才明白了Norm是什么。

假设我们输入的词向量的形状是(2,3,4),2为批次(batch),3为句子长度,4为词向量的维度,生成以下数据:

[[w11, w12, w13, w14], [w21, w22, w23, w24], [w31, w32, w33, w34][w41, w42, w43, w44], [w51, w52, w53, w54], [w61, w62, w63, w64]]

如果是在做BatchNorm(BN)的话,其计算过程如下:BN1=(w11+w12+w13+w14+w41+ w42+w43+w44)/8,同理会得到BN2和BN3,最终得到[BN1,BN2,BN3] 3个mean

对Transformer中Add&Norm层的理解(transformer中的参数)

如果是在做LayerNorm(LN)的话,则会进如下计算:LN1=(w11+w12+w13+w14+w21+ w22+w23+w24+w31+w32+w33+w34)/12,同理会得到LN2,最终得到[LN1,LN2]两个mean

如果是在做InstanceNorm(IN)的话,则会进如下计算:IN1=(w11+w12+w13+w14)/4,同理会得到IN2,IN3,IN4,IN5,IN6,六个mean,[[IN1,IN2,IN3],[IN4,IN5,IN6]] 下图完美的揭示了,这几种Norm 接下来我们来看一下Transformer中的Norm:首先生成[2,3,4]形状的数据,使用原始的编码方式进行编码:

import torchfrom torch.nn import InstanceNorm2drandom_seed = 123torch.manual_seed(random_seed)batch_size, seq_size, dim = 2, 3, 4embedding = torch.randn(batch_size, seq_size, dim)layer_norm = torch.nn.LayerNorm(dim, elementwise_affine = False)print("y: ", layer_norm(embedding))

输出:

y: tensor([[[ 1.5524, 0.0155, -0.3596, -1.2083], [ 0.5851, 1.3263, -0.7660, -1.1453], [ 0.2864, 0.0185, 1.2388, -1.5437]], [[ 1.1119, -0.3988, 0.7275, -1.4406], [-0.4144, -1.1914, 0.0548, 1.5510], [ 0.3914, -0.5591, 1.4105, -1.2428]]])

接下来手动去进行一下编码:

eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))mean: torch.Size([2, 3, 1])y_custom: tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以发现和LayerNorm的结果是一样的,也就是说明Norm是对d_model进行的Norm,会给我们[batch,sqe_length]形状的平均值。 加下来进行batch_norm,

layer_norm = torch.nn.LayerNorm([seq_size,dim], elementwise_affine = False)eps: float = 0.00001mean = torch.mean(embedding[:, :, :], dim=(-2,-1), keepdim=True)var = torch.square(embedding[:, :, :] - mean).mean(dim=(-2,-1), keepdim=True)print("mean: ", mean.shape)print("y_custom: ", (embedding[:, :, :] - mean) / torch.sqrt(var + eps))

输出:

mean: torch.Size([2, 1, 1])y_custom: tensor([[[ 1.1822, 0.4419, -0.3196, -1.9889], [-0.6677, -0.2537, -0.8151, 1.5143], [ 0.7174, 1.2147, -0.0852, -0.9403]], [[-0.0138, 1.5666, -2.1726, 1.0590], [ 0.6646, 0.6852, -0.8706, -0.0442], [-0.1163, 0.1389, 0.4454, -1.3423]]])

可以看到BN的计算的mean形状为[2, 1, 1],并且Norm结果也和上面的两个不一样,这就充分说明了Norm是在对最后一个维度求平均。 那么什么又是Instancenorm呢?接下来再来实现一下instancenorm

instance_norm = InstanceNorm2d(3, affine=False)output = instance_norm(embedding.reshape(2,3,4,1)) #InstanceNorm2D需要(N,C,H,W)的shape作为输入layer_norm = torch.nn.LayerNorm(4, elementwise_affine = False)print(layer_norm(embedding))

输出:

tensor([[[ 1.1505, 0.5212, -0.1262, -1.5455], [-0.6586, -0.2132, -0.8173, 1.6890], [ 0.6000, 1.2080, -0.3813, -1.4267]], [[-0.0861, 1.0145, -1.5895, 0.6610], [ 0.8724, 0.9047, -1.5371, -0.2400], [ 0.1507, 0.5268, 0.9785, -1.6560]]])

可以看出无论是layernorm还是instancenorm,还是我们手动去求平均计算其Norm,结果都是一样的,由此我们可以得出一个结论:Layernorm实际上是在做Instancenorm!

如果喜欢文章请点个赞,笔者也是一个刚入门Transformer的小白,一起学习,共同努力。

本文链接地址:https://www.jiuchutong.com/zhishi/296105.html 转载请保留说明!

上一篇:图像融合、Transformer、扩散模型(图像融合名词解释)

下一篇:Vue | Vue.js 全家桶 Pinia状态管理(vue全家桶的app项目代码)

  • 博客推广实用的技巧和优化注意事项有三 (博客推广的优缺点)

    博客推广实用的技巧和优化注意事项有三 (博客推广的优缺点)

  • 微信能设置专属提示音吗(微信能设置专属语音铃声吗)

    微信能设置专属提示音吗(微信能设置专属语音铃声吗)

  • ipad pro2021怎么开120hz(iPad Pro2021怎么开高刷)

    ipad pro2021怎么开120hz(iPad Pro2021怎么开高刷)

  • 中国联通流量包月租费怎么取消(中国联通流量包哪个最划算)

    中国联通流量包月租费怎么取消(中国联通流量包哪个最划算)

  • 苹果11怎么分屏多窗口(苹果11怎么分屏两个应用)

    苹果11怎么分屏多窗口(苹果11怎么分屏两个应用)

  • 华为荣耀20青春版耳机孔在哪(华为荣耀20青春版多少钱一台)

    华为荣耀20青春版耳机孔在哪(华为荣耀20青春版多少钱一台)

  • 移动充值卡不能充支付宝了(移动充值卡不能用)

    移动充值卡不能充支付宝了(移动充值卡不能用)

  • 钢化膜能撕下来重贴吗(钢化膜能撕下来再贴吗)

    钢化膜能撕下来重贴吗(钢化膜能撕下来再贴吗)

  • qq备注怎么弄(qq备注咋弄)

    qq备注怎么弄(qq备注咋弄)

  • cookie是干嘛的(cookie是干啥的)

    cookie是干嘛的(cookie是干啥的)

  • prtscsysrq键按了没反应(按prtscr键没反应)

    prtscsysrq键按了没反应(按prtscr键没反应)

  • 作业帮可以连接其他打印机吗(作业帮可以连接其他品牌的打印机吗)

    作业帮可以连接其他打印机吗(作业帮可以连接其他品牌的打印机吗)

  • 手机QQ发说说怎么艾特qq好友(手机QQ发说说怎么保持原画质)

    手机QQ发说说怎么艾特qq好友(手机QQ发说说怎么保持原画质)

  • 微信卸载后了聊天记录还在吗(微信卸载后聊天记录不见了怎么恢复)

    微信卸载后了聊天记录还在吗(微信卸载后聊天记录不见了怎么恢复)

  • 快手关注就是粉丝吗(快手关注就是粉丝怎么办)

    快手关注就是粉丝吗(快手关注就是粉丝怎么办)

  • ipad一边充电一边看视频伤机器吗(ipad一边充电一边用损害大吗)

    ipad一边充电一边看视频伤机器吗(ipad一边充电一边用损害大吗)

  • 为什么键盘上有的数字按不出来(为什么键盘上有些键打不出字来了)

    为什么键盘上有的数字按不出来(为什么键盘上有些键打不出字来了)

  • ipad内存怎么扩大内存(ipad内存)

    ipad内存怎么扩大内存(ipad内存)

  • ipad类纸膜可以贴在钢化膜上吗(ipad类纸膜可以保护眼睛吗)

    ipad类纸膜可以贴在钢化膜上吗(ipad类纸膜可以保护眼睛吗)

  • 重启到恢复模式recovery什么意思(重启到恢复模式怎么解除vivoo)

    重启到恢复模式recovery什么意思(重启到恢复模式怎么解除vivoo)

  • 字节数是什么(字节数是什么单位)

    字节数是什么(字节数是什么单位)

  • 闲鱼买东西先付款吗(闲鱼买东西先付款再调试发货吗)

    闲鱼买东西先付款吗(闲鱼买东西先付款再调试发货吗)

  • 华为yal一al00是什么型号(华为yal-al00)

    华为yal一al00是什么型号(华为yal-al00)

  • 华为p30pro支持光学防抖吗(华为p30pro光线感应怎么设置)

    华为p30pro支持光学防抖吗(华为p30pro光线感应怎么设置)

  • 苹果11电池百分比怎么调(苹果11电池百分比80正常吗)

    苹果11电池百分比怎么调(苹果11电池百分比80正常吗)

  • oppoa57呼吸灯怎么设置(oppo a55呼吸灯)

    oppoa57呼吸灯怎么设置(oppo a55呼吸灯)

  • vivonex3防水吗(vivonex3防水测试视频)

    vivonex3防水吗(vivonex3防水测试视频)

  • 云存储能做什么(云存储能干些什么)

    云存储能做什么(云存储能干些什么)

  • 不能作为微机输出设备的是(不能作为微机输出装备的是什么)

    不能作为微机输出设备的是(不能作为微机输出装备的是什么)

  • 苹果x频繁自动关机(iphonex手机老是自动关机)

    苹果x频繁自动关机(iphonex手机老是自动关机)

  • 免抵退税办法不得抵扣的进项税额
  • 小规模合作社免税吗
  • 应纳所得税额的税率
  • 车船税不交有什么影响 三大影响要注意
  • 代收代缴消费税会计分录
  • 汽车维修公司做账基本流程
  • 装修工程一切险
  • 公司汽车购置税怎么交
  • 应付债券利息计入哪里
  • 销售自己使用过的物品
  • 销售折扣怎么开
  • 分配辅助生产车间成本记账凭证
  • 合作保证金可以退吗
  • 公司是否可以开电子发票
  • 建筑企业一般纳税人提供建筑服务属于老项目
  • 增值税专用发票和普通发票的区别
  • 发票反写是什么时候
  • 购入工程物资的账务处理
  • 发票上可以盖两次章吗
  • 计提资产减值损失账务处理
  • 会计计提和冲回
  • 增值税多交了怎么申请退税
  • 建筑业预缴税款怎么退税
  • win 7怎么办
  • composer.json和composer.lock
  • 补缴社保公积金申请书怎么写
  • 附有退货条款的销售
  • 财政发票可以报销吗
  • 应收票据利息会计科目
  • php实验报告
  • 增值税待认证进项税额
  • 增值税的法律法规最新
  • js如何随机生成字符串
  • centos配置php环境
  • 动力和燃料的区别
  • 低代码框架开发
  • 商品仓储费用会增加吗
  • redis zset源码
  • js如何把字符串转换成数字
  • 网上抄税报税操作流程
  • 年度报表资产总额平均值怎么算
  • python __call__
  • 怎么查电子发票真伪
  • 个人所得税数据怎么导入新电脑
  • sql server功能选择
  • 发票上的印记能去掉吗
  • 权益法的股权比例
  • 资产减值损失科目
  • 以前年度社保计提出错了怎么调整
  • 工会经费记在什么科目
  • 以银行存款支付固定资产修理费
  • 制造费用转入生产成本摘要怎么写
  • 收房租的收据怎么写
  • 暂估成本估多了怎么办
  • 给个体工商户付款可以打到法人卡上吗
  • 提前还贷款要满十八岁吗
  • 如何删除windowsedb
  • Fedora Core 5.0 安装教程,菜鸟图文版(图文界面)
  • 5个经常被忽略的成语
  • ubuntu pdf编辑器
  • centos7脚本
  • 进程 com surrogate
  • win1020h2无法重启
  • win8 开机后无法进入系统
  • js require()
  • 通过手机号怎么查对方的位置
  • cocos2dx开发的游戏有哪些
  • jqueryui
  • dos命令查看磁盘分区
  • jquery读写文件
  • python数据清洗的方法有哪些
  • canvas实例
  • js生成二维数组
  • python3中raw_input的用法
  • js字段截取
  • javascript中的this属性
  • 税务局稽查科是干什么的工作
  • 百望电子发票查询下载
  • 甘肃省国家税务局电子税务局
  • 高新区地税办税服务厅
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设