位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 腾讯视频vip会员怎么共享给别人登录(怎么登陆别人的腾讯视频vip会员)

    腾讯视频vip会员怎么共享给别人登录(怎么登陆别人的腾讯视频vip会员)

  • 全民k歌怎么上热门(全民k歌怎么上传自己的录音)

    全民k歌怎么上热门(全民k歌怎么上传自己的录音)

  • 华为的黑名单去哪找(华为黑名单在哪里移除)

    华为的黑名单去哪找(华为黑名单在哪里移除)

  • 为什么word转pdf目录出现错误(为什么word转pdf目录出现未定义书签)

    为什么word转pdf目录出现错误(为什么word转pdf目录出现未定义书签)

  • win10鼠标移动方向相反(win10鼠标自动移动到窗口边缘)

    win10鼠标移动方向相反(win10鼠标自动移动到窗口边缘)

  • 微信发送的文件怎样改文件名(微信发送的文件怎么让它失效)

    微信发送的文件怎样改文件名(微信发送的文件怎么让它失效)

  • 苹果11关机充电怎么没有显示(苹果11关机充电开机不充电)

    苹果11关机充电怎么没有显示(苹果11关机充电开机不充电)

  • 拼多多3分钟回复率不能低于多少(拼多多3分钟回复率多久刷新)

    拼多多3分钟回复率不能低于多少(拼多多3分钟回复率多久刷新)

  • 蓝牙充电仓一直闪红灯(蓝牙充电仓一直亮绿灯什么意思)

    蓝牙充电仓一直闪红灯(蓝牙充电仓一直亮绿灯什么意思)

  • 光圈数值越小光圈越大吗

    光圈数值越小光圈越大吗

  • 闲鱼曝光0怎么回事(闲鱼曝光怎么回事)

    闲鱼曝光0怎么回事(闲鱼曝光怎么回事)

  • 老家信号差怎么能增强(老家信号差网络差咋办)

    老家信号差怎么能增强(老家信号差网络差咋办)

  • 苹果手机可不可以下载两个微信(苹果手机可不可以登陆两个微信)

    苹果手机可不可以下载两个微信(苹果手机可不可以登陆两个微信)

  • 手机显示3g是什么意思(手机上面显示3g是什么意思)

    手机显示3g是什么意思(手机上面显示3g是什么意思)

  • 微信群右上角有个圈点开是视频(微信群右上角有个蓝色圈圈)

    微信群右上角有个圈点开是视频(微信群右上角有个蓝色圈圈)

  • 抖音别人直播怎么点亮(抖音别人直播怎么录屏保存视频)

    抖音别人直播怎么点亮(抖音别人直播怎么录屏保存视频)

  • 怎么找回抖音密码(找回我原来的抖音)

    怎么找回抖音密码(找回我原来的抖音)

  • 苹果怎么关掉拍照声音(苹果怎么关掉拍照咔嚓声)

    苹果怎么关掉拍照声音(苹果怎么关掉拍照咔嚓声)

  • 为什么闲鱼发布的东西没有浏览量(为什么闲鱼发布没人看)

    为什么闲鱼发布的东西没有浏览量(为什么闲鱼发布没人看)

  • 南卡蓝牙耳机是国产吗(南卡蓝牙耳机是上市公司吗)

    南卡蓝牙耳机是国产吗(南卡蓝牙耳机是上市公司吗)

  • 华为优享版啥区别(华为手机优享版和标准版)

    华为优享版啥区别(华为手机优享版和标准版)

  • 闲鱼怎么清除所有聊天记录(闲鱼如何清空)

    闲鱼怎么清除所有聊天记录(闲鱼如何清空)

  • 怎么删除pr不用的项目(怎么删除pr不用的音频)

    怎么删除pr不用的项目(怎么删除pr不用的音频)

  • 特许权使用费所得包括
  • 销售货物应税劳务服务清单给客户一份是不是就可以了
  • 电子发票会自动发送到邮箱吗
  • 生产企业的基础设施是指
  • 一般纳税人降为小规模还能升为一般纳税人吗
  • 小规模给一般纳税人开专票能抵扣吗
  • 通过应交税费核算的
  • 机动船舶缴纳车船税吗
  • 掌握无形资产核算方法
  • 所有逾期未抵扣进项税额
  • 利润分配包括缴纳所得税吗
  • 开具红字发票抵扣后如何退税?
  • 公司购买网络交换机入什么会计科目
  • 小企业所得税申报流程
  • 上个月银行流水没有录这个月补录
  • 预付材料款会计分录
  • 低价销售代替非正常损失避免转出进项税?
  • 高新技术企业优惠政策叠加
  • 金融业保险业
  • 固定资产清理的借贷方向表示什么
  • 软件行业服务费印花税
  • 国内废钢贸易需要资质吗
  • 家具属于固定资产什么类别
  • 路由器怎么设置2.4g网络
  • 解放双手神器说说
  • 跨月的成本费用如何,调整
  • 企业签订的借款合同印花税
  • 汇款手续费应由谁承担
  • php面包屑导航
  • 普通发票主营业务收入销项负数发票怎么做账
  • 其他应付款很多
  • code ide
  • 出差补贴如何入账报销
  • b/s架构的正确解释方式
  • 简述php中常用魔术方法及其各自的作用
  • 已销售未出库如何结转成本
  • 员工的通讯费怎么报销
  • php mysql_list_dbs()函数用法示例
  • 纳税人数字签名怎么填
  • vue封装组建
  • 【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种
  • 房地产开发企业资质证书
  • 织梦系统网站搭建教程
  • 普通的增值税
  • 个税申报信息怎么填
  • 对于投资者而言购买债券型理财产品面临的最大风险来自
  • sql20054n
  • 毛利率在餐饮中的应用
  • 固定资产计入管理费用就不用折旧了吗
  • 单位购买的化妆品怎么用
  • 个体户是什么概念
  • 国有控股企业和国有参股企业的区别
  • 应收账款期初余额在借方还是贷方
  • 成本计算的方法定额法
  • 进项抵扣怎么操作
  • 上年度固定资产费用化了,财报怎么算
  • 安全生产费会计准则
  • 本年利润在明细里怎么填
  • 会计账务处理程序有哪些类型
  • 物资采购是
  • 非盈利社会团体可以开发票吗
  • 小规模纳税人注册资金要求多少
  • 支票有效期过期了怎么办
  • ubuntu docker教程
  • fedora常用命令
  • mac识别文字软件
  • 苹果电脑怎么下魔兽争霸
  • JavaScript对HTML DOM使用EventListener进行操作
  • shell 脚本 判断
  • vue eventhub
  • js的继承方式
  • jquery课程总结
  • 原生js实现promise
  • windows python2和python3共存
  • js 对象key
  • swift method swizzling
  • 深圳国税局发票查询
  • 2020年增值税运费税率是多少
  • 副局长是由局长任命的吗
  • 2022年广州社保基数
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设