位置: IT常识 - 正文

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)

编辑:rootadmin
Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例) 目录1 计算图原理2 基于计算图的传播3 神经网络计算图4 自动微分机5 Pytorch中的自动微分5.1 梯度缓存5.2 参数冻结1 计算图原理

推荐整理分享Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

计算图(Computational Graph)是机器学习领域中推导神经网络和其他模型算法,以及软件编程实现的有效工具。

计算图的核心是将模型表示成一张拓扑有序(Topologically Ordered)的有向无环图(Directed Acyclic Graph),其中每个节点uiu_iui​包含数值信息(可以是标量、向量、矩阵或张量)和算子信息fif_ifi​。拓扑有序指当前节点仅在全体指向它的节点被计算后才进行计算。

计算图的优点在于:

可以通过基本初等映射 的拓扑联结,形成复合的复杂模型,大多数神经网络模型都可以被计算图表示;便于实现自动微分机(Automatic Differentiation Machine),对给定计算图可基于链式法则由节点局部梯度进行反向传播。

计算图的基本概念如表所示,基于计算图的基本前向传播和反向传播算法如表

符号含义nnn计算图的节点数lll计算图的叶节点数LLL计算图的叶节点索引集CCC计算图的非叶节点索引集EEE计算图的有向边集合uiu_iui​计算图中的第iii节点或其值did_idi​uiu_iui​ 的维度fif_ifi​uiu_iui​的算子αi\alpha _iαi​uiu_iui​的全体关联输入Jj→i\boldsymbol{J}_{j\rightarrow i}Jj→i​节点uiu_iui​关于节点uju_juj​的雅克比矩阵Pi\boldsymbol{P}_iPi​输出节点关于输入节点的雅克比矩阵2 基于计算图的传播

基于计算图的前向传播算法如下

基于计算图的反向传播算法如下

以第一节的图为例,可知E={(1,3),(2,3),(2,4),(3,4)}E=\left\{ \left( 1,3 \right) ,\left( 2,3 \right) ,\left( 2,4 \right) ,\left( 3,4 \right) \right\}E={(1,3),(2,3),(2,4),(3,4)}。首先进行前向传播:

{u3=u1+u2=5u4=u2u3=15\begin{cases} u_3=u_1+u_2=5\\ u_4=u_2u_3=15\\\end{cases}{u3​=u1​+u2​=5u4​=u2​u3​=15​

{J1→3=∂u3/∂u1=1J2→3=∂u3/∂u2=1J2→4=∂u4/∂u2=u3=5J3→4=∂u4/∂u3=u2=3\begin{cases} \boldsymbol{J}_{1\rightarrow 3}={{\partial u_3}/{\partial u_1=}}1\\ \boldsymbol{J}_{2\rightarrow 3}={{\partial u_3}/{\partial u_2=}}1\\ \boldsymbol{J}_{2\rightarrow 4}={{\partial u_4}/{\partial u_2=}}u_3=5\\ \boldsymbol{J}_{3\rightarrow 4}={{\partial u_4}/{\partial u_3=}}u_2=3\\\end{cases}⎩⎨⎧​J1→3​=∂u3​/∂u1​=1J2→3​=∂u3​/∂u2​=1J2→4​=∂u4​/∂u2​=u3​=5J3→4​=∂u4​/∂u3​=u2​=3​

Pytorch深度学习实战3-5:详解计算图与自动微分机(附实例)

接着进行反向传播:

{P4=1P3=P4J3→4=3P2=P4J2→4+P3J2→3=8P1=P3J1→3=3\begin{cases} \boldsymbol{P}_4=1\\ \boldsymbol{P}_3=\boldsymbol{P}_4\boldsymbol{J}_{3\rightarrow 4}=3\\ \boldsymbol{P}_2=\boldsymbol{P}_4\boldsymbol{J}_{2\rightarrow 4}+\boldsymbol{P}_3\boldsymbol{J}_{2\rightarrow 3}=8\\ \boldsymbol{P}_1=\boldsymbol{P}_3\boldsymbol{J}_{1\rightarrow 3}=3\\\end{cases}⎩⎨⎧​P4​=1P3​=P4​J3→4​=3P2​=P4​J2→4​+P3​J2→3​=8P1​=P3​J1→3​=3​

3 神经网络计算图

一个神经网络的计算图实例如下,所有参数都可以用之前的模型表示

L{u1=W1∈Rn1×nu2=b1∈Rn1u3=x∈Rnu4=W2∈Rn2×n1u5=b2∈Rn2u6=y∈Rn2  C{u7=z1∈Rn1=W1x+b1u8=a1∈Rn1=σ(z1)u9=z2∈Rn2=W2a1+b2u10=y∈Rn2=σ(z2)u11=E∈R=12(y−y~)T(y−y~)L\begin{cases} u_1=\boldsymbol{W}^1\in \mathbb{R} ^{n_1\times n_0}\\ u_2=\boldsymbol{b}^1\in \mathbb{R} ^{n_1}\\ u_3=\boldsymbol{x}\in \mathbb{R} ^{n_0}\\ u_4=\boldsymbol{W}^2\in \mathbb{R} ^{n_2\times n_1}\\ u_5=\boldsymbol{b}^2\in \mathbb{R} ^{n_2}\\ u_6=\boldsymbol{y}\in \mathbb{R} ^{n_2}\\\end{cases}\,\, C\begin{cases} u_7=\boldsymbol{z}^1\in \mathbb{R} ^{n_1}=\boldsymbol{W}^1\boldsymbol{x}+\boldsymbol{b}^1\\ u_8=\boldsymbol{a}^1\in \mathbb{R} ^{n_1}=\sigma \left( \boldsymbol{z}^1 \right)\\ u_9=\boldsymbol{z}^2\in \mathbb{R} ^{n_2}=\boldsymbol{W}^2\boldsymbol{a}^1+\boldsymbol{b}^2\\ u_{10}=\boldsymbol{y}\in \mathbb{R} ^{n_2}=\sigma \left( \boldsymbol{z}^2 \right)\\ u_{11}=E\in \mathbb{R} =\frac{1}{2}\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right) ^T\left( \boldsymbol{y}-\boldsymbol{\tilde{y}} \right)\\\end{cases}L⎩⎨⎧​u1​=W1∈Rn1​×n0​u2​=b1∈Rn1​u3​=x∈Rn0​u4​=W2∈Rn2​×n1​u5​=b2∈Rn2​u6​=y∈Rn2​​C⎩⎨⎧​u7​=z1∈Rn1​=W1x+b1u8​=a1∈Rn1​=σ(z1)u9​=z2∈Rn2​=W2a1+b2u10​=y∈Rn2​=σ(z2)u11​=E∈R=21​(y−y~​)T(y−y~​)​

4 自动微分机

自动微分机的基本原理是:

跟踪记录从输入张量到输出张量的计算过程,并生成一幅前向传播计算图,计算图中的节点与张量一一对应;基于计算图反向传播原理即可链式地求解输出节点关于各节点的梯度。

必须指出,Pytorch不允许张量对张量求导,故输出节点必须是标量,通常为损失函数或输出向量的加权和;为节约内存,每次反向传播后Pytorch会自动释放前向传播计算图,即销毁中间计算节点的梯度和节点间的连接结构。

5 Pytorch中的自动微分

Tensor在自动微分机中的重要属性如表所示。

属性含义device该节点运行的设备环境,即CPU/GPUrequires_grad自动微分机是否需要对该节点求导,缺省为Falsegrad输出节点对该节点的梯度,缺省为Nonegrad_fn中间计算节点关于全体输入节点的映射,记录了前向传播经过的操作。叶节点为Noneis_leaf该节点是否为叶节点

完成前向传播后,调用反向传播API即可更新各节点梯度,具体如下

backward(gradient=None, retain_graph=None, create_graph=None)

其中

gradient是权重向量,当输出节点yyy不为标量时需指定与其同维的gradient,并以标量gradientTygradient^TygradientTy为输出进行反向传播retain_graph用于缓存前向传播计算图,可应用于一次传播测试多个损失函数等情形;creat_graph用于构造导数计算图,可用于进一步求解高阶导数。5.1 梯度缓存

中间计算节点的梯度需要通过retain_grad()方法进行缓存

w1 = torch.tensor([[2.], [3.]], requires_grad=True)b1 = torch.tensor([1.], requires_grad=True)x = torch.tensor([[10.], [20.]])y = torch.mm(w1.transpose(0, 1), x) + b1y.retain_grad()# 若不缓存则y.grad=Noneout = 3*yout.backward()>> tensor([[30.], [60.]]) tensor([3.]) None tensor([[3.]])5.2 参数冻结

若希望冻结网络部分参数,只调整优化另一部分参数;或按顺序训练分支网络而屏蔽对主网络梯度的,可使用detach()方法从计算图中分离节点,阻断反向传播。分离的节点与原节点共享值内存,但不具有grad和grad_fn属性。

# 记第一层网络w1-b1为f,第二层网络w2-b2为gw1 = torch.tensor([[2.], [3.]], requires_grad=True)w2 = torch.tensor([3.], requires_grad=True)b1 = torch.tensor([1.], requires_grad=True)b2 = torch.tensor([2.], requires_grad=True)x = torch.tensor([[10.], [20.]])y = torch.mm(w1.transpose(0, 1), x) + b1y_ = y.detach()z = w2 * y_ + b2out = 3*zout.backward()print(w1.grad, b1.grad, w2.grad, b2.grad)>> None None tensor([243.]) tensor([3.]) # f被冻结,梯度不更新# 若不使用detach冻结y之前的网络,则>> tensor([[ 90.], [180.]]) tensor([9.]) tensor([243.]) tensor([3.])

🔥 更多精彩专栏:

《ROS从入门到精通》《Pytorch深度学习实战》《机器学习强基计划》《运动规划实战精讲》…

👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

本文链接地址:https://www.jiuchutong.com/zhishi/300350.html 转载请保留说明!

上一篇:web前端期末大作业实例 (1500套) 集合(web前端期末大作业旅游页面)

下一篇:OpenAI GPT-3模型详解(gpt3 模型大小)

  • SEO内容为王之怎样创造伪原创(seo内容建设)

    SEO内容为王之怎样创造伪原创(seo内容建设)

  • 步步为赢——三步创建强势品牌(步步为赢百科)

    步步为赢——三步创建强势品牌(步步为赢百科)

  • 前面板的耳机排线接口(前面板没声音)(前面板和后面板的耳机孔)

    前面板的耳机排线接口(前面板没声音)(前面板和后面板的耳机孔)

  • 儿童新冠疫苗接种记录在哪里查询(儿童新冠疫苗接种2022年新规)

    儿童新冠疫苗接种记录在哪里查询(儿童新冠疫苗接种2022年新规)

  • qq里说的属性是啥(qq属性在哪里)

    qq里说的属性是啥(qq属性在哪里)

  • 充电宝鼓起来怎么处理(充电宝鼓起来怎么扔)

    充电宝鼓起来怎么处理(充电宝鼓起来怎么扔)

  • 升5g要换手机卡吗(升5g需要换卡么)

    升5g要换手机卡吗(升5g需要换卡么)

  • medly安卓闪退(medly安卓版下载教程)

    medly安卓闪退(medly安卓版下载教程)

  • vivoy50是5G手机吗(vivoy50是5g手机吗?)

    vivoy50是5G手机吗(vivoy50是5g手机吗?)

  • 苹果减弱动态效果什么意思(苹果减弱动态效果费电吗)

    苹果减弱动态效果什么意思(苹果减弱动态效果费电吗)

  • 夏新蓝牙耳机怎么只有一个耳机能播放(夏新蓝牙耳机怎么样)

    夏新蓝牙耳机怎么只有一个耳机能播放(夏新蓝牙耳机怎么样)

  • 手淘和淘宝有什么区别(手淘是指手机淘宝吗)

    手淘和淘宝有什么区别(手淘是指手机淘宝吗)

  • 加入粉丝团可以退出吗(加入粉丝团可以获得专属优惠券话术)

    加入粉丝团可以退出吗(加入粉丝团可以获得专属优惠券话术)

  • 苹果手机卡屏是什么原因造成的(苹果手机卡屏是什么原因)

    苹果手机卡屏是什么原因造成的(苹果手机卡屏是什么原因)

  • 计算机网络中wan中文名(计算机网络中网关的主要作用是)

    计算机网络中wan中文名(计算机网络中网关的主要作用是)

  • 手机怎样在电视上投屏(手机怎样在电视投屏?)

    手机怎样在电视上投屏(手机怎样在电视投屏?)

  • 华为mate30闪退怎么办(华为mate30总闪退)

    华为mate30闪退怎么办(华为mate30总闪退)

  • 黑鲨手机skr ao是几代(黑鲨skwa0是什么机型)

    黑鲨手机skr ao是几代(黑鲨skwa0是什么机型)

  • xr支持18w快充吗(苹果xr支持18w充电吗)

    xr支持18w快充吗(苹果xr支持18w充电吗)

  • oppoa3智能语音助手怎么打开(oppo a3语音助手)

    oppoa3智能语音助手怎么打开(oppo a3语音助手)

  • 如何通过QQ号查手机号(如何通过qq号查别人的个人信息)

    如何通过QQ号查手机号(如何通过qq号查别人的个人信息)

  • 搜狗浏览器如何设置不显示图片(搜狗浏览器如何截图)

    搜狗浏览器如何设置不显示图片(搜狗浏览器如何截图)

  • 快手显示无法连接网络(快手显示无法连接网络怎么办)

    快手显示无法连接网络(快手显示无法连接网络怎么办)

  • Win10修改hosts文件无法保存的解决方法(不用更改权限)(win10修改hosts文件权限)

    Win10修改hosts文件无法保存的解决方法(不用更改权限)(win10修改hosts文件权限)

  • 进项税额转出会影响利润吗
  • 产权转移数据的交易价格和固定资产科目
  • 扶贫入股分红能领多久
  • 企业多交所得税不想退税在电子税务局如何处理
  • 销售过程中客户买的是什么
  • 虚开增值税发票的涉税风险如何防范
  • 私营公司会计资取公司资金
  • 负数发票作废了对原来的正数发票有什么影响
  • 盈余公积可用于集体福利吗
  • 营改增后房地产公司税种及税率
  • 企业事故赔偿支出可以抵税吗
  • 预付房租是否需要分摊处理呢?
  • 营改增后取得施工作业收入需要交哪些税?
  • 出口退税要交企业所得吗
  • 增值税专用发票有效期是多长时间
  • 建筑企业增值税预缴
  • 加计扣除申报表填报说明
  • 加油的普票可以抵扣进项税吗
  • 固定资产对外投资增值税
  • 记账凭证的分类和基本内容
  • 生产企业出口退税全部流程
  • 总额法和净额法哪个合理
  • 法人变更后的涉税问题
  • 车辆保险抵扣会计分录
  • 专票可以当普票用不抵扣吗
  • 当月凭证做完怎么结转?
  • 交通费,通讯费均按照上级行标准领取
  • 子网掩码和默认网关怎么填
  • 2016年最佳歌曲
  • win10怎么用wifi上网
  • 在建工程的施工方案可以外传吗
  • 购买机器配件怎么做会计分录
  • linux命令行怎么用
  • 已删除好友的聊天记录
  • wlms.exe是什么
  • 冲回上年多提的费用会计分录
  • php strrchr
  • 哪些货物出口不需要运输条件
  • 企业法人个人贷款企业有风险吗
  • 基础代谢
  • php并发编程
  • 记账凭证和原始凭证都是登记账簿的直接依据
  • 阿尔莫什
  • 成本票不够怎么交税
  • 融资租入固定资产的改建支出
  • 售后租回交易形式是什么
  • 不交社保个税怎么处理
  • 公司帮非公司员工缴税
  • 出售汽车固定资产要交什么税
  • 向境外机构支付的服务费税收政策
  • 抵扣联和发票联算一张发票吗
  • 餐饮服务税率是服务类税率还是货物类
  • 企业法人需要本人到场吗
  • mysql1290报错
  • 会计常用分录
  • 本月发生费用未支付会计处理
  • 财政补助收入怎样申报企业所得税
  • 报销单可以当记账凭证吗
  • 快递费用是否可以开发票
  • 小规模减免附加税的会计处理
  • 季报企业所得税弥补亏损数怎么填
  • 会议中发生的相关事件
  • 五险怎么做账
  • 公对私 预付款 税
  • 存储过程实现业务逻辑
  • ubuntu20.04怎么用
  • ubuntu系统安装教程详细
  • mac os 删除
  • 自动登录xp系统怎么办
  • Win7如何安装音频设备
  • 跑跑跑游戏
  • 批处理替换文件中的某个内容
  • 文件包解密
  • 我们要什么行政执法监督机制和能力建设严格落实行政
  • 纳税服务主要职责
  • 广东省广州市税务局分数线
  • 泉州企业医保哪里缴费
  • 税控盘清卡的步骤
  • 广西残疾人保障金比例
  • 美国网购消费者个人信息保护法
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设