位置: IT常识 - 正文

Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad(pytorch sgd优化器)

编辑:rootadmin
Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

目录

写在前面

一、 torch.optim.SGD 随机梯度下降

SGD代码

SGD算法解析

1.MBGD(Mini-batch Gradient Descent)小批量梯度下降法

 2.Momentum动量

3.NAG(Nesterov accelerated gradient)

SGD总结

二、torch.optim.ASGD随机平均梯度下降

三、torch.optim.Rprop

四、torch.optim.Adagrad 自适应梯度

Adagrad 代码

Adagrad 算法解析

AdaGrad总结


推荐整理分享Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad(pytorch sgd优化器),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:scipy优化器,python优化工具箱,pytorch adam优化器参数,scipy优化器,pytorch sgd优化器,pytorch sgd优化器,pytorch adam优化器参数,pytorch sgd优化器,内容如对您有帮助,希望把文章链接给更多的朋友!

优化器系列文章列表

Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

Pytorch优化器全总结(二)Adadelta、RMSprop、Adam、Adamax、AdamW、NAdam、SparseAdam

Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码

Pytorch优化器全总结(四)常用优化器性能对比 含代码

写在前面

        优化器时深度学习中的重要组件,在深度学习中有举足轻重的地位。在实际开发中我们并不用亲手实现一个优化器,很多框架都帮我们实现好了,但如果不明白各个优化器的特点,就很难选择适合自己任务的优化器。接下来我会开一个系列,以Pytorch为例,介绍所有主流的优化器,如果都搞明白了,对优化器算法的掌握也就差不多了。

        作为系列的第一篇文章,本文介绍Pytorch中的SGD、ASGD、Rprop、Adagrad,其中主要介绍SGD和Adagrad。因为这四个优化器出现的比较早,都存在一些硬伤,而作为现在主流优化器的基础又跳不过,所以作为开端吧。

        我们定义一个通用的思路框架,方便在后面理解各算法之间的关系和改进。首先定义待优化参数 ,目标函数,学习率为  ,然后我们进行迭代优化,假设当前的epoch为,参数更新步骤如下:

1. 计算目标函数关于当前参数的梯度: 

                               (1)

 2. 根据历史梯度计算一阶动量和二阶动量:

                (2)

                 (3)

 3. 计算当前时刻的下降梯度: 

                           (4)

4. 根据下降梯度进行更新:  

                        (5)

        下面介绍的所有优化算法基本都能套用这个流程,只是式子(4)的形式会有变化。

一、 torch.optim.SGD 随机梯度下降

        该类可实现 SGD 优化算法,带动量 的SGD 优化算法和带 NAG(Nesterov accelerated gradient)的 SGD 优化算法,并且均可拥有 weight_decay(权重衰减) 项。

SGD代码'''params(iterable)- 参数组,优化器要优化的那部分参数。lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。momentum(float)- 动量,通常设置为 0.9,0.8dampening(float)- dampening for momentum ,暂时不了其功能,在源码中是这样用的:buf.mul_(momentum).add_(1 - dampening, d_p),值得注意的是,若采用nesterov,dampening 必须为 0.weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数nesterov(bool)- bool 选项,是否使用 NAG(Nesterov accelerated gradient)'''class torch.optim.SGD(params, lr=<object object>, momentum=0, dampening=0, weight_decay=0, nesterov=False)SGD算法解析1.MBGD(Mini-batch Gradient Descent)小批量梯度下降法

        明明类名是SGD,为什么介绍MBGD呢,因为在Pytorch中,torch.optim.SGD其实是实现的MBGD,要想使用SGD,只要将batch_size设成1就行了。

        MBGD就是结合BGD和SGD的折中,对于含有 n个训练样本的数据集,每次参数更新,选择一个大小为 m(m<n) 的mini-batch数据样本计算其梯度,其参数更新公式如下,其中是一个batch的开始:

                (6)

优点:使用mini-batch的时候,可以收敛得很快,有一定摆脱局部最优的能力。

缺点:a.在随机选择梯度的同时会引入噪声,使得权值更新的方向不一定正确

           b.不能解决局部最优解的问题

 2.Momentum动量

         动量是一种有助于在相关方向上加速SGD并抑制振荡的方法,通过将当前梯度与过去梯度加权平均,来获取即将更新的梯度。如下图b图所示。它通过将过去时间步长的更新向量的一小部分添加到当前更新向量来实现这一点:

 动量项通常设置为0.9或类似值。

参数更新公式如下,其中ρ 是动量衰减率,m是速率(即一阶动量):

                         (7)

                (8)

                   (9)

