位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 新浪微博怎么注销账号(新浪微博怎么注销不了)

    新浪微博怎么注销账号(新浪微博怎么注销不了)

  • 滴滴一口价是什么意思(滴滴一口价怎么收费)

    滴滴一口价是什么意思(滴滴一口价怎么收费)

  • 钉钉视频会议怎样两个群同时进行(钉钉视频会议怎么共享屏幕)

    钉钉视频会议怎样两个群同时进行(钉钉视频会议怎么共享屏幕)

  • qq留言怎么设置仅彼此可见(qq留言怎么设置背景照片)

    qq留言怎么设置仅彼此可见(qq留言怎么设置背景照片)

  • 有线耳机连接不上手机(有线耳机连接不稳定怎么办)

    有线耳机连接不上手机(有线耳机连接不稳定怎么办)

  • 手机只能照相不能录像(手机只能照相不能拍视频了)

    手机只能照相不能录像(手机只能照相不能拍视频了)

  • 淘宝掘金团队怎么退出(淘宝挖金团队从哪里进入)

    淘宝掘金团队怎么退出(淘宝挖金团队从哪里进入)

  • 美团otc是什么订单(美团oc是什么意思)

    美团otc是什么订单(美团oc是什么意思)

  • exynos 980处理器是啥水平(Exynos 980处理器相当于)

    exynos 980处理器是啥水平(Exynos 980处理器相当于)

  • hg8546m哪个口是千兆口(hg8540m哪个口是千兆)

    hg8546m哪个口是千兆口(hg8540m哪个口是千兆)

  • 手机莫名其妙充不进去电(手机莫名其妙充值50元)

    手机莫名其妙充不进去电(手机莫名其妙充值50元)

  • dig-tl10是什么型号(dig_tl10)

    dig-tl10是什么型号(dig_tl10)

  • opporeno后面几个摄像头(oppo reno后面那个突出的点是什么)

    opporeno后面几个摄像头(oppo reno后面那个突出的点是什么)

  • 苹果电脑丢了能追踪吗(苹果电脑丢了能不能找到)

    苹果电脑丢了能追踪吗(苹果电脑丢了能不能找到)

  • 苹果手机更新后怎么卸载软件(苹果手机更新后发烫)

    苹果手机更新后怎么卸载软件(苹果手机更新后发烫)

  • 怎么关掉win杀毒软件(win杀毒软件如何关闭)

    怎么关掉win杀毒软件(win杀毒软件如何关闭)

  • 入驻京东自营怎么收费(京东自营入驻店铺流程)

    入驻京东自营怎么收费(京东自营入驻店铺流程)

  • 苹果手机3g怎么切换到4g(苹果手机3g怎么设置4g)

    苹果手机3g怎么切换到4g(苹果手机3g怎么设置4g)

  • 数据预处理的方法(数据预处理的方法 离散化)

    数据预处理的方法(数据预处理的方法 离散化)

  • 苹果7p有前置呼吸灯吗(苹果7p有前置呼叫功能吗)

    苹果7p有前置呼吸灯吗(苹果7p有前置呼叫功能吗)

  • 如何通过10086转赠流量(10086怎么转流量给别人)

    如何通过10086转赠流量(10086怎么转流量给别人)

  • 荣耀20如何添加小工具(荣耀20如何添加门禁卡)

    荣耀20如何添加小工具(荣耀20如何添加门禁卡)

  • rio-ul00什么型号(rio ul00华为是什么型号)

    rio-ul00什么型号(rio ul00华为是什么型号)

  • mac和ipad如何互传文件呢?MAC传文件到IPAD方法介绍(macbook怎么和ipad)

    mac和ipad如何互传文件呢?MAC传文件到IPAD方法介绍(macbook怎么和ipad)

  • rteng7.exe - rteng7是什么进程 有什么用

    rteng7.exe - rteng7是什么进程 有什么用

  • php单例模式有什么用(php单例模式优点)

    php单例模式有什么用(php单例模式优点)

  • 补缴以前年度企业所得税如何填报汇算清缴表
  • 小规模要交增值税怎么计提
  • 支付给职工以及为职工支付的现金包括哪些
  • 房屋维修基金帐户怎么查
  • 未抵扣的进项发票是什么意思
  • 利润表收入含其他收入吗怎么填
  • 个税年度累计计算器
  • 资产申报是什么
  • 净收益营运指数大于1说明什么
  • 存货成本包括消费吗
  • 行政单位应缴预算款的管理原则
  • 因管理不善的材料盘亏如何做账
  • 压覆矿产赔偿标准法律依据
  • 制造费用的
  • 未分配利润的计税基础是
  • 公允价值变动损益属于当期损益吗
  • 土地增值税清算是什么意思
  • 固定资产处置的账务处理
  • 融资租赁租金及利息计算
  • 税金及附加包括地方教育费附加吗
  • 收到以前年度退回的企业所得税怎么做账
  • 有关于秋天的诗句
  • 金融资产有哪三类代码
  • 收到上年度企业所得税退税款
  • 广告公司的成本是什么
  • 付款交单和承兑交单对卖方来说都有一定风险
  • 代收代付款项入账需要什么资料
  • kb4539601安装失败
  • phpfilter
  • 员工辞退补偿金扣个税吗
  • php中execute
  • 小微企业城建税及附加减免优惠
  • 文本超出单元格
  • php gdb
  • 增值税发票已认证抵扣还可以进项税额转出吗
  • wget下载yum
  • 承包安装工程活怎么接
  • mongodb win7
  • 货代一般一个柜利润多少
  • 小微企业需要专职安全员吗
  • 零申报企业所得税的资产总额怎么填写
  • 小规模纳税人怎么开专票
  • 临时用工费用计入什么会计科目
  • 小规模纳税人购入货物收到增值税专用发票
  • 委托代销安排的迹象有哪些
  • 公司账户资金转个人账户
  • 包材库存
  • 业务招待费可以计入销售费用吗
  • 企业实缴资本如何查
  • 建设工程施工管理
  • 没有什么费用
  • 红冲发票金额大于原发票金额
  • w7系统ip地址
  • windows预览版计划
  • Win7系统开机流程
  • win8系统崩溃无法开机
  • 如何设置win10系统输入法
  • ubuntu的安装步骤
  • win7禁用administrator
  • win7架设ftp服务器
  • win7系统如何修改默认浏览器
  • win7系统玩英雄联盟
  • 你会支持国产系统吗英文
  • 冗余文件是什么意思
  • 创建react native项目
  • 微信小程序实现人脸识别
  • our与my的区别
  • android注册界面设计
  • 网页设计div css布局
  • ajax成功不走success
  • android源码分析
  • python搭建虚拟环境torch
  • jquery mobile app
  • 判断jquery对象是否存在
  • 税务评估风险等级是什么
  • 落实与什么动词搭配
  • 雄安属于北京管吗
  • 出口退税申报时闿
  • 一般纳税人提供公共交通运输服务免征增值税
  • 西安市人力资源和社会保障局关于2020年
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设