位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • vivox80pro怎么唤醒小v(vivox70怎么唤醒屏幕)

    vivox80pro怎么唤醒小v(vivox70怎么唤醒屏幕)

  • 笔记本闪屏怎么办(笔记本闪屏怎么强制关机)

    笔记本闪屏怎么办(笔记本闪屏怎么强制关机)

  • 小米10是否支持反向充电(小米10是否支持DP协议)

    小米10是否支持反向充电(小米10是否支持DP协议)

  • 怎么把ppt转成word(怎么把ppt转成word版本)

    怎么把ppt转成word(怎么把ppt转成word版本)

  • 苹果11怎样打开电池百分比(苹果11怎样打开麦克风)

    苹果11怎样打开电池百分比(苹果11怎样打开麦克风)

  • 苹果无线耳机二代和三代的区别(苹果无线耳机二代pro)

    苹果无线耳机二代和三代的区别(苹果无线耳机二代pro)

  • 朋友圈三张图片并排是怎么发的 (朋友圈三张图片并排是怎么发的?)

    朋友圈三张图片并排是怎么发的 (朋友圈三张图片并排是怎么发的?)

  • 内存储器相对于外存储器的特点是(内存储器相对于外存储器的特点是容量小,速度快)

    内存储器相对于外存储器的特点是(内存储器相对于外存储器的特点是容量小,速度快)

  • z5x升级版什么意思(z5x升级版和z5x有什么区别)

    z5x升级版什么意思(z5x升级版和z5x有什么区别)

  • 扫二维码付款可不可以查到那个人(扫二维码付款可以查到对方的微信号吗)

    扫二维码付款可不可以查到那个人(扫二维码付款可以查到对方的微信号吗)

  • 家里wifi卡是什么原因(家里面wifi卡怎么办)

    家里wifi卡是什么原因(家里面wifi卡怎么办)

  • b站硬币有什么用(b站硬币有什么用返多少)

    b站硬币有什么用(b站硬币有什么用返多少)

  • 表格怎么设置成双面打印(表格怎么设置成10×10的)

    表格怎么设置成双面打印(表格怎么设置成10×10的)

  • 什么是通信网络协议(通信网络的定义)

    什么是通信网络协议(通信网络的定义)

  • 支付宝收款语音播报怎么没声音(支付宝收款语音播报怎么添加店员)

    支付宝收款语音播报怎么没声音(支付宝收款语音播报怎么添加店员)

  • 苹果6可以投屏电视吗(苹果手机怎么投屏到苹果电脑上)

    苹果6可以投屏电视吗(苹果手机怎么投屏到苹果电脑上)

  • 华为mate30视频带美颜吗

    华为mate30视频带美颜吗

  • ps怎么拉长背景(ps怎么拉长背景颜色)

    ps怎么拉长背景(ps怎么拉长背景颜色)

  • 来电充电宝最迟多久还(来电充电宝最迟多久充满)

    来电充电宝最迟多久还(来电充电宝最迟多久充满)

  • iwatch4耐克版和普通版区别(iwatch耐克版区别)

    iwatch4耐克版和普通版区别(iwatch耐克版区别)

  • 电脑怎么下载微信到桌面(电脑怎么下载微博)

    电脑怎么下载微信到桌面(电脑怎么下载微博)

  • 怦然心动漫画(怦然心动漫画结局)

    怦然心动漫画(怦然心动漫画结局)

  • 怎么用手机拍一寸照片(怎么用手机拍一寸蓝底照片)

    怎么用手机拍一寸照片(怎么用手机拍一寸蓝底照片)

  • 华为p30解锁方式(华为p30怎样解锁锁屏密码)

    华为p30解锁方式(华为p30怎样解锁锁屏密码)

  • 工资表怎么导入个税系统计算个税
  • 可转债会计分录利息
  • 股东不任职
  • 商业折扣影响主要因素有
  • 无形资产登记什么明细账
  • 退税会计科目怎么做账
  • 5元印花税怎么申报
  • 其他综合收益转入投资收益
  • 过桥费是多少
  • 成本 费用区别
  • 增值税一般纳税人资格登记表
  • 建筑业异地预缴增值税流程
  • 印花税按什么征收
  • 库存盘亏的原因
  • 企业存款利息计入什么科目
  • 系统技术维护费计入什么科目
  • 拨款和支出的区别
  • 事业单位其他应付款贷方余额表示什么
  • 免征增值税对应的进项税额怎么处理
  • 农场管委会是什么性质单位
  • 营改增的税收政策
  • 环境保护税即将施行 有哪些点需要关注?
  • 金税四期具体内容
  • 收到运费发票是进项还是销项
  • 税前扣除的固定资产
  • 农业大棚卷帘机用什么油
  • 物流公司运费账务处理
  • 收到融资租赁发票怎么做账
  • 苹果x如何显示电量数字
  • 已经抵扣增值税专用发票对方要换票怎么办
  • 配置path环境变量
  • 公司收到保险公司赔款
  • linux怎么查看防火墙信息
  • php require的用法
  • erl.exe是什么进程
  • 以固定资产换入无形资产
  • 银行收取对公账户服务费有什么用
  • axios异步请求数据
  • 商贸企业税收优惠政策
  • 2023最新最全的祈祷视频
  • react+go
  • 发票验证校验码为什么只能填6位
  • 红字信息表必须要原件吗
  • 从公账发工资是什么凭证
  • 进项税转出的会计分录
  • 公司购买烟酒怎么入账
  • mysql内存使用详解
  • 人工成本与工资的关系
  • 营业外收入汇算清缴时需要调增吗
  • 跨年预收账款被税局要求确认收入怎么交增值税
  • 分公司撤销跨区经营
  • 软件公司会计科目
  • 一般纳税人销售旧货税率
  • 一定要正颌吗
  • 营业账簿印花税怎么申报
  • 每个月交工会经费
  • 资产负债表中的应交税费包括什么
  • 待报解预算收入给我转了钱是什么意思
  • 案例分析以前年龄的变化
  • 劳务公司拿什么挣钱
  • 其他债权投资减值影响账面价值吗
  • ubuntu和window双系统
  • ios自定义应用图标
  • xp系统怎么和win7系统共享
  • ip地址xp系统
  • linux启动u盘制作
  • mac怎么共享打印机设备
  • centos6.5升级到7.5
  • jquery Ajax 全局调用封装实例详解
  • js函数详解
  • 菜鸟app兼职
  • shell字符串操作命令
  • unity3d控制物体移动
  • jquery切换css样式
  • 安卓手机管家
  • python生成器send
  • unity物体碰撞爆炸
  • 夫妻相聚
  • 攸县丧葬
  • 矿产资源税税率2020年
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设