位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 粘贴自iphone弹窗怎么关闭(粘贴自iphone弹窗什么意思)

    粘贴自iphone弹窗怎么关闭(粘贴自iphone弹窗什么意思)

  • 小米手环可以连接其他手机吗(小米手环可以连接iphone吗)

    小米手环可以连接其他手机吗(小米手环可以连接iphone吗)

  • qq幽灵表情什么意思(qq表情幽灵啥意思)

    qq幽灵表情什么意思(qq表情幽灵啥意思)

  • usb无线网卡无法连接(怎样连接wifi网络)

    usb无线网卡无法连接(怎样连接wifi网络)

  • 微博会员苹果无法购买(苹果开微博会员不生效)

    微博会员苹果无法购买(苹果开微博会员不生效)

  • 手机卡就是sim卡吗(电话卡是sim)

    手机卡就是sim卡吗(电话卡是sim)

  • 抖音点赞失效解决办法(抖音点赞失效解决方案)

    抖音点赞失效解决办法(抖音点赞失效解决方案)

  • qq匹配聊天怎么开(qq匹配聊天怎么开启)

    qq匹配聊天怎么开(qq匹配聊天怎么开启)

  • 爱思助手能检测出苹果零件有没有被换过吗(爱思助手能检测苹果手表吗)

    爱思助手能检测出苹果零件有没有被换过吗(爱思助手能检测苹果手表吗)

  • 手机qq软件打不开了怎么办(手机qq软件打不开怎么回事)

    手机qq软件打不开了怎么办(手机qq软件打不开怎么回事)

  • 爱奇艺怎么设置儿童模式(爱奇艺怎么设置几个人用)

    爱奇艺怎么设置儿童模式(爱奇艺怎么设置几个人用)

  • qq安全达人图标怎么熄灭(qq安全达人图标上面有一个红杠)

    qq安全达人图标怎么熄灭(qq安全达人图标上面有一个红杠)

  • imei imsi权限是什么意思(im权限是啥)

    imei imsi权限是什么意思(im权限是啥)

  • 电脑域名解析错误怎么解决(电脑域名解析错误是怎么回事)

    电脑域名解析错误怎么解决(电脑域名解析错误是怎么回事)

  • 天猫精灵cc怎么视频通话(天猫精灵CC怎么激活)

    天猫精灵cc怎么视频通话(天猫精灵CC怎么激活)

  • 华为tdlte什么型号(华为tdlte什么型号手机)

    华为tdlte什么型号(华为tdlte什么型号手机)

  • 苹果xsmax带指纹吗(苹果xs max有指纹)

    苹果xsmax带指纹吗(苹果xs max有指纹)

  • 抖音注销后能重新注册吗(抖音注销后能重新登陆吗)

    抖音注销后能重新注册吗(抖音注销后能重新登陆吗)

  • 荣耀note10能当遥控器吗(荣耀note10带红外遥控吗)

    荣耀note10能当遥控器吗(荣耀note10带红外遥控吗)

  • 微信发什么会掉东西有特效(微信发什么会掉落东西)

    微信发什么会掉东西有特效(微信发什么会掉落东西)

  • 如何运营公众号(如何运营公众号月入3万)

    如何运营公众号(如何运营公众号月入3万)

  • iphonexr定时开机设置(iphonexr定时开关机)

    iphonexr定时开机设置(iphonexr定时开关机)

  • 程序坞是什么意思(程序坞三部分)

    程序坞是什么意思(程序坞三部分)

  • 多多果园水滴福袋提醒怎么关闭(多多果园水滴福利退单了会怎么样)

    多多果园水滴福袋提醒怎么关闭(多多果园水滴福利退单了会怎么样)

  • 电费发票隔月开如何做账
  • 航天金税盘使用说明
  • 4s店事故处理流程
  • 小规模纳税人开普票要交税吗
  • 管理费用与税金及附加哪个会影响利润
  • 房地产开发企业销售自行开发的房地产项目
  • 附加税减半征收的条件
  • 视同销售收入是纳税调整项目吗
  • 企业可以一次性补交员工十年养老保险吗
  • 企业开办费可不交税吗
  • 分支机构属于小型微利企业吗
  • 农产品加工需要交税吗
  • 买车支付的车辆购置税怎么入账
  • 个体户不足征是否要交房产税
  • 小规模纳税人销售农产品税率是多少
  • 模具维修费用清单表格
  • 加工属于什么税收分类
  • 不能取得进项发票但结转成本,税务说明怎么写
  • 工会经费可以不提吗
  • 教资认定流程详细步骤2023
  • 销售应税消费品应交的消费税分录
  • 支付保洁费用
  • 研发费用入账
  • 农业机耕服务是什么税收分类编码
  • 待抵扣进项税额借贷方向
  • 分支机构增值税汇总纳税怎么申报?
  • 一般纳税人出租不动产
  • 企业所得税退税怎么做账务处理
  • window10拖动窗口的手势
  • doc文档隐藏
  • 幼儿园会计做账实操
  • linux太卡顿
  • 金针菜的养殖方法和技术
  • 电脑xmp是啥
  • deepin缩放
  • win10dev预览版
  • 摊销无形资产会影响无形资产的账面价值吗
  • 管理费用属于产品成本项目的费用吗
  • 购买免税农产品的会计分录
  • 元宇宙与nft
  • 微信php接口
  • php require函数
  • 没有残疾人就业保障金需要申报吗
  • 施工企业内部常设置?主要出实验报告
  • 二手固定资产怎么折旧
  • 小规模纳税人的企业所得税怎么算
  • 银行回单打回来会计要做什么
  • 计提事业发展基金分录
  • 定金罚则可以约定吗
  • 开增值税专用发票的好处
  • 物流辅助服务是
  • 飞机票抵扣进项税怎么填附表二
  • 研发费用加计扣除的条件
  • 当期应纳增值税税额的计算流程
  • 小规模无票收入纳税申报表怎么填
  • 研发支出 期末
  • 货物已发出可以退款吗
  • 设备升级是什么意思
  • 当月作废的专票还是要交增值税吗
  • 预付卡发票如何做分录
  • 保险公司理赔款如何入账
  • 银行存款明细账借方代表什么
  • 长期无法收回的应收账款如何处理
  • 企业内建立小企业属于哪个阶段
  • mysql的基本操作语句
  • mysql 5.7启动
  • 在Linux环境下mysql的root密码忘记解决方法(三种)
  • mysql创建用户密码命令
  • kali linux 视频教程
  • 删除xp本地保存的视频
  • mac系统有txt吗
  • 如何让xp系统崩溃
  • windows预览0x80072ee7
  • bootstrap modal 位置
  • tomcat怎么启动
  • 常用的批处理命令
  • 如何开具分包发票流程
  • 税务登记证办理
  • 亚马逊利用大数据练就读心术
  • 廉政谈话什么是廉政?
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设