位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 苹果13pro像素多少万(iphone13pro照片像素)

    苹果13pro像素多少万(iphone13pro照片像素)

  • 查看ip命令

    查看ip命令

  • nova7se和nova6有哪些区别(nova7和nova6se哪个好)

    nova7se和nova6有哪些区别(nova7和nova6se哪个好)

  • 华为p40pro后盖是什么材质(华为p40pro后盖是否原装)

    华为p40pro后盖是什么材质(华为p40pro后盖是否原装)

  • 操作系统负责管理计算机的(操作系统负责管理所有)

    操作系统负责管理计算机的(操作系统负责管理所有)

  • 能上微信不能打开网页(能上微信不能打开网页win10)

    能上微信不能打开网页(能上微信不能打开网页win10)

  • 拼多多预售商品发货得多长时间(拼多多预售商品可靠吗)

    拼多多预售商品发货得多长时间(拼多多预售商品可靠吗)

  • 平板可以连接打印机吗(平板可以连接打印机开发票吗)

    平板可以连接打印机吗(平板可以连接打印机开发票吗)

  • 苹果max接电话声音小(iphone xs max接电话声音突然变小)

    苹果max接电话声音小(iphone xs max接电话声音突然变小)

  • 11悬浮球怎么打开(悬浮球怎么打不开了)

    11悬浮球怎么打开(悬浮球怎么打不开了)

  • blued视频聊天会被录屏吗(blued视频聊天会被录吗)

    blued视频聊天会被录屏吗(blued视频聊天会被录吗)

  • 惠普打印机打不出来字怎么回事(惠普打印机打不出黑色怎么回事)

    惠普打印机打不出来字怎么回事(惠普打印机打不出黑色怎么回事)

  • 手机换内屏对手机有影响吗(手机换内屏对手机里的照片有影响吗)

    手机换内屏对手机有影响吗(手机换内屏对手机里的照片有影响吗)

  • 电脑微信可以视频聊天吗(电脑微信可以视频吗)

    电脑微信可以视频聊天吗(电脑微信可以视频吗)

  • icloud注销是退出吗(icloud注销和退出登录一样吗)

    icloud注销是退出吗(icloud注销和退出登录一样吗)

  • vivos5怎么开启游戏模式(vivoy5s怎么开游戏模式)

    vivos5怎么开启游戏模式(vivoy5s怎么开游戏模式)

  • 手机栏显示hd怎样关掉(手机界面上显示hd)

    手机栏显示hd怎样关掉(手机界面上显示hd)

  • 闲鱼上卖东西要注意什么(闲鱼卖东西要银行卡吗)

    闲鱼上卖东西要注意什么(闲鱼卖东西要银行卡吗)

  • 淘宝直播标签重要吗(淘宝主播标签能更换吗)

    淘宝直播标签重要吗(淘宝主播标签能更换吗)

  • 苹果5.5寸是什么型号(苹果5.5英寸是什么型号)

    苹果5.5寸是什么型号(苹果5.5英寸是什么型号)

  • 拼多多能上传视频吗(拼多多能上传视频赚钱吗)

    拼多多能上传视频吗(拼多多能上传视频赚钱吗)

  • oppok3是什么机身(oppok3是5g手机吗?)

    oppok3是什么机身(oppok3是5g手机吗?)

  • 苹果x录音功能在哪里(苹果x录音功能怎么打开)

    苹果x录音功能在哪里(苹果x录音功能怎么打开)

  • Vue中@change、@input和@blur的区别以及什么是@keyup

    Vue中@change、@input和@blur的区别以及什么是@keyup

  • python列表清除元素的四种方式(python中列表清空)

    python列表清除元素的四种方式(python中列表清空)

  • phpcms怎么登陆后台(phpcms手机端)

    phpcms怎么登陆后台(phpcms手机端)

  • 营业税金及附加计算公式
  • 印花税缴纳方式一经选择1年之内不得修改
  • 小企业会计准则调整以前年度费用分录
  • 发票面额增大
  • 小规模纳税人免税销售额是含税还是不含税
  • 工会经费免征三年需要申报吗
  • 代扣业务员佣金怎么做账
  • 出口退税收汇凭证是什么
  • 公司购买二手车怎么抵税
  • 海关对现金携带数量有要求吗
  • 向境外企业购买国内企业股权
  • 研发用的原材料怎么开领料单
  • 质保金算合同资产
  • 中介收中介费后就不管了
  • 去年的季度所得税额怎么做账
  • win7 excel
  • 事业单位基建账并入大账规定
  • 多付货款退回的法律依据
  • 教学用具属于什么项目类别
  • 多交的所得税退回来账务处理
  • 限额领料单属于外来原始凭证吗
  • PHP:pg_prepare()的用法_PostgreSQL函数
  • win11修改版
  • msoicons.exe是什么文件
  • 解决问题
  • 发包工程补付工程款分录
  • 营改增开始时间
  • 文化建设事业费优惠政策
  • re.findall()用法
  • 税局 不负责任
  • thinkphp自定义标签page
  • 固定资产为什么提折旧,有何实际意义
  • 集合框架有何好处
  • java泛型简单例子
  • 补充医疗保险属于什么
  • access创建一个表
  • pytest unittest
  • 小规模当月开普票作废流程
  • 股权收购被收购方怎么做账
  • 子公司注销母公司投资损失企业所得税
  • 每月工资不一样怎么算误工费呢
  • 资产负债表应交税费是负数正常吗
  • 支付贷款手续费怎么入账
  • 以前多计提的税款怎么办
  • 研发费用的核算方法
  • 车辆保险都入什么
  • 净值怎么算?
  • 员工借款可以直接转账吗
  • 城镇土地使用税减免税政策
  • 阿里云linux 服务器 字符集
  • win7安装mysql5.5
  • unix是什么语言
  • solaris ssh offline
  • mac安装windows10体验
  • 彻底关闭windows10自动更新工具
  • 苹果系统声音怎么设置方法
  • cmd.exe是什么意思
  • linux卸载apache2
  • pqtray.exe - pqtray 是什么进程 有什么用
  • linux php 开发教程
  • linux使用vi编辑文件
  • win8怎么隐藏桌面图标
  • 边做游戏边学
  • js获取中文拼音
  • python利用for循环求1到100的奇数之和
  • jQuery使用$.ajax提交表单完整实例
  • 基于web的旅游网站毕业设计
  • Android使用领域是什么
  • ANDROID手机客户端软件开发工程师
  • AssetBundle.Unload(false)的作用
  • python自动化源码
  • flash谈广告
  • python爬虫程序下载网页上内容
  • Android网络通讯哪个最简单
  • 12333医保缴费具体步骤
  • 个人所得税完税证明在哪里查询
  • 公司购买车辆是什么费用
  • 沙子属于矿产资源
  • 银行是不是要交社保
  • 企业所得税零申报
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设