3.NAG(Nesterov accelerated gradient)

        NAG的思想是在动量法的基础上展开的。动量法是思想是,将当前梯度与过去梯度加权平均,来获取即将更新的梯度。在知道梯度之后,更新自变量到新的位置。也就是说我们其实在每一步,是知道下一时刻位置的。这时Nesterov就说了:那既然这样的话,我们何不直接采用下一时刻的梯度来和上一时刻梯度进行加权平均呢?下面两张图看明白,就理解NAG了:

Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad(pytorch sgd优化器)

        

 

NAG和经典动量法的差别就在B点和C点梯度的不同。 

 参数更新公式:

                (10)

                        (11)

                           (12)

        上式中的就是图中的B到C那一段向量,就是C点坐标(参数)。可以看到NAG除了式子(10)与式子(7)有所不同,其余公式和Momentum是一样的。

        一般情况下NAG方法相比Momentum收敛速度快、波动也小。实际上NAG方法用到了二阶信息,所以才会有这么好的结果。

         Nesterov动量梯度的计算在模型参数施加当前速度之后,因此可以理解为往标准动量中添加了一个校正因子。在凸批量梯度的情况下,Nesterov动量将额外误差收敛率从(k步后)改进到  ,然而,在随机梯度情况下,Nesterov动量对收敛率的作用却不是很大。

SGD总结

使用了Momentum或NAG的MBGD有如下特点:

优点:加快收敛速度,有一定摆脱局部最优的能力,一定程度上缓解了没有动量的时候的问题

缺点:a.仍然继承了一部分SGD的缺点

          b.在随机梯度情况下,NAG对收敛率的作用不是很大

          c.Momentum和NAG都是为了使梯度更新更灵活。但是人工设计的学习率总是有些生硬,下面介绍几种自适应学习率的方法。

推荐程度:带Momentum的torch.optim.SGD 可以一试。

二、torch.optim.ASGD随机平均梯度下降

        ASGD 也称为 SAG,表示随机平均梯度下降(Averaged Stochastic Gradient Descent),简单地说 ASGD 就是用空间换时间的一种 SGD,因为很少使用,所以不详细介绍,详情可参看论文: http://riejohnson.com/rie/stograd_nips.pdf

'''params(iterable)- 参数组,优化器要优化的那些参数。lr(float)- 初始学习率,可按需随着训练过程不断调整学习率。lambd(float)- 衰减项,默认值 1e-4。alpha(float)- power for eta update ,默认值 0.75。t0(float)- point at which to start averaging,默认值 1e6。weight_decay(float)- 权值衰减系数,也就是 L2 正则项的系数。'''class torch.optim.ASGD(params, lr=0.01, lambd=0.0001, alpha=0.75, t0=1000000.0, weight_decay=0)

 推荐程度:不常见

三、torch.optim.Rprop

        该类实现 Rprop 优化方法(弹性反向传播),适用于 full-batch,不适用于 mini-batch,因而在 mini-batch 大行其道的时代里,很少见到。

'''params - 参数组,优化器要优化的那些参数。lr - 学习率etas (Tuple[float, float])- 乘法增减因子step_sizes (Tuple[float, float]) - 允许的最小和最大步长'''class torch.optim.Rprop(params, lr=0.01, etas=(0.5, 1.2), step_sizes=(1e-06, 50))

优点:它可以自动调节学习率,不需要人为调节

缺点:仍依赖于人工设置一个全局学习率,随着迭代次数增多,学习率会越来越小,最终会趋近于0

推荐程度:不推荐

四、torch.optim.Adagrad 自适应梯度

        该类可实现 Adagrad 优化方法(Adaptive Gradient),Adagrad 是一种自适应优化方法,是自适应的为各个参数分配不同的学习率。这个学习率的变化,会受到梯度的大小和迭代次数的影响。梯度越大,学习率越小;梯度越小,学习率越大。

Adagrad 代码'''params (iterable) – 待优化参数的iterable或者是定义了参数组的dictlr (float, 可选) – 学习率(默认: 1e-2)lr_decay (float, 可选) – 学习率衰减(默认: 0)weight_decay (float, 可选) – 权重衰减(L2惩罚)(默认: 0)initial_accumulator_value - 累加器的起始值,必须为正。'''class torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0)Adagrad 算法解析

        AdaGrad对学习率进行了一个约束,对于经常更新的参数,我们已经积累了大量关于它的知识,不希望被单个样本影响太大,希望学习速率慢一些;对于偶尔更新的参数,我们了解的信息太少,希望能从每个偶然出现的样本身上多学一些,即学习速率大一些。这样大大提高梯度下降的鲁棒性。而该方法中开始使用二阶动量,才意味着“自适应学习率”优化算法时代的到来。         在SGD中,我们每次迭代对所有参数进行更新,因为每个参数使用相同的学习率。而AdaGrad在每个时间步长对每个参数使用不同的学习率。AdaGrad消除了手动调整学习率的需要。AdaGrad在迭代过程中不断调整学习率,并让目标函数中的每个参数都分别拥有自己的学习率。大多数实现使用学习率默认值为0.01,开始设置一个较大的学习率。

        AdaGrad引入了二阶动量。二阶动量是迄今为止所有梯度值的平方和,即它是用来度量历史更新频率的。也就是说,我们的学习率现在是,从这里我们就会发现 是恒大于0的,而且参数更新越频繁,二阶动量越大,学习率就越小,这一方法在稀疏数据场景下表现非常好,参数更新公式如下: 

                                                            (13)

                                (14)

