位置: IT常识 - 正文

Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析(pytorch说明文档)

编辑:rootadmin
Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析

推荐整理分享Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析(pytorch说明文档),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch documentation,pytorch中文文档,pytorch docs,pytorch中文官方文档,pytorch document,pytorch doc,pytorch document,pytorch documents,内容如对您有帮助,希望把文章链接给更多的朋友!

官方文档链接:MultiheadAttention — PyTorch 1.12 documentation

目录

多注意头原理

pytorch的多注意头

解读 官方给的参数解释:

多注意头的pytorch使用

完整的使用代码

多注意头原理

MultiheadAttention,翻译成中文即为多注意力头,是由多个单注意头拼接成的

它们的样子分别为:👇

        单头注意力的图示如下:

单注意力头 ​​ 

        整体称为一个单注意力头,因为运算结束后只对每个输入产生一个输出结果,一般在网络中,输出可以被称为网络提取的特征,那我们肯定希望提取多种特征,[ 比如说我输入是一个修狗狗图片的向量序列,我肯定希望网络提取到特征有形状、颜色、纹理等等,所以单次注意肯定是不够的 ]

        于是最简单的思路,最优雅的方式就是将多个头横向拼接在一起,每次运算我同时提到多个特征,所以多头的样子如下:

多注意力头

        其中的紫色长方块(Scaled Dot-Product Attention)就是上一张单注意力头,内部结构没有画出,如果拼接h个单注意力头,摆放位置就如图所示。

        因为是拼接而成的,所以每个单注意力头其实是各自输出各自的,所以会得到h个特征,把h个特征拼接起来,就成为了多注意力的输出特征。

pytorch的多注意头

        

首先可以看出我们调用的时候,只要写torch.nn.MultiheadAttention就好了,比如👇

import torchimport torch.nn as n# 先决定参数dims = 256 * 10 # 所有头总共需要的输入维度heads = 10 # 单注意力头的总共个数dropout_pro = 0.0 # 单注意力头# 传入参数得到我们需要的多注意力头layer = torch.nn.MultiheadAttention(embed_dim = dims, num_heads = heads, dropout = dropout_pro)解读 官方给的参数解释:

embed_dim - Total dimension of the model 模型的总维度(总输入维度)

        所以这里应该输入的是每个头输入的维度×头的数量

num_heads - Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).

        num_heads即为注意头的总数量        

        注意看括号里的这句话,每个头的维度为 embed_dim除num_heads

        也就是说,如果我的词向量的维度为n,(注意不是序列的维度),我准备用m个头提取序列的特征,则embed_dim这里的值应该是n×m,num_heads的值为m。

【更新】这里其实还是有点小绕的,虽然官文说每个头的维度需要被头的个数除,但是自己在写网络定义时,如果你在输入到多注意力头前到特征为256(举例),这里定义时仍然写成256即可!!,假如你用了4个头,在源码里每个头的特征确实会变成64维,最后又重新拼接成为64乘4=256并输出,但是这个内部过程不用我们自己操心。

还有其他的一些参数可以手动设置:

dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).

bias – If specified, adds bias to input / output projection layers. Default: True.

add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.

add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.

Pytorch文档解读|torch.nn.MultiheadAttention的使用和参数解析(pytorch说明文档)

kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).

vdim – Total number of features for values. Default: None (uses vdim=embed_dim).

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

多注意头的pytorch使用

如果看定义的话应该可以发现:torch.nn.MultiheadAttention是一个类

我们刚刚输入多注意力头的参数,只是’实例化‘出来了我们想要规格的一个多注意力头,

那么想要在训练的时候使用,我们就需要给它喂入数据,也就是调用forward函数,完成前向传播这一动作。

forward函数的定义如下:

forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)

下面是所传参数的解读👇

前三个参数就是attention的三个基本向量元素Q,K,V

query – Query embeddings of shape  for unbatched input,  when batch_first=False or  when batch_first=True, where  is the target sequence length,  is the batch size, and  is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.  

       翻译一下就是说,如果输入不是以batch形式的,query的形状就是,是目标序列的长度,就是query embedding的维度,也就是输入词向量被变换成q后,q的维度,这个注释说是embed_dim, 说明输入词向量和q维度一致;

        若是以batch形式输入,且batch_first=False 则query的形状为,若 batch_first=True,则形状为。【batch_first是’实例化‘时可以设置的,默认为False】

key – Key embeddings of shape for unbatched input, when batch_first=False or when batch_first=True, where S is the source sequence length,is the batch size, and  is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

        key也就是K,同理query,以batch形式,且batch_first=False,则key的形状为。是key embedding的维度,默认也是与相同,则是原序列的长度(source sequence length)

