位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 苹果xr相机如何设置镜像(苹果xr相机如何背景变透明)

    苹果xr相机如何设置镜像(苹果xr相机如何背景变透明)

  • 在朋友圈如何发15秒视频(在朋友圈如何发文字消息)

    在朋友圈如何发15秒视频(在朋友圈如何发文字消息)

  • 表格a列显示不出来怎么办(excelabc列不显示)

    表格a列显示不出来怎么办(excelabc列不显示)

  • 苹果影音先锋坏了文件不见了(苹果影音先锋怎么用不了)

    苹果影音先锋坏了文件不见了(苹果影音先锋怎么用不了)

  • 戴尔g3散热风扇有异响(戴尔g3散热风扇一直转)

    戴尔g3散热风扇有异响(戴尔g3散热风扇一直转)

  • 超话发帖等级限制是几级(超话发帖等级限制怎么解除)

    超话发帖等级限制是几级(超话发帖等级限制怎么解除)

  • 苹果se来电闪光灯怎么设置(苹果se来电闪光怎么关闭)

    苹果se来电闪光灯怎么设置(苹果se来电闪光怎么关闭)

  • airpods盒子上的序列号在哪(airpods盒子上的型号写的A2700但是手机上显示A2698)

    airpods盒子上的序列号在哪(airpods盒子上的型号写的A2700但是手机上显示A2698)

  • qq看主页会留下记录吗(qq看主页会被发现吗)

    qq看主页会留下记录吗(qq看主页会被发现吗)

  • 华为移动数据按钮灰色(华为移动数据按键在哪)

    华为移动数据按钮灰色(华为移动数据按键在哪)

  • 双书名号怎么输(双书名号怎么使用)

    双书名号怎么输(双书名号怎么使用)

  • 支付宝盒子红灯一直闪怎么回事(支付宝盒子红灯闪是什么意思)

    支付宝盒子红灯一直闪怎么回事(支付宝盒子红灯闪是什么意思)

  • Excel取消自动筛选后会怎样(excel取消自动筛选后会怎样)

    Excel取消自动筛选后会怎样(excel取消自动筛选后会怎样)

  • 苹果怎么查被拦截的电话(苹果怎么查拦截电话)

    苹果怎么查被拦截的电话(苹果怎么查拦截电话)

  • 苹果手机双4g啥意思(ios双4g)

    苹果手机双4g啥意思(ios双4g)

  • 手机安装不了qq怎么办(手机安装不了QQ怎么办)

    手机安装不了qq怎么办(手机安装不了QQ怎么办)

  • 手机无线密码怎么查看(手机无线密码怎么显示出来)

    手机无线密码怎么查看(手机无线密码怎么显示出来)

  • 哈罗单车扫码后不开锁(哈罗单车扫码后显示地图)

    哈罗单车扫码后不开锁(哈罗单车扫码后显示地图)

  • 苹果手机数据怎么转移到华为手机(苹果手机数据怎么传到电脑)

    苹果手机数据怎么转移到华为手机(苹果手机数据怎么传到电脑)

  • 栈和队列在现实生活的应用(栈和队列使用场景)

    栈和队列在现实生活的应用(栈和队列使用场景)

  • 苹果7待机时间多久(苹果7待机时间长了就卡)

    苹果7待机时间多久(苹果7待机时间长了就卡)

  • 怎么关闭win11安全中心 win11安全中心关闭步骤(怎么关闭win11安装软件提示)

    怎么关闭win11安全中心 win11安全中心关闭步骤(怎么关闭win11安装软件提示)

  • win11怎么隐藏底部任务栏? windows11任务栏隐藏方法(Win11怎么隐藏底部)

    win11怎么隐藏底部任务栏? windows11任务栏隐藏方法(Win11怎么隐藏底部)

  • torch.cuda常用指令(torch.cuda.is_available())

    torch.cuda常用指令(torch.cuda.is_available())

  • 税控盘开票系统怎么升级
  • 财务软件进什么费用
  • 购买车间使用的设备计入什么
  • 取得社会团体会费专用票据可以税前扣除吗
  • 工程保险谁负责
  • 员工培训的费用按照多少钱计入安措费
  • 税费漏报
  • 铁路大票抵扣几个点
  • 工作未满12个月被辞退时前月平均工资怎么计算
  • 制造业印花税计税,按照去税金额计算
  • 公司车辆转让需要缴纳印花税吗
  • 应付账款是负数怎么回事
  • 手撕发票怎么区分地区开具
  • 增值税应交税费科目
  • 报价表含税点是什么意思?
  • 总资产增长率的含义
  • 信息技术服务在开票系统怎么选
  • 快递有发票快递如何收费
  • 个人所得税申请免税条件
  • 个税汇算清缴包含退休金吗
  • 微软 windows11
  • php转word
  • wordpress建网站详细教程
  • PHP:imagecreatefromgd2part()的用法_GD库图像处理函数
  • 毛地黄长什么样
  • vue知识点总结
  • chrome插件api
  • php二维数组查询指定值
  • 银行承兑汇票应由在承兑银行开立存款账户的存款人签发
  • 社保代扣代缴的规定
  • 更改Mysql root用户密码
  • 增值税纳税人放弃免税权的规定
  • 直接计入所有者权益的利得和损失,影响当期损益
  • 小规模企业免税收入会计分录
  • 关税完税价格计算增值税
  • 一般纳税人季报怎么填
  • 固定资产可以一次性折旧吗
  • 车辆保险费如何缴纳印花税的
  • 豆制品属于农副产品吗为什么
  • 其他应收款账务核销后放在哪个科目里
  • 股东往来款算投资款吗
  • 采购折扣怎么结算
  • 原材料暂估的业务包括
  • 财政总预算会计的主体是
  • 企业收到微信和企业微信
  • 影响折旧的因素有哪三个方面
  • 现金日记账本月合计怎么划线
  • 个人向公司账户存现金
  • 普通日记账如何记账
  • MySQL下载安装视频
  • mysql mac下载
  • windows隐藏功能
  • XP系统怎么删除密码
  • centosuuid
  • windows后台启动VirtualBox虚拟机让界面不在出现
  • apache zipfile
  • linux改变
  • linux安装gdb命令
  • 在linux 上使用QQ聊天程序
  • fsa是什么文件格式
  • NJeeves.exe - NJeeves进程文件是什么意思 有什么用
  • win7开机chkdsk
  • pphelper是什么文件
  • win8系统恢复
  • win8自启动
  • linux创建.c
  • 怎么对js代码程序进行设计
  • springmvc接收form表单
  • opengl 绘制
  • tensorflowoom
  • Bullet(Cocos2dx)之优化PhysicsDraw3D
  • linux每隔1s执行一次命令
  • unity导出3d模型
  • jquery 插件写法
  • mac如何配置pycharm
  • 调查问卷的背景资料怎么写
  • 国家税务总局云平台网址
  • 中药生产与加工和中药制药技术区别
  • 2020年河南麦收时间
  • 江苏电子口岸卡邮寄大概需要多久
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设