AdaGrad总结

        AdaGrad在每个时间步长对每个参数使用不同的学习率。并且引入了二阶动量,二阶动量是迄今为止所有梯度值的平方和。

优点:AdaGrad消除了手动调整学习率的需要。AdaGrad在迭代过程中不断调整学习率,并让目标函数中的每个参数都分别拥有自己的学习率。

缺点:a.仍需要手工设置一个全局学习率  , 如果  设置过大的话,会使regularizer过于敏感,对梯度的调节太大

        b.在分母中累积平方梯度,由于每个添加项都是正数,因此在训练过程中累积和不断增长。这导致学习率不断变小并最终变得无限小,此时算法不再能够获得额外的知识即导致模型不会再次学习。

 推荐程度:不推荐

优化器系列文章列表

Pytorch优化器全总结(一)SGD、ASGD、Rprop、Adagrad

Pytorch优化器全总结(二)Adadelta、RMSprop、Adam、Adamax、AdamW、NAdam、SparseAdam

Pytorch优化器全总结(三)牛顿法、BFGS、L-BFGS 含代码

Pytorch优化器全总结(四)常用优化器性能对比 含代码

本文链接地址:https://www.jiuchutong.com/zhishi/298759.html 转载请保留说明!

上一篇:【原创】基于JavaWeb的医院预约挂号系统(医院挂号管理系统毕业设计)

