位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 朵唯女性手机营销策略(朵唯女性手机怎么样)(朵唯女性手机图片)

    朵唯女性手机营销策略(朵唯女性手机怎么样)(朵唯女性手机图片)

  • iqoo8pro支持hifi吗(iqoo8有没有hifi)

    iqoo8pro支持hifi吗(iqoo8有没有hifi)

  • 农业银行微信提醒怎么开通(农业银行微信提现免费)

    农业银行微信提醒怎么开通(农业银行微信提现免费)

  • 高速电子支付的发票在哪里开(高速电子支付的发票会发手机上吗)

    高速电子支付的发票在哪里开(高速电子支付的发票会发手机上吗)

  • 抖音怎么注销账号(抖音怎么注销账号手机号码)

    抖音怎么注销账号(抖音怎么注销账号手机号码)

  • 拼多多商家版怎么删除直播视频(拼多多商家版怎么找人工客服聊天)

    拼多多商家版怎么删除直播视频(拼多多商家版怎么找人工客服聊天)

  • 华为手机照片如何永久删除(华为手机照片如何添加日期水印)

    华为手机照片如何永久删除(华为手机照片如何添加日期水印)

  • 文字后面的下划线怎么添加不上(文字后面的下划线出不来)

    文字后面的下划线怎么添加不上(文字后面的下划线出不来)

  • 电脑qq录屏文件保存在哪里(电脑qq录屏文件太大如何变小)

    电脑qq录屏文件保存在哪里(电脑qq录屏文件太大如何变小)

  • 像素怎么算(相机像素怎么算)

    像素怎么算(相机像素怎么算)

  • 照片格式jpg和jpeg是什么意思(照片格式JPG和JPEG)

    照片格式jpg和jpeg是什么意思(照片格式JPG和JPEG)

  • 微信群主怎样撤回群员消息(微信群主怎样撤销群员发的消息)

    微信群主怎样撤回群员消息(微信群主怎样撤销群员发的消息)

  • 快手小店有没有运费险(快手小店有没有新手期)

    快手小店有没有运费险(快手小店有没有新手期)

  • 魅族16T屏幕材质(魅族16th屏幕材质)

    魅族16T屏幕材质(魅族16th屏幕材质)

  • 200m宽带可以用千兆端口吗(200m宽带可以用多少流量)

    200m宽带可以用千兆端口吗(200m宽带可以用多少流量)

  • win7的桌面是指(windows7的桌面指的是什么)

    win7的桌面是指(windows7的桌面指的是什么)

  • 京东红包过期了能恢复吗?(京东红包过期了怎么办)

    京东红包过期了能恢复吗?(京东红包过期了怎么办)

  • 怎么在微信里看身份证(怎么在微信里看走了多少步)

    怎么在微信里看身份证(怎么在微信里看走了多少步)

  • 双十二是淘宝还是天猫(双十二是淘宝还是京东搞活动)

    双十二是淘宝还是天猫(双十二是淘宝还是京东搞活动)

  • vivo手机怎么设置sos功能(vivo手机怎么设置字体大小)

    vivo手机怎么设置sos功能(vivo手机怎么设置字体大小)

  • 手机强制恢复出厂设置(手机强制恢复出厂设置软件)

    手机强制恢复出厂设置(手机强制恢复出厂设置软件)

  • 手机浏览器点开闪退怎么办(手机浏览器点开有不良视频出现怎么关闭)

    手机浏览器点开闪退怎么办(手机浏览器点开有不良视频出现怎么关闭)

  • 图片怎么重新命名(图片重新命名怎么弄)

    图片怎么重新命名(图片重新命名怎么弄)

  • 抖音私聊别人能看到吗(抖音私聊别人能第三者能看到吗)

    抖音私聊别人能看到吗(抖音私聊别人能第三者能看到吗)

  • nba2k16手机版rs键详解(nba2k16rs键使用教程)

    nba2k16手机版rs键详解(nba2k16rs键使用教程)

  • 拼多多AA收款能退款吗(拼多多收款码使用规则)

    拼多多AA收款能退款吗(拼多多收款码使用规则)

  • 苹果8怎么快速切换应用(苹果8怎么快速截图)

    苹果8怎么快速切换应用(苹果8怎么快速截图)

  • 键盘各键的作用(键盘各键的作用的讲解图)

    键盘各键的作用(键盘各键的作用的讲解图)

  • Bootstrap 框架详解(bootstrap框架的理解)

    Bootstrap 框架详解(bootstrap框架的理解)

  • 劳务公司一般纳税人要交什么税
  • 纳税义务发生时间记忆口诀
  • 利息收入交所得税吗
  • 资产负债表其他应付款包括哪些
  • 租车开发票属于什么类
  • 小规模纳税人劳务分包税率
  • 年报过期了
  • 长期利润分享计划属于短期薪酬吗
  • 留存收益筹资的优缺点
  • 进货没有开具发票能退吗
  • 会计人士教你在Excel中如何计算年均增长率
  • 住宿费可以开会议费吗
  • 对外捐赠衣物怎样入账
  • 工资做账原始凭证是什么
  • 地税有哪些税种类型
  • 保安服务费可以计入劳务费吗
  • 出现一窗式比对失败,该纳税人没有防伪税控比对信息!
  • 怎么查找使用手机的时间
  • 委托加工物资什么意思
  • Win10系统cpu性能如何调高 Win10把cpu性能调到极佳的方法
  • 滴滴发票开公司名称可以抵扣进项吗
  • 新建厂房环评流程
  • 笔记本电脑怎么重装系统
  • 我的世界1.12.2优化下载
  • linux递归创建目录命令
  • 失控发票进项转出后要补企业所得税吗
  • wordpress portfolio
  • php读取数据输出html
  • php面向对象的三大特征
  • 处置资产的账务处理
  • 作为大学生你能为国家安全贡献哪些力量论文
  • 总结的拼音
  • 小规模纳税人增值税月末处理
  • 可供分配利润是留存收益吗
  • 与下级往来账户贷方核算的内容有
  • 减值损失和减值损失区别
  • 减免税款账务处理
  • 一般纳税人企业所得税怎么征收
  • MySQL数据库远程登录
  • 个体户和公司的税收相差多少
  • 受托加工物资如何开票
  • 没有发票的支出怎么入账
  • 研发费用不一致说明
  • 国有划拔土地房整体可以买卖吗
  • 电子承兑没到期兑换手多少手续费
  • 固定资产未提完折旧
  • 采用成本法核算的长期股权投资
  • 计提印花税会计科目
  • 企业叉车折旧年限几年
  • 其他业务支出包括哪些内容科目
  • 无形资产的摊销方法
  • 数据库表的查询操作实验
  • windows vista home basic
  • windows server 2003与2008的区别联系与选择指南
  • 局域网 下载
  • ubuntu系统安装无线网卡驱动
  • ubuntu20.04怎么用
  • ubuntu-desktop启动
  • VMware虚拟机中安装MATE桌面环境
  • vc运行程序exe停止工作怎么办
  • ubuntu 4.10
  • win7电脑蓝牙图标怎么弄出来
  • macos触控
  • init systemd
  • win8装机教程
  • win降win7
  • 触发器csdn
  • shell脚本case语句判断成绩
  • cocos2dx4.0教程
  • java guns框架
  • js注释方法
  • python抓取软件界面数据
  • unity简单项目
  • jQuery+ajax+asp.net获取Json值的方法
  • 国家税务局浙江省电子税务局新版
  • 税务 涉税中介
  • 电子税务局个体工商户如何登陆
  • 客货两用车应如何运输
  • 吉林税务发票自动查询系统网
  • 党员e先锋中的支部云课堂在哪
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设