value – Value embeddings of shape for unbatched input,  when batch_first=False or when batch_first=True, where  is the source sequence length,  is the batch size, and  is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

         value是V,与key同理

     其他的参数先不赘述

key_padding_mask – If specified, a mask of shape (N, S)(N,S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be (S)(S). Binary and byte masks are supported. For a binary mask, a True value indicates that the corresponding key value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.

need_weights – If specified, returns attn_output_weights in addition to attn_outputs. Default: True.

attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S)(L,S) or (N\cdot\text{num\_heads}, L, S)(N⋅num_heads,L,S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.

average_attn_weights – If true, indicates that the returned attn_weights should be averaged across heads. Otherwise, attn_weights are provided separately per head. Note that this flag only has an effect when need_weights=True. Default: True (i.e. average weights across heads)

层的输出格式:

attn_output - Attention outputs of shape when input is unbatched,  when batch_first=False or  when batch_first=True, where  is the target sequence length,  is the batch size, and  is the embedding dimension embed_dim.

        以batch输入,且batch_first=False,attention输出的形状为, 是目标序列长度,是batch的大小,是embed_dim(第一步实例化设置的)

attn_output_weights - Only returned when need_weights=True. If average_attn_weights=True, returns attention weights averaged across heads of shape ) when input is unbatched or , where NN is the batch size,is the target sequence length, and S is the source sequence length. If average_weights=False, returns attention weights per head of shapewhen input is unbatched or .

        只有当need_weights的值为True时才返回此参数。

完整的使用代码multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)attn_output, attn_output_weights = multihead_attn(query, key, value)
本文链接地址:https://www.jiuchutong.com/zhishi/287189.html 转载请保留说明!

上一篇:【前端文件下载】直接下载和在浏览器显示下载进度的下载方法(前端实现文件下载功能)

