位置: IT常识 - 正文

Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析(pytorch说明文档)

编辑:rootadmin
Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析

推荐整理分享Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析(pytorch说明文档),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch documentation,pytorch中文文档,pytorch docs,pytorch中文官方文档,pytorch document,pytorch doc,pytorch document,pytorch documents,内容如对您有帮助,希望把文章链接给更多的朋友!

官方文档链接:MultiheadAttention — PyTorch 1.12 documentation

目录

多注意头原理

pytorch的多注意头

解读 官方给的参数解释:

多注意头的pytorch使用

完整的使用代码

多注意头原理

MultiheadAttention,翻译成中文即为多注意力头,是由多个单注意头拼接成的

它们的样子分别为:👇

        单头注意力的图示如下:

单注意力头 ​​ 

        整体称为一个单注意力头,因为运算结束后只对每个输入产生一个输出结果,一般在网络中,输出可以被称为网络提取的特征,那我们肯定希望提取多种特征,[ 比如说我输入是一个修狗狗图片的向量序列,我肯定希望网络提取到特征有形状、颜色、纹理等等,所以单次注意肯定是不够的 ]

        于是最简单的思路,最优雅的方式就是将多个头横向拼接在一起,每次运算我同时提到多个特征,所以多头的样子如下:

多注意力头

        其中的紫色长方块(Scaled Dot-Product Attention)就是上一张单注意力头,内部结构没有画出,如果拼接h个单注意力头,摆放位置就如图所示。

        因为是拼接而成的,所以每个单注意力头其实是各自输出各自的,所以会得到h个特征,把h个特征拼接起来,就成为了多注意力的输出特征。

pytorch的多注意头

        

首先可以看出我们调用的时候,只要写torch.nn.MultiheadAttention就好了,比如👇

import torchimport torch.nn as n# 先决定参数dims = 256 * 10 # 所有头总共需要的输入维度heads = 10 # 单注意力头的总共个数dropout_pro = 0.0 # 单注意力头# 传入参数得到我们需要的多注意力头layer = torch.nn.MultiheadAttention(embed_dim = dims, num_heads = heads, dropout = dropout_pro)解读 官方给的参数解释:

embed_dim - Total dimension of the model 模型的总维度(总输入维度)

        所以这里应该输入的是每个头输入的维度×头的数量

num_heads - Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

        num_heads即为注意头的总数量        

        注意看括号里的这句话,每个头的维度为 embed_dim除num_heads

        也就是说,如果我的词向量的维度为n,(注意不是序列的维度),我准备用m个头提取序列的特征,则embed_dim这里的值应该是n×m,num_heads的值为m。

【更新】这里其实还是有点小绕的,虽然官文说每个头的维度需要被头的个数除,但是自己在写网络定义时,如果你在输入到多注意力头前到特征为256(举例),这里定义时仍然写成256即可!!,假如你用了4个头,在源码里每个头的特征确实会变成64维,最后又重新拼接成为64乘4=256并输出,但是这个内部过程不用我们自己操心。

还有其他的一些参数可以手动设置:

dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).

bias – If specified, adds bias to input / output projection layers. Default: True.

add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.

add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.

Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析(pytorch说明文档)

kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).

vdim – Total number of features for values. Default: None (uses vdim=embed_dim).

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

多注意头的pytorch使用

如果看定义的话应该可以发现:torch.nn.MultiheadAttention是一个类

我们刚刚输入多注意力头的参数,只是’实例化‘出来了我们想要规格的一个多注意力头,

那么想要在训练的时候使用,我们就需要给它喂入数据,也就是调用forward函数,完成前向传播这一动作。

forward函数的定义如下:

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)

下面是所传参数的解读👇

前三个参数就是attention的三个基本向量元素Q,K,V

query – Query embeddings of shape  for unbatched input,  when batch_first=False or  when batch_first=True, where  is the target sequence length,  is the batch size, and  is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.  

       翻译一下就是说,如果输入不是以batch形式的,query的形状就是,是目标序列的长度,就是query embedding的维度,也就是输入词向量被变换成q后,q的维度,这个注释说是embed_dim, 说明输入词向量和q维度一致;

        若是以batch形式输入,且batch_first=False 则query的形状为,若 batch_first=True,则形状为。【batch_first是’实例化‘时可以设置的,默认为False】

key – Key embeddings of shape for unbatched input, when batch_first=False or when batch_first=True, where S is the source sequence length,is the batch size, and  is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

        key也就是K,同理query,以batch形式,且batch_first=False,则key的形状为。是key embedding的维度,默认也是与相同,则是原序列的长度(source sequence length)

