位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 小米civi跑分

    小米civi跑分

  • word形状填充在哪里(word形状填充在哪)

    word形状填充在哪里(word形状填充在哪)

  • bootmgr is conmpressed电脑上怎么解决(bootmgrisconmpressed无法开机怎么办)

    bootmgr is conmpressed电脑上怎么解决(bootmgrisconmpressed无法开机怎么办)

  • 苹果8p怎么分屏多任务(苹果8p怎么分屏小窗口操作)

    苹果8p怎么分屏多任务(苹果8p怎么分屏小窗口操作)

  • 叨叨记账可以互相聊天吗(叨叨记账怎么互动)

    叨叨记账可以互相聊天吗(叨叨记账怎么互动)

  • 手机更新后没有声音怎么回事(手机更新后没有录屏了怎么办)

    手机更新后没有声音怎么回事(手机更新后没有录屏了怎么办)

  • 魅族17上市时间(魅族20价格)

    魅族17上市时间(魅族20价格)

  • oppor9splus一直自动重启怎么回事(oppor9splus一直自动重启,拆除两个电感后有什么影响)

    oppor9splus一直自动重启怎么回事(oppor9splus一直自动重启,拆除两个电感后有什么影响)

  • 你拨叫的用户暂时无法接通什么意思(你拨叫的用户暂时无人接听请稍后再拨)

    你拨叫的用户暂时无法接通什么意思(你拨叫的用户暂时无人接听请稍后再拨)

  • 动态锁屏壁纸怎么没有声音(动态锁屏壁纸怎么关闭)

    动态锁屏壁纸怎么没有声音(动态锁屏壁纸怎么关闭)

  • 如何在word中画线段图(如何在word中画表格)

    如何在word中画线段图(如何在word中画表格)

  • 华为手机屏幕变成黑色怎么恢复(华为手机屏幕变绿色了怎么回事)

    华为手机屏幕变成黑色怎么恢复(华为手机屏幕变绿色了怎么回事)

  • 微信代扣服务在哪里可以取消的?(微信代扣业务是什么)

    微信代扣服务在哪里可以取消的?(微信代扣业务是什么)

  • mumimo需要终端支持吗(muos终端)

    mumimo需要终端支持吗(muos终端)

  • 微信怎么一次删除多个好友(微信怎么一次删除很多聊天记录)

    微信怎么一次删除多个好友(微信怎么一次删除很多聊天记录)

  • 华为手机设置彩虹电池(华为手机设置彩铃免费)

    华为手机设置彩虹电池(华为手机设置彩铃免费)

  • 红米note8支持屏幕指纹吗(redmi note8屏幕怎么样)

    红米note8支持屏幕指纹吗(redmi note8屏幕怎么样)

  • 红蜘蛛多媒体教学软件怎么卸载(红蜘蛛多媒体教室 使用)

    红蜘蛛多媒体教学软件怎么卸载(红蜘蛛多媒体教室 使用)

  • 美易怎么申请退款(美易怎么申请退款退货)

    美易怎么申请退款(美易怎么申请退款退货)

  • 小爱是什么品牌(小爱是什么品牌的手机)

    小爱是什么品牌(小爱是什么品牌的手机)

  • 苹果手机找不到订阅管理怎么办(苹果手机找不到了在家里怎么找)

    苹果手机找不到订阅管理怎么办(苹果手机找不到了在家里怎么找)

  • ios怎么添加信任(iphone怎么添加信任)

    ios怎么添加信任(iphone怎么添加信任)

  • oppo手机微信铃声在哪里设置(oppo手机微信铃声大小怎么调)

    oppo手机微信铃声在哪里设置(oppo手机微信铃声大小怎么调)

  • 帝国cms后台登录次数不超过5次限制(帝国cms怎么样)

    帝国cms后台登录次数不超过5次限制(帝国cms怎么样)

  • 怎么把python代码做成软件(怎么把python代码发给别人运行)

    怎么把python代码做成软件(怎么把python代码发给别人运行)

  • 小规模纳税人执行小企业会计准则吗
  • 提现的手续费怎么做账
  • 建筑业预缴企业所得税
  • 小规模没有税控怎么办
  • 2021年超市发票税率是多少
  • 我公司给对方公司付款
  • 结算金额大于合同金额
  • 未竣工验收导致发生质量问题由谁承担责任
  • 会计上的未达账项是什么
  • 税报完了可以撤销吗
  • 源泉扣缴税率是多少
  • 建筑业会计实操视频教程
  • 想要房贷利息抵税怎么办
  • 赠送样品需要交税吗
  • 收到注册资金要交税吗
  • 礼服租赁套餐
  • 公司没有车牌可以买车吗
  • 未确认融资费用摊销额怎么计算
  • 过桥费和过路费去哪里了
  • 合伙企业累进税率
  • 外企投资应该怎么投资
  • 企业汇算清缴后还能更正吗
  • 收入与成本不匹配建议怎么写
  • 小规模纳税人需要建账吗
  • 发给职工的交通补贴
  • 小规模纳税人买车可以抵税吗
  • 申报高新企业的当年一定要研发费用加计扣除吗
  • bios屏蔽接口
  • 销售方红字发票账务处理
  • 什么是产品生产者之间争取最有利的关系
  • 员工冲借款应该怎么做账
  • 退款后发票如何查询
  • php中execute
  • 供热企业享受增值税吗
  • 金银首饰以旧换新消费税怎么算
  • 投资者投资企业项目的主要目的是
  • php获取随机数
  • php读取文件内容
  • tensorflow gan
  • 修改Laravel5.3中的路由文件与路径
  • -mtime命令
  • 单位收到投标保证金会计分录
  • 高新技术企业研发费用加计扣除
  • 增值税在纳税申报表中怎么得出
  • 税前扣除是什么时候
  • 收到多开发票的会计分录
  • mysql union or
  • mongo 更新数据
  • ibm.data.db2
  • 进项有效期
  • 房产税的政策依据
  • 税务局退回水利基金账务怎么处理
  • 产品不符合要求
  • 流动性比率过小什么意思
  • 为什么出口退税是贷方
  • 以前年度损益这个科目
  • 购买汽车后,需要缴纳的税种有哪些
  • 补缴税务滞纳金
  • 重要性水平如何判断
  • 外资企业计提盈余公积吗
  • 新收入准则要求
  • 如何使用u盘安装win11
  • wcu.exe是什么
  • windows打不开添加打印机
  • vsftp查看状态
  • linux系统的安全机制有哪些
  • win7找不到启动
  • win7 64位纯净版系统c盘空间显示与实际占用空间不对的解决方法图文教程
  • win8升win8.1
  • js require()
  • react用什么ui
  • jquery浮动窗口
  • node connect
  • css如何控制图片位置
  • 用python编写登录程序
  • perl脚本调试方法
  • vlc录制没反应
  • shell脚本实例精讲
  • js tab切换
  • 代扣代缴申报表 填表说明
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设