位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • QQ 群推广的8个常用方法(qq群适合推广什么)

    QQ 群推广的8个常用方法(qq群适合推广什么)

  • 华为nova7se支不支持nfc(华为nova7se支不支持OTG)

    华为nova7se支不支持nfc(华为nova7se支不支持OTG)

  • 12寸屏幕长宽是多少(12寸屏幕长宽是多少英寸)

    12寸屏幕长宽是多少(12寸屏幕长宽是多少英寸)

  • iphonex可以使用nfc吗(iphonex可以使用无线充电吗?)

    iphonex可以使用nfc吗(iphonex可以使用无线充电吗?)

  • oppo应用通知在哪里设置(oppo软件通知怎样打开)

    oppo应用通知在哪里设置(oppo软件通知怎样打开)

  • 电视机无线网络连接失败原因(电视机无线网络开关打不开)

    电视机无线网络连接失败原因(电视机无线网络开关打不开)

  • 华为9x怎么截图全屏(华为9x怎么截图快捷键)

    华为9x怎么截图全屏(华为9x怎么截图快捷键)

  • 苹果三包法退换货条件(苹果手机三包规定退换货)

    苹果三包法退换货条件(苹果手机三包规定退换货)

  • 哔哩哔哩会员可以几个人用(哔哩哔哩会员可以几个人一起用)

    哔哩哔哩会员可以几个人用(哔哩哔哩会员可以几个人一起用)

  • ipad充电器可以充iphone11吗(ipad充电器可以充电脑吗)

    ipad充电器可以充iphone11吗(ipad充电器可以充电脑吗)

  • 微博加入黑名单 对方会知道吗(微博加入黑名单对方还能看我主页吗)

    微博加入黑名单 对方会知道吗(微博加入黑名单对方还能看我主页吗)

  • 天猫上车是什么意思(天猫上买车便宜吗)

    天猫上车是什么意思(天猫上买车便宜吗)

  • 华为mate30设置手机铃声(mate30怎么设置)

    华为mate30设置手机铃声(mate30怎么设置)

  • 港版7pa1661是全网通吗(港版苹果7pa1661支持电信吗)

    港版7pa1661是全网通吗(港版苹果7pa1661支持电信吗)

  • 手机键盘怎么打开(手机键盘怎么打勾号)

    手机键盘怎么打开(手机键盘怎么打勾号)

  • 苹果11pro可以双卡吗(苹果11pro可以双开微信吗)

    苹果11pro可以双卡吗(苹果11pro可以双开微信吗)

  • 拼多多差评怎样修改(拼多多怎么求评价)

    拼多多差评怎样修改(拼多多怎么求评价)

  • ipad双重认证是什么意思(双重认证ipad一直让输密码)

    ipad双重认证是什么意思(双重认证ipad一直让输密码)

  • 新手如何使用共享单车(新手如何使用共享单车一次多少饯)

    新手如何使用共享单车(新手如何使用共享单车一次多少饯)

  • 微信怎么刷公交车(微信怎么刷公交卡)

    微信怎么刷公交车(微信怎么刷公交卡)

  • 漫画app开发有什么功能(漫画app开源)

    漫画app开发有什么功能(漫画app开源)

  • wan口设置已断开服务器无响应(wan口设置已断开服务器无响应怎么解决)

    wan口设置已断开服务器无响应(wan口设置已断开服务器无响应怎么解决)

  • win10系统的电脑怎么开启在线语音识别?(win10系统的电脑能装win7吗)

    win10系统的电脑怎么开启在线语音识别?(win10系统的电脑能装win7吗)

  • Flex 4 的十大变化(flex:4)

    Flex 4 的十大变化(flex:4)

  • 工资表个税多扣了账务处理递减
  • 免税收入不征税收入有哪些
  • 增值税的计税依据是什么
  • 二手房个人所得税是买方交还是卖方交
  • 技术服务的分录
  • 新准则经营租赁会计分录
  • 残保金滞纳金可以税前扣除吗
  • 卷式发票是什么样的
  • 全额工资是到手工资吗
  • 债权转增资本应缴纳什么税
  • 筹建期间有收入怎么办
  • 营改增一般纳税人标准
  • 海关免税设备清单
  • 进口货物会计分录举例
  • 公司聘请专家的差旅费可以税前扣除吗
  • 税金及附加怎么计提
  • 冲回多提所得税
  • 退役士兵创业就业政策
  • 纳税人提供植物油的税率
  • 差额征税的小微企业免税销售额
  • 兼营非应税劳务行为举例
  • 无票收入增值税申报表怎么填小规模纳税人
  • 企业所得税分期收款确认收入的时间政策
  • 购买预付卡账务处理
  • 促销购买
  • 已经折旧完的固定资产怎么处理
  • 分公司是小微企业总公司是一般纳税人,如何做合并报表
  • avcodec是什么意思
  • win10系统损坏开不了机
  • icon files
  • 职业病治疗费用谁承担
  • 长期借款利息费用的资本化账务处理
  • PHP:imagecreatefromxbm()的用法_GD库图像处理函数
  • 应收票据到期后账务处理
  • 航空业燃油附加率是多少
  • php日期计算器
  • php返回数据给ajax
  • 蓝桥杯b组2020
  • gpt指标
  • ipcrm命令
  • dpkg命令详解
  • mkfifo命令
  • 个税申报系统操作指南
  • 应收应付核销规则及常见问题
  • 销售费用进项税额转出会计分录怎么写
  • 进项税没入账补入账分录
  • 共同开发产品
  • 运费计入什么会计分录
  • 现金短缺或溢余指的是什么
  • 弥补以前年度亏损怎么算
  • 客户多付的货款 不用退回 进营业外收入吗
  • 代扣缴纳税款会计分录
  • 银行汇票计入什么会计科目
  • 差额征税问题
  • 计提印花税会计分录
  • 茶农的茶叶自产自销需要办哪些证
  • 资产负债表中的货币资金怎么算
  • mysql的多表查询语句
  • win8专业版系统更改电脑设置没反应
  • wsinspector.exe是什么进程
  • 在Ubuntu Trusty 14.04 (LTS) (64-bit)安装Docker的步骤
  • centos7 login账号
  • CentOS系统中与时间的相关命令详解
  • WINDOWS系统还原主要作用
  • centos init
  • hkcmd是什么进程
  • 无损音乐如何播放
  • Win10桌面任务栏能不能删除
  • win8系统咋样
  • unix网络命令
  • win10预览版21301bug
  • linux文件系统的根目录的i节点号为
  • html模板 js
  • 数组observer
  • linux实现shell
  • js onkeypress与onkeydown 事件区别详细说明
  • jQuery使用animate实现ul列表项相互飘动效果示例
  • ugui粒子ui层级
  • 建筑企业如何
  • 广东省国税局局长潘
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设