位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 硬盘如何分区(win10固态硬盘如何分区)

    硬盘如何分区(win10固态硬盘如何分区)

  • 苹果手机怎么设置闹钟铃声(苹果手机怎么设置来电拦截)

    苹果手机怎么设置闹钟铃声(苹果手机怎么设置来电拦截)

  • vivo x27联系人怎么导入sim卡(vivox27联系人怎么存到手机里)

    vivo x27联系人怎么导入sim卡(vivox27联系人怎么存到手机里)

  • 佳能打印机灯闪烁代表什么原因(佳能打印机灯闪烁无法打印怎么办)

    佳能打印机灯闪烁代表什么原因(佳能打印机灯闪烁无法打印怎么办)

  • b站id怎么看注册时间(怎么看b站注册账号)

    b站id怎么看注册时间(怎么看b站注册账号)

  • 打印机卷纸怎么解决(打印机卷纸怎么拿出来)

    打印机卷纸怎么解决(打印机卷纸怎么拿出来)

  • 华为nova7se和nova7pro的区别(华为nova7se和nova7哪个好)

    华为nova7se和nova7pro的区别(华为nova7se和nova7哪个好)

  • 打印机e-是什么故障(打印机e-是什么故障京瓷)

    打印机e-是什么故障(打印机e-是什么故障京瓷)

  • 华为nova7se怎么分屏(华为nova7se怎么设置门禁卡)

    华为nova7se怎么分屏(华为nova7se怎么设置门禁卡)

  • 为什么没有分享二维码别人却进了群(转转为什么没有分享)

    为什么没有分享二维码别人却进了群(转转为什么没有分享)

  • 有线耳机连不上手机(华为有线耳机连不上)

    有线耳机连不上手机(华为有线耳机连不上)

  • airpods怎么连接电脑(airpods怎么连接安卓手机)

    airpods怎么连接电脑(airpods怎么连接安卓手机)

  • 安装包的后缀是什么(安装包后缀是iso怎么打开)

    安装包的后缀是什么(安装包后缀是iso怎么打开)

  • ipad如何增加页面(ipad文稿怎么加页数)

    ipad如何增加页面(ipad文稿怎么加页数)

  • word样本模板在哪(world样本模板 在哪里)

    word样本模板在哪(world样本模板 在哪里)

  • word怎么让下划线一样长(word怎么让下划线长度一样)

    word怎么让下划线一样长(word怎么让下划线长度一样)

  • 苹果十一支持无线充电吗(苹果十一支持无线充电么)

    苹果十一支持无线充电吗(苹果十一支持无线充电么)

  • iqoo怎么开启性能模式(iqoo手机性能模式怎么开)

    iqoo怎么开启性能模式(iqoo手机性能模式怎么开)

  • 抖音撤销消息对方能看到吗(抖音撤销消息对谁有影响)

    抖音撤销消息对方能看到吗(抖音撤销消息对谁有影响)

  • 支付宝怎么信用卡还贷(支付宝怎么信用卡提现)

    支付宝怎么信用卡还贷(支付宝怎么信用卡提现)

  • qq闺蜜关系怎么绑定(qq闺蜜关系怎么调日期)

    qq闺蜜关系怎么绑定(qq闺蜜关系怎么调日期)

  • 冷门暴利生意——清洁公司(冷门暴利行业)

    冷门暴利生意——清洁公司(冷门暴利行业)

  • 【计算机视觉】数字图像处理(四)—— 图像增强(计算机视觉的未来发展方向有哪些)

    【计算机视觉】数字图像处理(四)—— 图像增强(计算机视觉的未来发展方向有哪些)

  • 核定征收一般纳什么税
  • 个体户需要税务申报吗?
  • 固定资产新规则
  • 税盘抵扣的会计分录
  • 营业外收入记账
  • 公司简易注销需要清算吗
  • 固定资产计提折旧的原则
  • 股权转让个人所得税如何申报
  • 金税三期国地税合并
  • 土地评估费计入什么会计科目
  • 餐馆的前期投资预算
  • 换出资产为固定资产,差额计入
  • 地税印花税税率是多少
  • 混合销售行为如何缴纳消费税
  • 关于住宿费增值税专发票抵扣问题
  • 小企业会计准则成本核算方法选什么
  • win11怎么关闭进程
  • win10系统中怎么共享文件
  • 结转本月各项损益
  • 上个月退货会计分录
  • php中表单的使用
  • 固定资产减值如何确定
  • w11系统安卓
  • mac怎么切换输入方式
  • 处置子公司的收益
  • 劳务所得税怎么计算公式
  • 销售折扣增值税如何处理
  • 个体工商户与家庭生活难以划分的费用
  • 拱门国家公园景点
  • 递延收益摊销金额
  • php有哪些
  • 收到投资款的会计科目怎么入账
  • wrap激活
  • php判断用户是否登录
  • 民间非营利组织如何纳税
  • joinby命令
  • 长期借款的处理原则
  • 办公用品和低值易耗品节省成本吗
  • 织梦cms为什么不维护了
  • 当月销售次月开票就按次月申报
  • 罚款记入其他应收款科目
  • php中isset函数作用
  • 生产企业出口转内销增值税申报表怎么填
  • 在阿里云的云主机之间怎么通信
  • mongodb 教程
  • 税法增值税的不同
  • 管理费用的借贷科目
  • 简易计税开票税率
  • 税率开错了会影响贷款吗
  • 其他应付款不用付了会计分录
  • 待处理财产损益期末余额在哪方
  • 个人所得税工资薪金包括哪些内容
  • 运输公司发票抵扣
  • 个人咨询费发票怎么开
  • 增值税当月缴纳还是次月缴纳
  • 上月多出来的薪资怎么算
  • 暂估入库需要入什么科目
  • 出口样品未报关处罚
  • 申报财产租赁合同怎么写
  • 技术服务发票怎么做成本
  • 大学里学分不满不让毕业是真的吗
  • winXP系统截图
  • 怎么用ubuntu
  • centos双网卡配置超详细
  • xp系统怎么卸载程序
  • centos inode
  • 笔记本运行WINCC不显示全屏
  • w7提高开机速度
  • 开启win7
  • 简单实现多彩慕斯蛋糕淋面的方法
  • js中用var定义变量的格式
  • nodejs基础教程
  • 关于批处理的说法错误的是
  • django 自定义权限管理
  • jquery操作元素内容的方法
  • android四大组件的作用
  • 怎么在电脑上下载浙政钉
  • 宁夏地税领导班子名单
  • 资源税从价计征的有哪些
  • 国税三所电话
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设