下一篇:〖大前端 - 基础入门三大核心之JS篇㉓〗- JavaScript 的「数组」(大前端入门指南)

  • 微信推广网店的十二种玩法(微信网站推广)

    微信推广网店的十二种玩法(微信网站推广)

  • 芒果tv怎么查看自己的二维码(芒果tv怎么查看全部电影)

    芒果tv怎么查看自己的二维码(芒果tv怎么查看全部电影)

  • 笔记本电脑充电器通用吗(笔记本电脑充电器发热很烫什么原因)

    笔记本电脑充电器通用吗(笔记本电脑充电器发热很烫什么原因)

  • 快手申请仲裁的意思是什么(快手小店申请仲裁管用吗)

    快手申请仲裁的意思是什么(快手小店申请仲裁管用吗)

  • 被搜账号状态异常无法显示(被搜账号状态异常是被拉黑了吗)

    被搜账号状态异常无法显示(被搜账号状态异常是被拉黑了吗)

  • 音箱电流声大是什么回事(音箱电流声大是怎么回事)

    音箱电流声大是什么回事(音箱电流声大是怎么回事)

  • a2197是ipad几代几寸(ipad a2197是第几代)

    a2197是ipad几代几寸(ipad a2197是第几代)

  • 苹果11mwn82ch/a是不是国行(苹果11mwn82cha是不是国行)

    苹果11mwn82ch/a是不是国行(苹果11mwn82cha是不是国行)

  • 天猫直送超时怎么赔偿(天猫直送超时怎么处理)

    天猫直送超时怎么赔偿(天猫直送超时怎么处理)

  • 买家旺旺名是什么意思(买家旺旺名字是淘宝号名字么)

    买家旺旺名是什么意思(买家旺旺名字是淘宝号名字么)

  • 手机版知乎怎么复制(手机版知乎怎么匿名发文章)

    手机版知乎怎么复制(手机版知乎怎么匿名发文章)

  • 超级nfc和全功能nfc区别(超级nfcsim)

    超级nfc和全功能nfc区别(超级nfcsim)

  • 微信取消小米自动续费(小米手机怎么关闭微信自动续费)

    微信取消小米自动续费(小米手机怎么关闭微信自动续费)

  • ie浏览器如何截图(ie游览器怎么截长图)

    ie浏览器如何截图(ie游览器怎么截长图)

  • 三星note10 5g韩版和国行有什么区别(三星note10 5g韩版怎么样)

    三星note10 5g韩版和国行有什么区别(三星note10 5g韩版怎么样)

  • 手机突然没有声音是什么原因(手机突然没有声音了怎么解决)

    手机突然没有声音是什么原因(手机突然没有声音了怎么解决)

  • 华为p30pro特殊功能(华为p30pro功能键介绍)

    华为p30pro特殊功能(华为p30pro功能键介绍)

  • 快手极速版怎么放大(快手极速版怎么取消微信绑定提现)

    快手极速版怎么放大(快手极速版怎么取消微信绑定提现)

  • 翻新手机和新机区别(翻新手机好吗)

    翻新手机和新机区别(翻新手机好吗)

  • ssl协议的主要作用是什么(ssl协议的具体应用有哪些)

    ssl协议的主要作用是什么(ssl协议的具体应用有哪些)

  • 怎样可以看到对方正在输入(怎样可以看到对方的朋友圈)

    怎样可以看到对方正在输入(怎样可以看到对方的朋友圈)

  • 快手青少年模式怎么关闭(快手青少年模式怎么开启)

    快手青少年模式怎么关闭(快手青少年模式怎么开启)

  • 番茄社区怎么看直播(番茄社区从哪里看直播)

    番茄社区怎么看直播(番茄社区从哪里看直播)

  • 苹果手机qq视频能开悬浮窗么(苹果手机qq视频可以美颜吗)

    苹果手机qq视频能开悬浮窗么(苹果手机qq视频可以美颜吗)

  • 华为nova5支持5g吗(华为nova5支持5g信号吗)

    华为nova5支持5g吗(华为nova5支持5g信号吗)

  • 动态范围和宽容度区别(动态宽容度)

    动态范围和宽容度区别(动态宽容度)

  • 手机号码黑名单哪里找(怎么看手机号码黑名单)

    手机号码黑名单哪里找(怎么看手机号码黑名单)

  • 天猫手机如何分期付款(天猫手机怎么分期付款)

    天猫手机如何分期付款(天猫手机怎么分期付款)

  • Python---time模块(pythontime模块)

    Python---time模块(pythontime模块)

  • 小规模纳税人可以开什么发票
  • 长期股权投资中应采用成本法核算的是
  • 结转销售成本的分录
  • 个人所得税怎么交
  • 围挡制作开票的税收分类
  • 退土增税后账务处理
  • 税控系统维护费账务处理
  • 金蝶标准版结转损益发生错误
  • 快递增值税税率
  • 公司注销后还会有事吗
  • 关于进项税额转出的规定
  • 待认证发票后面需要做附件吗
  • 销售退回冲减主营业务收入吗
  • 挂靠行为应当如何纳税?
  • 个人独资企业出资额是注册资本吗
  • 会计低值易耗品有哪些
  • 1697508513
  • 个税扣除每个月更新吗
  • 免税农产品购进怎么做账
  • 对个体工商户个人的认识
  • 年初建账的期初余额
  • Win11任务栏不显示
  • 公司向法人借款会计分录
  • 收到办公室桶装水开的普票怎样入账?
  • 银行承兑汇票付款提示期限
  • 三证合一后的税务登记证查询方法
  • 弃置费用预计负债的会计处理
  • 公司主要开支是指什么
  • wordpress抓取网页
  • 修改配置文件是什么意思
  • 可抵免境外所得税税额
  • 共管账户可以转账吗
  • 机器学习中的数学原理——过拟合、正则化与惩罚函数
  • Visual studio 2019 社区版下载和安装
  • NovelAi + Webui + Stable-diffusion本地配置
  • thinkphp跨域
  • php或者判断
  • 政府会计累计盈余解析
  • 什么经营范围可以开门票发票
  • 应收货款计入什么科目
  • 红冲以前年度的费用怎么做账
  • 毛利率代表什么?如何计算毛利率?
  • 普通发票为什么只能领一张
  • 软件开发企业怎么结转成本
  • 财付通支付备付金
  • 为什么生产经营许可证要第三方代办
  • 同一个客户有应收也有应付怎么办
  • 小规模纳税人不超过10万免增值税
  • 公司银行账号注销需要法人到场吗
  • 账户信息变更说明
  • 公司发放给员工的福利又要回
  • 应交税金借方余额在报表列示
  • 餐饮服务属于什么职业
  • 备用金如何管理制度
  • 取得存款利息收入需附
  • mysql在本地主机创建用户账号
  • xp系统控制面板在哪里打开
  • xp无法识别的usb设备unknown device
  • xp怎么装系统步骤图解
  • ati2plab.exe是什么进程 ati2plab进程安全吗
  • 怎么看win8.1的版本
  • win10应用商店叫什么
  • 在linux操作系统中
  • linux系统设置
  • jquery怎么写
  • java如何自定义函数
  • cmd命令如何进入d盘
  • perl read
  • vtk下载步骤
  • mysql如何将查询结果输出到文件
  • shell脚本 su
  • android基础知识总结
  • python正则\b
  • JQUERY的AJAX请求缓存里的数据问题处理
  • android开发环境搭建实验报告总结
  • 详解Javascript事件驱动编程
  • 人工智能在税务领域应用中的风险与规制
  • 企业税务代码是什么号
  • 模范劳动者
  • 消费发票上的金额含税吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设