下一篇:踩坑记录1——RK3588编译OpenCV(踩坑视频)

  • 支付宝我的家怎么踢人呢(支付宝我的家怎么存钱)

    支付宝我的家怎么踢人呢(支付宝我的家怎么存钱)

  • 苹果为什么qq有消息手机上方不显示(iphone为什么出现qq有消息没显示)

    苹果为什么qq有消息手机上方不显示(iphone为什么出现qq有消息没显示)

  • 华为手机出现来源不明的照片(华为手机出现来电闪光灯)

    华为手机出现来源不明的照片(华为手机出现来电闪光灯)

  • 电脑显示器接口哪个好(电脑显示器接口叫什么)

    电脑显示器接口哪个好(电脑显示器接口叫什么)

  • 小米5c如何变成全网通(小米5c如何变成中文模式)

    小米5c如何变成全网通(小米5c如何变成中文模式)

  • iphone11扬声器声音小(iphone11扬声器声音不一样大)

    iphone11扬声器声音小(iphone11扬声器声音不一样大)

  • 安卓微信手动删除的聊天记录可以恢复吗(安卓微信手动删除聊天)

    安卓微信手动删除的聊天记录可以恢复吗(安卓微信手动删除聊天)

  • vivo手机无缘无故黑屏(vivo手机无缘无故没有声音)

    vivo手机无缘无故黑屏(vivo手机无缘无故没有声音)

  • mate30 4g和5g版本区别大吗(华为mate30 4g版和5g版有什么区别)

    mate30 4g和5g版本区别大吗(华为mate30 4g版和5g版有什么区别)

  • 电脑关机了又自动启动怎么回事(电脑关机了又自动重启是怎么回事)

    电脑关机了又自动启动怎么回事(电脑关机了又自动重启是怎么回事)

  • 反向有功电量怎么回事(反向有功怎么计算)

    反向有功电量怎么回事(反向有功怎么计算)

  • 百度网盘视频下载放哪里(百度网盘视频下载到手机相册)

    百度网盘视频下载放哪里(百度网盘视频下载到手机相册)

  • 戴尔笔记本电脑i5和i7有什么区别(戴尔笔记本电脑哪款性价比最高)

    戴尔笔记本电脑i5和i7有什么区别(戴尔笔记本电脑哪款性价比最高)

  • mate30 6g 8g区别(华为mate308g和6g区别)

    mate30 6g 8g区别(华为mate308g和6g区别)

  • i59400f配什么主板(i59400f配什么主板可以超频)

    i59400f配什么主板(i59400f配什么主板可以超频)

  • 文件下面横线怎么画(文件下面的横线)

    文件下面横线怎么画(文件下面的横线)

  • vivo手机sos怎么取消(vivo手机sos怎么打开)

    vivo手机sos怎么取消(vivo手机sos怎么打开)

  • oppoa9x私密照片怎么查看(oppo手机私密相片)

    oppoa9x私密照片怎么查看(oppo手机私密相片)

  • 拼多多免拼卡在哪查看(拼多多免拼卡在哪里查看)

    拼多多免拼卡在哪查看(拼多多免拼卡在哪里查看)

  • 怎么才能让自己打字快(怎么才能让自己开心起来不压抑)

    怎么才能让自己打字快(怎么才能让自己开心起来不压抑)

  • 路由器登录网址有几种(360路由器登录网址)

    路由器登录网址有几种(360路由器登录网址)

  • 锁屏有热点资讯怎么关(锁屏热点资讯怎么关闭)

    锁屏有热点资讯怎么关(锁屏热点资讯怎么关闭)

  • 堆内存和栈内存区别(堆内存和栈内存溢出)

    堆内存和栈内存区别(堆内存和栈内存溢出)

  • 苹果开机亮标志又黑屏(苹果手机开机苹果标志闪一下是什么原因)

    苹果开机亮标志又黑屏(苹果手机开机苹果标志闪一下是什么原因)

  • js构造继承有什么优点(js继承方式及其优缺点)

    js构造继承有什么优点(js继承方式及其优缺点)

  • 茶叶自产自销成本核算
  • 个税退税是公司退还是个人退
  • 个人所得税申报错误如何更正申报
  • 一般纳税人需要申报什么税
  • 上市公司回购优先股
  • 正常经营损失
  • 外币借款汇兑差额资本化额怎么计算
  • 退回多交的所得税怎么做分录小规模
  • 企业开发票的人员要经过培训吗?
  • 个人开劳务发票几个点
  • 餐费报销需要发票吗
  • 水电费的发票要交税吗
  • 有产权车位转让需要什么手续和费用
  • 已抵扣未入账的红字发票
  • 制造费用分配的的标准是什么?
  • 公司地址变更代办需要多少钱
  • 购房房产税如何支付
  • 土地使用权评估中的成本法
  • 未生产期间的折旧费记到哪
  • 美元汇户和钞户的区别
  • 劳务公司开发票,劳务公司怎么转取收入
  • 公司水费怎么算
  • php iswriteable
  • 离职补偿金的会计实务处理
  • 餐费计入什么费用
  • php curl_init
  • 使用小程序实现im
  • php使用curl
  • 原生php和框架php的区别
  • 最新预提房租会计分录
  • 电子承兑汇票到期提示付款后多久到账
  • 公司一般户财务负责人是另一公司法人
  • 织梦logo
  • 公司购买基金入什么科目
  • 小微企业的税收优惠政策2023
  • 五种差异化收费方式
  • 无法取得发票的成本能列支吗
  • 会员退费怎么算
  • 一般纳税人收到的普票可以抵扣吗
  • 怎么计算城市
  • 关税怎么入账
  • 收到住宿费普通发票会计分录
  • 暂估入库的账务处理含税吗
  • 土地作为无形资产入账依据
  • 产品检验费怎样计算
  • 境外所得税税收抵免操作指南
  • 装修费用是当月支付吗
  • 原材料用于在建工程
  • 资本公积的核算维度是什么
  • 发出存货的计价应当采用
  • 去年的福利费没有用完,今年可以用吗
  • 维修费收入怎么结转销售成本
  • 加油票的发票抬头怎么写
  • 银行存款手工账
  • 事业单位接受捐赠会计处理
  • 可供出售金融资产和交易性金融资产
  • 回收锯末木屑价格
  • sql server 错误
  • SQL Server的通用分页存储过程 未使用游标,速度更快!
  • redhat配置bond
  • 轻松玩转职场职场沟通与写作技巧答案
  • Windows Server AD 访问数量控制配置方法
  • xp系统如何禁止软件联网
  • xp电脑怎么样
  • mac电脑文件夹怎么重命名
  • win7怎么添加播放设备
  • ubuntu 20.04 unity
  • windows8更新不了怎么办
  • win7电脑全屏
  • 简单强悍是哪首歌
  • 可输入文字查找的软件
  • python处理文本文件代码优化
  • jquery 修改
  • [置顶]JM259194
  • 车辆购置税2024年政策
  • 上海自贸区税务大厅地址
  • 山东济南税务局投诉电话
  • 上海个体工商户怎么报税
  • 即征即退进项税额分摊方法
  • 国税局发票查询电话
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设