位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 小米9pro有没有耳机孔(小米9pro有没有高刷新率)

    小米9pro有没有耳机孔(小米9pro有没有高刷新率)

  • 恢复回收站清空的照片(恢复回收站清空的文件怎么恢复)

    恢复回收站清空的照片(恢复回收站清空的文件怎么恢复)

  • 京东企业用户和个人用户区别吗(京东企业用户和plus会员)

    京东企业用户和个人用户区别吗(京东企业用户和plus会员)

  • 华为mate30送的是什么耳机(华为mate30赠送礼包)

    华为mate30送的是什么耳机(华为mate30赠送礼包)

  • 被对方拉黑了对方换头像看得到吗(被对方拉黑了对方还能看到我的作品吗)

    被对方拉黑了对方换头像看得到吗(被对方拉黑了对方还能看到我的作品吗)

  • 分页符的作用是什么(分页符作用是分页吗)

    分页符的作用是什么(分页符作用是分页吗)

  • 怎么查无线网是否欠费(怎么查无线网是否到期)

    怎么查无线网是否欠费(怎么查无线网是否到期)

  • 华为nova7和荣耀30有什么区别(华为nova7和荣耀30的区别)

    华为nova7和荣耀30有什么区别(华为nova7和荣耀30的区别)

  • 荣耀v30卡槽怎么取卡出来(荣耀v30卡槽怎么打开图解)

    荣耀v30卡槽怎么取卡出来(荣耀v30卡槽怎么打开图解)

  • 128kbs网速能看视频吗(128k网速能看视频吗)

    128kbs网速能看视频吗(128k网速能看视频吗)

  • 华为充电器口叫什么(华为充电器接口中文叫什么)

    华为充电器口叫什么(华为充电器接口中文叫什么)

  • 华为手机上方hd是什么意思怎样去掉(华为手机上方hd怎么关)

    华为手机上方hd是什么意思怎样去掉(华为手机上方hd怎么关)

  • 快手送过礼的主播怎么在赞里面删除(快手送过礼的主播怎么在赞里面彻底删除)

    快手送过礼的主播怎么在赞里面删除(快手送过礼的主播怎么在赞里面彻底删除)

  • 手机横屏怎么办(0p0p手机横屏怎么办)

    手机横屏怎么办(0p0p手机横屏怎么办)

  • 苹果11后面几个摄像头(苹果11 后面)

    苹果11后面几个摄像头(苹果11 后面)

  • 华为webview有什么用(华为wedview)

    华为webview有什么用(华为wedview)

  • 谷歌服务框架安装失败怎么办(谷歌服务框架安装了谷歌商店还是闪退)

    谷歌服务框架安装失败怎么办(谷歌服务框架安装了谷歌商店还是闪退)

  • 剪映怎么搜索音乐(剪映怎么搜索音频文件)

    剪映怎么搜索音乐(剪映怎么搜索音频文件)

  • 梦与诗(梦与诗原文)

    梦与诗(梦与诗原文)

  • 腾讯视频不能投屏怎么回事(腾讯视频不能投屏怎么解决)

    腾讯视频不能投屏怎么回事(腾讯视频不能投屏怎么解决)

  • 苹果短信前面有个月亮(苹果短信前面有一个月亮是什么意思)

    苹果短信前面有个月亮(苹果短信前面有一个月亮是什么意思)

  • python如何赚外快(python怎么赚外快)

    python如何赚外快(python怎么赚外快)

  • win10重装win7(win10重装win7后键盘鼠标不能用)

    win10重装win7(win10重装win7后键盘鼠标不能用)

  • MAC中安装软件提示软件已损坏或提示不是App store下载的解决方法(mac安装软件提示需要更高版本)

    MAC中安装软件提示软件已损坏或提示不是App store下载的解决方法(mac安装软件提示需要更高版本)

  • 什么是Python中的闭包(什么叫python)

    什么是Python中的闭包(什么叫python)

  • 服务类一般纳税人无进项
  • 税金及附加都包含什么
  • 应收出口退税金额无法收回怎么做账?
  • 公司法人已变更,前法人被失信
  • 资产现金流量收益率计算例题
  • 广告牌制作加盟厂家
  • 小规模纳税人开票税率
  • 公司零星支出没有发票收据怎么开
  • 业务招待费扣除标准营业收入包括
  • 增值税专用发票抵扣税额是什么意思
  • 现金折扣冲减销售收入冲销项税吗
  • 怎么计算股票的压力位和支撑位
  • 失业社保补助金领取条件
  • 期末留抵税额可以留多久
  • 预缴增值税为什么记借方
  • 盈余公积可用于集体福利吗
  • 当期应税销售收入是含税还是不含税
  • 印花税按什么税率
  • 开办期间的税控设备怎么入账?
  • 企业所得税如何缴纳
  • 增值税开票金额在哪里看
  • 小微企业所得税优惠政策
  • 税务变更
  • 居间合同怎么签才算有效
  • 企业购入土地如何处理
  • 当月计提当月缴纳的增值税还用结转吗
  • 审图费发票需要备注吗
  • 内部存货交易的抵消分录例题讲解
  • 培训费没有发票怎么办
  • 收到投资的会计科目
  • 利率和利息的区别白话
  • 工厂的绿化费进项税额
  • 向非绑定账户转账超限是什么意思
  • 支付设备维修费用计入什么科目
  • RegSrvc.exe - RegSrvc是什么进程 有什么用
  • 财务里计提是什么意思
  • PHP:iterator_to_array()的用法_spl函数
  • 出租固定资产取得的收入属于收入要素吗
  • 排灯节起源
  • phpstorm怎么样
  • php有多简单
  • 软件产品即征即退申请表
  • php操作字符串
  • html/css/javascript
  • ps制作折扇效果图
  • 本月的进项票可以抵扣上月税款吗?
  • css选择上一个兄弟
  • python中列表的索引用法
  • java事件处理机制三个重要概念
  • 未开票收入如何做账
  • 为什么有些网站会自动复制
  • 新的会计制度
  • 辞退福利为什么不计入产品成本
  • 定期定额个体经营所得税申报错了怎么办
  • 房产税的政策依据
  • 委托代销业务的会计分录
  • 以银行承兑汇票支付购买原材料款
  • 服务行业服务费怎么入账
  • 弱电工程属于什么行业
  • 固定资产确认条件最新
  • 颁发数字证书要符合什么条件
  • 企业是否必须建立巡察制度
  • Windows10 64位安装MySQL5.6.35的图文教程
  • u盘装win8系统教程图解
  • bios和cmos的区别和联系
  • linux快速查找历史命令
  • win7系统的摄像功能在哪
  • xp开机提示explorer
  • ubuntu系统升级 开机黑屏怎么解决
  • win7查看本机信息
  • win10显示win8
  • win8系统自带浏览器
  • js时间倒计时定时器怎么弄
  • 编写折半查找的程序
  • js的settimeout方法
  • shell可以多线程吗
  • 深入理解中国式现代化
  • 税务局打印申报表
  • 80491232税务申报代码
  • 环保税申报操作手册
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设