value – Value embeddings of shape for unbatched input,  when batch_first=False or when batch_first=True, where  is the source sequence length,  is the batch size, and  is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

         value是V,与key同理

     其他的参数先不赘述

key_padding_mask – If specified, a mask of shape (N, S)(N,S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be (S)(S). Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.

need_weights – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.

attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S)(L,S) or (N\cdot\text{num\_heads}, L, S)(N⋅num_heads,L,S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

average_attn_weights – If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

层的输出格式:

attn_output - Attention outputs of shape when input is unbatched,  when batch_first=False or  when batch_first=True, where  is the target sequence length,  is the batch size, and  is the embedding dimension embed_dim.

        以batch输入,且batch_first=False,attention输出的形状为, 是目标序列长度,是batch的大小,是embed_dim(第一步实例化设置的)

attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads of shape ) when input is unbatched or , where NN is the batch size,is the target sequence length, and S is the source sequence length. If average_weights=False, returns attention weights per head of shapewhen input is unbatched or .

        只有当need_weights的值为True时才返回此参数。

完整的使用代码multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)attn_output, attn_output_weights = multihead_attn(query, key, value)
本文链接地址:https://www.jiuchutong.com/zhishi/287189.html 转载请保留说明!

上一篇:【前端文件下载】直接下载和在浏览器显示下载进度的下载方法(前端实现文件下载功能)

