位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 蓝奏云怎么注册(蓝奏云怎么注册账号)

    蓝奏云怎么注册(蓝奏云怎么注册账号)

  • 微信名第一次怎么修改(微信初始微信名)

    微信名第一次怎么修改(微信初始微信名)

  • 苹果11微信视频发黄怎么调(苹果11微信视频手机发热怎么回事)

    苹果11微信视频发黄怎么调(苹果11微信视频手机发热怎么回事)

  • iphonexsmax支持双卡吗(iphonexsmax支持双移动卡吗?)

    iphonexsmax支持双卡吗(iphonexsmax支持双移动卡吗?)

  • 卖家收货后不确认退款(卖家收货后不确认收货)

    卖家收货后不确认退款(卖家收货后不确认收货)

  • 个人通话录音怎么调取(通话咋个录音)

    个人通话录音怎么调取(通话咋个录音)

  • 显卡会烧坏吗(显卡坏了的症状有哪些)

    显卡会烧坏吗(显卡坏了的症状有哪些)

  • f8键有什么作用(笔记本f8键没反应啊)

    f8键有什么作用(笔记本f8键没反应啊)

  • 电脑如何下载微信(电脑如何下载微信并安装到桌面)

    电脑如何下载微信(电脑如何下载微信并安装到桌面)

  • 小米在哪里查看已卸载的软件(小米在哪里查看云相册)

    小米在哪里查看已卸载的软件(小米在哪里查看云相册)

  • 新买的移动硬盘第一步要怎么做(新买的移动硬盘插上去就可以用吗)

    新买的移动硬盘第一步要怎么做(新买的移动硬盘插上去就可以用吗)

  • 苹果puls什么意思(苹果手机中的plus是什么意思)

    苹果puls什么意思(苹果手机中的plus是什么意思)

  • 抖音好友怎么备注姓名(抖音上加微信怎么加)

    抖音好友怎么备注姓名(抖音上加微信怎么加)

  • 如何关闭小米音乐锁屏(如何关闭小米音箱)

    如何关闭小米音乐锁屏(如何关闭小米音箱)

  • 华为免打扰怎么关闭(华为免打扰怎么设置快捷键)

    华为免打扰怎么关闭(华为免打扰怎么设置快捷键)

  • iphone8p运行内存多大(iphone8p运行内存无故占满)

    iphone8p运行内存多大(iphone8p运行内存无故占满)

  • iphone8怎么拍夜景(iphone8如何拍夜景)

    iphone8怎么拍夜景(iphone8如何拍夜景)

  • 小米6如何恢复快充(小米6如何恢复原系统)

    小米6如何恢复快充(小米6如何恢复原系统)

  • b站音频缓存在哪(b站的音频缓存)

    b站音频缓存在哪(b站的音频缓存)

  • 手机qq斗地主在哪里找(qq斗地主在哪找)

    手机qq斗地主在哪里找(qq斗地主在哪找)

  • 自己怎么创建公众号(自己怎么创建公司)

    自己怎么创建公众号(自己怎么创建公司)

  • 千牛挂起与不挂起区别(千牛挂起能收到消息吗)

    千牛挂起与不挂起区别(千牛挂起能收到消息吗)

  • 若依前后端分离版:增加新的登录接口,用于小程序或者APP获取token,并使用若依的验证方法(若依前后端分离需要准备啥)

    若依前后端分离版:增加新的登录接口,用于小程序或者APP获取token,并使用若依的验证方法(若依前后端分离需要准备啥)

  • 如何设置BIOS开机启动项将开机第一启动项设置为U盘或光驱(如何设置bios开关机)

    如何设置BIOS开机启动项将开机第一启动项设置为U盘或光驱(如何设置bios开关机)

  • 如何解决Win7台式电脑没声音?(win7s)

    如何解决Win7台式电脑没声音?(win7s)

  • 经典动画库 animate.css 的应用(经典动画动漫)

    经典动画库 animate.css 的应用(经典动画动漫)

  • 公共电话亭是否应该被拆除
  • 企业增值税退税是算企业利润的吗
  • 盈余公积计提比例必须是10%么
  • 建筑企业印花税的计税依据
  • 房产税的纳税义务人是征税范围内房屋产权所有人
  • 收到赠送的商品并销售
  • 增值税申报表中应税货物销售额
  • 餐饮发票个人抬头怎么写
  • 中介行业风险
  • 挂靠管理费如何入账?
  • 工业企业电费出售会计分录怎么写?
  • 对公账户的钱怎么取出来才不用交税
  • 转账支票遗失能挂失吗
  • 厂家核销费用直接抵扣
  • 个人购买车辆的发票可以贷款吗
  • 材料短缺赔偿会计分录怎么写?
  • 价内税和价外税名词解释
  • 职工工资个人所得税缴纳标准
  • 商品进销差价属不属于存货
  • 食堂伙食费账务处理
  • 征信证明怎么开啊
  • 应付的工资属于什么科目
  • 顶账物品都有哪些
  • 升级win10到专业版
  • macbookpro怎么添加文件夹
  • scanserver.exe - scanserver是什么进程 有什么用
  • php几天可以速成
  • 将现金存入银行,登记银行存款日记账的依据一般是
  • laravel with查询指定字段
  • 建安企业所得税怎么算2.25税率
  • 房屋租金应缴纳多少
  • 银行罚息计入什么会计科目
  • php静态缓存
  • 今夕七夕
  • 文化事业2021
  • 暂估成本的账务处理分录
  • 装修费还没摊销完就搬家了
  • 材料款零头抹掉怎么做凭证
  • 其他收益在资产负债表哪点
  • flask框架入门
  • 飞机票抵扣进项税含民航发展基金吗
  • 客运公司做账怎么做
  • 简易征收税率表
  • 已经提完折旧的房产价值评估
  • 四种股利分配政策及适用情况
  • 企业固定资产贷款二押的风险
  • mysql 重复记录查询
  • 分公司能不能作为行政处罚的主体
  • 企业所得税入账凭证
  • 现金短缺与溢余解析
  • 漏缴增值税处罚规定
  • 退休人员在企业工作工资怎么算
  • 小规模纳税人系统查询
  • 预付账款的账务处理视频教程
  • 固定资产会计核算方法
  • mysql删除方法
  • 微软2016是window多少
  • win xp怎么样
  • 系和系怎么区分
  • linux dicom
  • linux 中find
  • win10怎么检查
  • 为什么国外程序员比国内厉害
  • window8系统安装步骤
  • windows开始界面
  • win7系统的wlan在哪里?
  • win8.1 应用商店是不是不能用了
  • win10系统应用更新
  • win10不能玩qq堂没反应
  • 环境变量windows
  • python字符串常用方法
  • js字符串函数
  • android入门视频教程
  • shell脚本中执行命令语句
  • javascript编程技术
  • 北京市国家税务局
  • 广东省电子税务局电话
  • 人防异地建设费标准
  • 国家税务贵州省税务
  • 税务党课主题或党课题目
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设