位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 苹果11pro充电器多少瓦(苹果11pro充电器是什么样的)

    苹果11pro充电器多少瓦(苹果11pro充电器是什么样的)

  • windows7系统移动硬盘提示格式化怎么办

    windows7系统移动硬盘提示格式化怎么办

  • 华为nova6手机卡怎么插(华为Nova6手机卡咋拿出来)

    华为nova6手机卡怎么插(华为Nova6手机卡咋拿出来)

  • 酷狗怎么通过昵称搜人(酷狗音乐只知道酷狗名怎么加好友)

    酷狗怎么通过昵称搜人(酷狗音乐只知道酷狗名怎么加好友)

  • 怎么把好评改成差评(怎么更改好评为差评)

    怎么把好评改成差评(怎么更改好评为差评)

  • 联想笔记本开机显示invalid(联想笔记本开机黑屏什么都不显示)

    联想笔记本开机显示invalid(联想笔记本开机黑屏什么都不显示)

  • 淘宝快速提升销量的方法(淘宝快速提升销售数据)

    淘宝快速提升销量的方法(淘宝快速提升销售数据)

  • 打印机感光鼓在哪(打印机感光鼓断开怎么弄)

    打印机感光鼓在哪(打印机感光鼓断开怎么弄)

  • 屏幕顶部触控失灵(屏幕顶部触控失灵且跳屏)

    屏幕顶部触控失灵(屏幕顶部触控失灵且跳屏)

  • i34130配什么主板(i34130配什么主板最好)

    i34130配什么主板(i34130配什么主板最好)

  • 苹果隔多久可以二退(苹果多久可以吃苹果)

    苹果隔多久可以二退(苹果多久可以吃苹果)

  • 拼多多运费怎么算(拼多多运费怎么那么便宜)

    拼多多运费怎么算(拼多多运费怎么那么便宜)

  • 微信视频背景怎么设置(微信视频背景怎么制作)

    微信视频背景怎么设置(微信视频背景怎么制作)

  • 苹果11提示灯怎么开启(苹果11提示灯怎么关掉)

    苹果11提示灯怎么开启(苹果11提示灯怎么关掉)

  • 海思麒麟810相当于骁龙多少(海思麒麟810相当于苹果几)

    海思麒麟810相当于骁龙多少(海思麒麟810相当于苹果几)

  • qq语音限制人数吗(qq语音聊天上限)

    qq语音限制人数吗(qq语音聊天上限)

  • 小爱同学音响不联网能用吗(小爱同学音响不通电怎么弄)

    小爱同学音响不联网能用吗(小爱同学音响不通电怎么弄)

  • 手机怎样在图片上打字(手机怎样在图片下面编辑文字)

    手机怎样在图片上打字(手机怎样在图片下面编辑文字)

  • vivoiqoo有红外线吗(vivoiqoo5红外线)

    vivoiqoo有红外线吗(vivoiqoo5红外线)

  • 抖音上怎么做照片卡点视频(抖音怎么做照片合集)

    抖音上怎么做照片卡点视频(抖音怎么做照片合集)

  • 手机开不了机了怎么办(手机开不了机了怎么解决)

    手机开不了机了怎么办(手机开不了机了怎么解决)

  • 怎样注册拼多多新人号(怎样注册拼多多商家版店铺)

    怎样注册拼多多新人号(怎样注册拼多多商家版店铺)

  • 爱奇艺怎么授权给别人(爱奇艺会员如何授权)

    爱奇艺怎么授权给别人(爱奇艺会员如何授权)

  • 探探vip有什么功能(探探vip有什么功能可以隐身吗)

    探探vip有什么功能(探探vip有什么功能可以隐身吗)

  • 华为双电信卡设置方法(华为手机双电信卡设置)

    华为双电信卡设置方法(华为手机双电信卡设置)

  • 苹果旁白有什么用(苹果旁白是)

    苹果旁白有什么用(苹果旁白是)

  • 华为私隐空间密码解除(华为隐私空间设置密码)

    华为私隐空间密码解除(华为隐私空间设置密码)

  • 微信运动记录运动数据在哪里(微信运动记录运动轨迹)

    微信运动记录运动数据在哪里(微信运动记录运动轨迹)

  • 爱奇艺如何关闭自动续费(爱奇艺如何关闭字幕功能)

    爱奇艺如何关闭自动续费(爱奇艺如何关闭字幕功能)

  • WordPress为旧文章批量设置特色图(wordpress文章保存在哪里)

    WordPress为旧文章批量设置特色图(wordpress文章保存在哪里)

  • 个人出租不动产税率
  • 买交强险需要把车开过去吗
  • 固定资产一次性折旧的账务处理和税务处理
  • 劳务公司临时工工资需要申报吗
  • 去税务局申报需要带营业执照吗
  • 事业单位财政拨款取得方式
  • 利润表中财务费用为负数是什么意思
  • 金税盘全额抵扣分录
  • 房产过户的相关问题
  • 股东拿不到钱
  • 境内企业得到境外企业的红利是否需要缴纳所得税?
  • 评估入账的开发权是否可以税前扣除?
  • 施工单位项目部牌子
  • 如何区分纳税调额和补税
  • 年度中期是几月份
  • 管理费用明细是什么意思
  • 公司不盈利用交税吗
  • 合伙企业分红是免税企业需要缴纳什么税
  • 企业购置房产折旧
  • 摊销费用多做如何做账?
  • 华为折叠手机mateXs3
  • 网购iphone注意什么
  • PHP:oci_new_connect()的用法_Oracle函数
  • php中表单的使用
  • php的编辑工具有哪些
  • 销售折让负数发票如何入账
  • 特许权使用费税前扣除标准
  • php axios
  • launcher是啥
  • java.exe进程可以关掉吗
  • 【安装 】
  • phpsutdy
  • 员工安置费标准出台
  • 交易性金融资产属于流动资产
  • 中科院院士2023增选
  • 企业预缴所得税怎么算
  • 月末结转本年利润吗
  • 长期借款账务处理会计分录怎么写
  • 机动车组织机构代码查询
  • jqueryfor
  • css选择器怎么用
  • 应收账款计提坏账准备方法
  • 用支票偿还货款
  • mongodb索引使用正则表达式
  • sqlserver怎么把数据库导出来
  • mysql5.7压缩包安装配置教程
  • mongodb数据库的层次结构
  • 使用命令方式安装程序
  • 怎么查企业历史
  • 暂估收入时会有哪些凭证
  • 一般纳税人销售旧货可以开专票吗
  • 什么是一般公共预算财政拨款
  • 以产品偿还债务怎么算
  • 新增建筑物
  • 法人转移公司资产怎么办
  • 公司租的房子电费发票怎么开
  • 待处理财产损益期末结转到哪里
  • mysql必知必会读书心得
  • 请创建一个die类
  • sqlserver2019的使用
  • win7一直弹广告怎么办
  • win7总是提示激活
  • 搜索功能使用方法
  • slmgr.vbs /dli
  • linux calloc
  • linux常用命令chown
  • 自建ss
  • win7 64位系统双击桌面所有程序提示"文件没有与之关联的程序来执行"的解决方法
  • win8系统如何
  • 如何检测装有监控器?
  • js填写input
  • javascriptz
  • 游戏开发吧
  • JavaScript驾驭网页-获取网页元素
  • js有哪些作用域,分别是什么意思
  • android自定义组件开发详解
  • 境外承包工程款收入
  • 贵州电子税务局app下载
  • 江苏税务app操作手册
  • 地税怎么交税
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设