位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 美团订酒店怎么订(美团订酒店怎么代付)

    美团订酒店怎么订(美团订酒店怎么代付)

  • 保存快捷键是什么(保存快捷键是什么加什么)

    保存快捷键是什么(保存快捷键是什么加什么)

  • 怎么把微信设置成不在线(怎么把微信设置成免打扰模式)

    怎么把微信设置成不在线(怎么把微信设置成免打扰模式)

  • 怎么开启华为手机的小艺功能(怎么开启华为手机的小艺)

    怎么开启华为手机的小艺功能(怎么开启华为手机的小艺)

  • 苹果手机照片全部不见了怎么办(苹果手机照片全图缩小做头像)

    苹果手机照片全部不见了怎么办(苹果手机照片全图缩小做头像)

  • c盘windows哪些资料可以删除(c盘windows文件夹哪些文件可以删除)

    c盘windows哪些资料可以删除(c盘windows文件夹哪些文件可以删除)

  • 华为手机有语音助手吗(华为手机有语音播报功能吗)

    华为手机有语音助手吗(华为手机有语音播报功能吗)

  • laptop是什么电脑(laptop laptop)

    laptop是什么电脑(laptop laptop)

  • 确认收货为什么扣了钱(确认收货为什么要输入支付密码)

    确认收货为什么扣了钱(确认收货为什么要输入支付密码)

  • qq音乐访客数量不对(qq音乐访客数量却没有增加)

    qq音乐访客数量不对(qq音乐访客数量却没有增加)

  • 怎么拒收消息(不删微信怎么拒收消息)

    怎么拒收消息(不删微信怎么拒收消息)

  • 华为p40pro是90hz吗(华为p40Pro是1080p屏吗?)

    华为p40pro是90hz吗(华为p40Pro是1080p屏吗?)

  • 微博怎么设置密码(微博怎么设置密码锁屏)

    微博怎么设置密码(微博怎么设置密码锁屏)

  • mpp是什么文件(mpp是什么文件,怎么打开)

    mpp是什么文件(mpp是什么文件,怎么打开)

  • ps怎么设置画笔压感(ps怎么添加画笔笔尖形状)

    ps怎么设置画笔压感(ps怎么添加画笔笔尖形状)

  • 手机知乎怎么搜索(知乎app搜索)

    手机知乎怎么搜索(知乎app搜索)

  • 小米9pro怎么关闭桌面图标角标(小米9怎么关闭miui)

    小米9pro怎么关闭桌面图标角标(小米9怎么关闭miui)

  • iphone8plus怎么开热点(iPhone8plus怎么开流量)

    iphone8plus怎么开热点(iPhone8plus怎么开流量)

  • oppok3线下有卖吗(oppok3在实体店大约多少钱)

    oppok3线下有卖吗(oppok3在实体店大约多少钱)

  • 打开蜂窝数据不显示4g(蜂窝数据打开没反应)

    打开蜂窝数据不显示4g(蜂窝数据打开没反应)

  • 为什么打不出去电话(电信电话为什么打不出去)

    为什么打不出去电话(电信电话为什么打不出去)

  • 魅族15哪一年出的(魅族15刚上市多少钱)

    魅族15哪一年出的(魅族15刚上市多少钱)

  • QQ怎么开启聊天记录漫游(qq怎么开启聊天标识)

    QQ怎么开启聊天记录漫游(qq怎么开启聊天标识)

  • openCV实践项目:拖拽虚拟方块(opencv项目开发实战)

    openCV实践项目:拖拽虚拟方块(opencv项目开发实战)

  • ausearch命令  检索审计记录(auth命令)

    ausearch命令 检索审计记录(auth命令)

  • 建筑业购买材料计入什么科目
  • 全国增值税发票查验平台入口
  • 银行应发贷款和实际收到的贷款为什么不一致
  • 租房子没有
  • 私募基金如何做大规模
  • 发票开错对方已抵扣怎么处理
  • 小规模免税的税额怎么处理
  • 水费抵扣进项税税率是多少
  • 运输企业印花税按什么缴纳
  • 支付招聘费收到专票怎么记账
  • 收购农产品进项税抵扣税率是多少
  • 已认证抵扣的进项发票,次月开具红字发票信息表,凭证
  • 出口样品收汇不报关会计分录
  • 非正常原因导致的存货盘亏或毁损非正常原因是哪些
  • 物业公司收取水费如何开具发票
  • 实物型产品的基础知识
  • 当月税负率怎么算
  • 提取备用金的手续费会计分录
  • 对公账户的利息收入如何入账
  • 个人所得税生产经营所得投资者减除费用
  • 劳务费开发票还要代扣代缴吗?
  • 普票红冲对方已入账发票拿不回来
  • 企业增资还需要增资账户么
  • 教育费附加免征还计提吗
  • 银行 环保
  • 车辆过路费凭什么收费
  • 企业取得财政拨款怎么算
  • 购买固定资产算投资吗
  • 电脑bios怎么设置网络启动
  • 农产品进项转出的规定
  • 前端获取当前地址
  • php讲解
  • PHP+Mysql+jQuery文件下载次数统计实例讲解
  • 发票认证了,但是没有入账
  • 外汇延期收款办理操作指南
  • 苹果手机铃声删除在哪里
  • 承租人对经营租赁的会计处,怎么快速记住方法
  • 融资租赁业务的操作程序
  • wordpress登录注册
  • 外商投资工业企业有哪些
  • 小微企业减免额怎么计算
  • 外贸出口企业的税务风险
  • 工会的钱怎么取出来
  • 研发失败能做加工企业吗
  • php设计思路
  • 会计核算方式有几种
  • sql优化的方法及思路
  • 房地产企业购买礼品赠送客户
  • 弃置费用的摊余成本
  • 车险代买的出了事故怎么办
  • 劳务费怎么做会计科目
  • 公司社保外包了,没给我社保卡怎么办?
  • 投标费用如何入账科目
  • 互联网金融理财产品的优势
  • 盘亏的固定资产是资产吗
  • 银行信用贷款发放邮件后多久到账
  • 计入税金及附加借方的内容
  • 电话费计入什么二级科目
  • win9怎么截图
  • macbook和windows
  • win7清除usb插拔记录
  • 如何在windows中对硬盘进行分区
  • windows运行不了怎么办
  • win8.1语言包下载
  • mmc.exe是什么
  • Ghost XP SP3 YN8.0装机版 (雨林木风)
  • cocos3.0
  • js cocos
  • node-red 全局变量
  • css中层叠的含义
  • 已经序列化的表单怎么再添加
  • javascript判断
  • js 三元
  • javascript闭包的作用
  • flask框架下使用scrapy框架
  • 国家税务局浙江省电子税务局新版
  • 宝鸡税务局长
  • 不配合税务检查的法律责任
  • 河北保定地税局官网
  • 创新税务稽查方案
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设