下一篇:踩坑记录1——RK3588编译OpenCV(踩坑视频)

  • 如何利用免费博客平台做外链?(如何利用免费博主赚钱)

    如何利用免费博客平台做外链?(如何利用免费博主赚钱)

  • 三星手机怎么样好用吗(三星手机怎么样)(三星手机怎么样恢复出厂设置)

    三星手机怎么样好用吗(三星手机怎么样)(三星手机怎么样恢复出厂设置)

  • iosnfc怎么复制门禁卡的(nfc iphone怎么复制门禁卡)

    iosnfc怎么复制门禁卡的(nfc iphone怎么复制门禁卡)

  • mrjn2ch是ipad几(mrjp2ch/a是ipad几)

    mrjn2ch是ipad几(mrjp2ch/a是ipad几)

  • 一个手机号码可以绑定几个QQ(一个手机号码可以绑定几个支付宝账户)

    一个手机号码可以绑定几个QQ(一个手机号码可以绑定几个支付宝账户)

  • 钉钉怎么调声音(钉钉怎么调声音类型)

    钉钉怎么调声音(钉钉怎么调声音类型)

  • 对方关机怎么强制打通电话(对方关机了怎样才能唤醒)

    对方关机怎么强制打通电话(对方关机了怎样才能唤醒)

  • 魅族一直显示插着耳机(魅族手机显示在充电但是却充不进去)

    魅族一直显示插着耳机(魅族手机显示在充电但是却充不进去)

  • home键有声音怎么回事(home键声音怎么调出来)

    home键有声音怎么回事(home键声音怎么调出来)

  • ps怎么填充颜色(手机wps怎么填充颜色)

    ps怎么填充颜色(手机wps怎么填充颜色)

  • iPhone11支持动态壁纸么(苹果11设置动态)

    iPhone11支持动态壁纸么(苹果11设置动态)

  • 多次呼转暂时无法接通是什么意思(已经多次呼转)

    多次呼转暂时无法接通是什么意思(已经多次呼转)

  • 怎样在优酷下载视频(怎样在优酷下载视频到手机上)

    怎样在优酷下载视频(怎样在优酷下载视频到手机上)

  • 电脑微信登录怎么升级版(电脑微信登录怎么打印文件)

    电脑微信登录怎么升级版(电脑微信登录怎么打印文件)

  • airpods能连安卓吗(AirPods能连安卓手机吗)

    airpods能连安卓吗(AirPods能连安卓手机吗)

  • 微信解冻需要什么条件(微信解冻需要什么资料)

    微信解冻需要什么条件(微信解冻需要什么资料)

  • 苹果xs有指纹解锁吗(苹果xs指纹解锁在哪里设置)

    苹果xs有指纹解锁吗(苹果xs指纹解锁在哪里设置)

  • 腾讯视频积分买的东西在哪里看(腾讯视频积分买什么划算)

    腾讯视频积分买的东西在哪里看(腾讯视频积分买什么划算)

  • 设置磁盘缓冲区的目的(系统内存中设置磁盘缓冲区)

    设置磁盘缓冲区的目的(系统内存中设置磁盘缓冲区)

  • airpods可以用安卓手机吗(airpods可以用安卓充电吗)

    airpods可以用安卓手机吗(airpods可以用安卓充电吗)

  • 荣耀8x是什么屏幕(荣耀8x采用什么屏幕)

    荣耀8x是什么屏幕(荣耀8x采用什么屏幕)

  • 微信号怎么改成手机号码(微信号怎么改成电话号码不加字母)

    微信号怎么改成手机号码(微信号怎么改成电话号码不加字母)

  • 电脑系统如何安装系统win7?(电脑系统如何安装字体)

    电脑系统如何安装系统win7?(电脑系统如何安装字体)

  • 虚拟机是什么(虚拟机是什么意思)

    虚拟机是什么(虚拟机是什么意思)

  • 云下的麦田,西班牙巴利亚多利德 (© Carlos Javier García Prieto/EyeEm/Getty Images)(云霞下的麦田)

    云下的麦田,西班牙巴利亚多利德 (© Carlos Javier García Prieto/EyeEm/Getty Images)(云霞下的麦田)

  • 天猫提现一直没到账
  • 增值税多提了怎么处理
  • 每年的第一季度
  • 个税申报时提示扣缴单位无有效的税费种认定信息
  • 发票认证抵扣后还有用吗
  • 领用自产应税消费品负担的消费税计入在建工程成本吗
  • 电子商务支付平台有哪些
  • 公司办公室收到上级主管部门的一份
  • 工资退回怎么处理
  • 企业所得税计提金额怎么算
  • 小规模企业的企业所得税怎么交
  • 电子承兑都是银行承兑吗
  • 现金盘亏账务处理分录
  • 项目顾问是什么意思
  • 进项税额转出会计分录月末如何结转
  • 增值税发票备注栏怎么填写
  • 记账凭证是出纳编制吗
  • 1697509029
  • 利润分成的会计分录
  • 电子缴税付款凭证怎么做记账凭证
  • word文档打印时不打印批注
  • 期末调汇凭证怎么调
  • 购买种子怎么做账
  • win10任务栏怎么隐藏
  • macos使用技巧
  • 做胃镜多少钱了
  • 现金流量套期的例子
  • gst插件
  • PHP:mcrypt_enc_self_test()的用法_Mcrypt函数
  • Java8 Stream流Collectors.toMap当key重复时报异常(IllegalStateException)
  • 老生常谈PHP 文件写入和读取(必看篇)
  • 一个实用的php验证
  • 阳光穿透云层是什么效应
  • etc通行费发票可以抵扣吗
  • 应收票据计提利息
  • win11更新22468
  • mysql各种索引的使用场景
  • 融资租赁首付租金定义
  • 虚开普票的立案标准
  • 购买的优惠卷到期后退款
  • 金蝶迷你版怎么打印明细账
  • 记账王怎么打开以前的账套
  • 免税农产品范围目录的文件
  • 一般纳税人建筑劳务税率
  • 一般纳税人什么时候用简易计税
  • 用人单位垫付生育津贴垫付金额和垫付天数
  • 增值税进项税额加计抵减政策
  • 行政事业单位核销固定资产的账务处理
  • 财务费用怎么用
  • 制造费用包括哪些内容科目
  • 在建工程项目包括
  • 贴现费用分录
  • 库存商品发出计价测试
  • 期末没有结账成本怎么办
  • 税控盘反写怎么操作流程
  • 未按照规定编制应急预案的,责令限期改正,可以处罚款
  • 会计及库管岗位职责(要求)
  • win7 64位运行软件提示MSCOMCTL.OCX丢失或无效该怎么办?
  • win7环境变量在哪打开啊
  • xp双系统怎么设置默认系统
  • xp系统怎样设置无线网络连接
  • xp系统1
  • linux系统如何进入终端
  • windows8.1开始
  • linux的curl
  • win10系统开机自动还原
  • bootstrap-treeview.js
  • easyui combobox设置值
  • html用div来写表格
  • 原生js实现ajax步骤
  • javascript中用于声明变量的关键字
  • javascript 基础篇3 类,回调函数,内置对象,事件处理
  • web开发 java
  • python网页验证码
  • jquery判断对象是否存在
  • 支持国税普通发展的原因
  • 企业自建房如何缴税
  • 监察室主任岗位职责
  • 珠海选调生2021公告
  • 自然人电子税务局
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设