位置: IT常识 - 正文

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

编辑:rootadmin
关于Pytorch中的train()和eval()(以及no_grad()) 1、三剑客:train()、eval()、no_grad()1.1 train()1.2 eval()1.3 no_grad()2、简单分析下2.1 为什么要使用train()和eval()2.2 为什么可以把训练集的统计量用作测试集?3、我的坑

推荐整理分享关于Pytorch中的train()和eval()(以及no_grad())(pytorch train()),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch trace,pytorch tqdm,pytorch trace,pytorch jit trace,pytorch trace,pytorch tqdm,pytorch tqdm,pytorch trace,内容如对您有帮助,希望把文章链接给更多的朋友!

起源是我训练好了一个模型,新建一个推理脚本加载好checkpoint和预处理输入后推理,发现无论输入是哪一类甚至是随机数,其输出概率总是第一类的值最大,且总是在0.5附近,排查许久,发现是没有加上model.eval()函数。

因为我使用了model.no_grad(),下意识认为不需要加model.eval(),导致发生了本次事故

1、三剑客:train()、eval()、no_grad()

这三个函数实际上很常见,先来简单看下使用方法

1.1 train()

train()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.train()表示将该model设置为训练模式,一般在开始新epoch训练时,我们会首先执行该命令:

...model.train()# 将模型设置为训练模式for i, data in enumerate(train_loader): # 开始新epoch的训练images, labels = data images, labels = images.to(device), labels.to(device)...

1.2 eval()

同train()一样,其用法和含义也一样,eval()是nn.Module的方法,也就是你定义了一个网络model,那么mdoel.eval()表示将该model设置为验证模式,一般在开始验证当前model效果时,我们会首先执行该命令:

...model.eval()# 将模型设置为验证模式for i, data in enumerate(eval_loader): # 在验证集上验证images, labels = data images, labels = images.to(device), labels.to(device)...

1.3 no_grad()

no_grad()是torch库的方法,和上下文管理器with来搭配使用。 其作用是禁用梯度计算,当你确定不会调用tensor.backward()时。它将减少计算的内存消耗,否则这些计算将requires_grad=True。 如果设定了no_grad(),即使输入张量属性requires_grad为True,也不会计算梯度

一般我们进行模型验证或者模型推理时,就不需要梯度以及反向传播,所以我们可以在torch.no_grad()上下文管理器中执行我们的验证或推理任务,可以显著降低显存的使用。

with torch.no_grad():output=model(input_tensor)# 模型推理print(output) # model推理才涉及梯度等,print都不涉及了,所以在不在with之中已经无所谓了2、简单分析下2.1 为什么要使用train()和eval()

我们知道nn.Module中的BN层可以加速收敛,但是该层需要计算输入BatchTensor的均值和方差,毕竟一个BatchSize为64、128甚至更大,计算他们的均值和方差也简单。

关于Pytorch中的train()和eval()(以及no_grad())(pytorch train())

但问题是,当我们推理时,去对一张图像进行推理时,计算到BN层也需要该批次的均值和方差。但是现在就一个tensor,计算其均值和方差是没有意义的(一个样本的均值和方差统计量说明不了什么)。

实际上在推理时BN所需要的均值和方差是训练时的值(可以理解为训练时把训练样本的均值和方差记录下来了)。

问题来了,模型怎么知道我现在是训练状态还是推理状态?

当model.train()时,模型处于训练状态,模型会计算Batch的均值和方差

当model.eval()时,模型处于验证状态,模型会使用训练集的均值和方差作为验证数据的均值和方差

同样的还有Dropout层,Dropout层在训练时会随机失活某些神经元,提高模型泛化能力,但是在验证推理时,Dropout层不需要再失活了,也就是所有的神经元都要“干活”了。

总之train()和eval()最主要就是影响了BN层和Dropout层

2.2 为什么可以把训练集的统计量用作测试集?

为什么可以把训练集的统计量用作测试集,因为无论是训练集、验证集还是测试机,甚至是没有收集到的同类图像,他们都是独立同分布的。

换句话说,世界上所有的猫的图片组成一个集合,那么这个集合就存在一个分布,这个分布就像高斯分布、泊松分布等,只不过这个猫的集合分布可能更加复杂,暂叫它猫分布吧。

这个猫分布中每一个样本都肯定是服从这个猫分布的,但同时这些样本互不相关联,我们把其中一部分拿来做训练集,再拿一小部分做测试集。

我们设计了一个模型在训练集上训练,因为训练集也服从猫分布,所以模型在训练集上“锻炼”出来的能力,就是从小块训练集去拟合整个猫分布。

即从少量猫图上去推理所有猫图,从而具有泛化能力,去推理没有见过的但同类的图像也有非常好的效果。但是这也容易造成管中窥豹,只看到事物的一部分,见不全面,所以模型又无法识别出所有的猫图。

3、我的坑

我下意识以为使用了no_grad()就不需要再设置了eval(),导致训练效果很好,自己以测试,其输出的概率毫无逻辑。

eval()是影响BN层和Dropout层 而no_grad()是不计算梯度 两个是风马牛不相及,当然搭配使用效果即好还剩内存!

本文链接地址:https://www.jiuchutong.com/zhishi/295386.html 转载请保留说明!

上一篇:面试官问出这几道算法题,你能扛住么?(面试官问几个问题)

下一篇:微信小程序【获取用户昵称头像和昵称(附源码)】(微信小程序获取位置信息的权限在哪里修改位置)

  • 让微信营销粉丝破万的七个关键词(微信营销粉丝获取渠道)

    让微信营销粉丝破万的七个关键词(微信营销粉丝获取渠道)

  • 荣耀X30max怎么拦截骚扰电话(荣耀30怎么设置拦截所有陌生号码)

    荣耀X30max怎么拦截骚扰电话(荣耀30怎么设置拦截所有陌生号码)

  • 网易云可以看好友在线吗(网易云可以看好友最近听过的歌吗)

    网易云可以看好友在线吗(网易云可以看好友最近听过的歌吗)

  • 苹果手机相册视频如何慢放(苹果手机相册视频怎么剪辑)

    苹果手机相册视频如何慢放(苹果手机相册视频怎么剪辑)

  • 运算器可以储存信息吗(运算器可以储存数据吗)

    运算器可以储存信息吗(运算器可以储存数据吗)

  • 鼠标没电了怎么办(鼠标没电了怎么换电池)

    鼠标没电了怎么办(鼠标没电了怎么换电池)

  • 键盘上找不到win键怎么办(键盘上找不到顿号怎么打)

    键盘上找不到win键怎么办(键盘上找不到顿号怎么打)

  • 华为play4t和play4tpro有什么区别(华为play4t和play3哪个好)

    华为play4t和play4tpro有什么区别(华为play4t和play3哪个好)

  • 表格合并单元格后怎么分成多行(excel表格合并单元格)

    表格合并单元格后怎么分成多行(excel表格合并单元格)

  • 五千毫安的充电宝能充几次(五千毫安的充电宝能带上高铁吗)

    五千毫安的充电宝能充几次(五千毫安的充电宝能带上高铁吗)

  • 联想电脑自带的联想电脑管家可以卸载吗(联想电脑自带的解压软件在哪里)

    联想电脑自带的联想电脑管家可以卸载吗(联想电脑自带的解压软件在哪里)

  • 电脑没反应按什么键恢复(电脑没反应什么原因)

    电脑没反应按什么键恢复(电脑没反应什么原因)

  • 路由器性能主要看什么(路由器性能主要包括)

    路由器性能主要看什么(路由器性能主要包括)

  • 苹果手机qq更新不了怎么办(苹果手机QQ更新了要重新输入密码吗)

    苹果手机qq更新不了怎么办(苹果手机QQ更新了要重新输入密码吗)

  • 华为手机突然图标变黑(华为手机突然图标乱了)

    华为手机突然图标变黑(华为手机突然图标乱了)

  • qq语音通话怎么开摄像头(qq语音通话怎么关闭对方声音)

    qq语音通话怎么开摄像头(qq语音通话怎么关闭对方声音)

  • 荣耀20电池能用多久(荣耀20电池用多久)

    荣耀20电池能用多久(荣耀20电池用多久)

  • 华为mate30pro有几个版本(华为mate30pro有几种型号)

    华为mate30pro有几个版本(华为mate30pro有几种型号)

  • 手机怎么看历史浏览记录(手机怎么看历史安装的软件)

    手机怎么看历史浏览记录(手机怎么看历史安装的软件)

  • 华为watch gt2怎么查看日常活动(华为watch gt2怎么添加跳绳)

    华为watch gt2怎么查看日常活动(华为watch gt2怎么添加跳绳)

  • oppor17长宽高多少厘米(oppor17长宽高多少)

    oppor17长宽高多少厘米(oppor17长宽高多少)

  • 荣耀20怎么清后台(荣耀20怎么清屏)

    荣耀20怎么清后台(荣耀20怎么清屏)

  • 新款iphone11双卡双待吗(苹果iphone11双卡)

    新款iphone11双卡双待吗(苹果iphone11双卡)

  • qq扩列如何提高人气(qq扩列分数怎么变高)

    qq扩列如何提高人气(qq扩列分数怎么变高)

  • airpods连接失败(airpods连接失败再试一次)

    airpods连接失败(airpods连接失败再试一次)

  • mysqlimport命令  MySQL服务器数据导入(mysql常用命令行大全)

    mysqlimport命令 MySQL服务器数据导入(mysql常用命令行大全)

  • 关税是直接税还是间接税
  • 什么是存货周转率?存货周转率的意义是什么
  • 厂部管理人员薪酬计入什么费用
  • etc发票开票中
  • 抵扣白条账单是怎么回事
  • 母子公司无偿划转股权印花税
  • 已经缴纳的税款在哪里查询
  • 贴现利息可以抵扣吗
  • 企业生产设备发生的日常维修费用
  • 新会计准则有预提费用吗
  • 冲销预付账款后该如何做账务处理呢?
  • 增值税税控设备服务费
  • 劳务公司购买材料怎么做账
  • 公司成立多久费用可进开办费
  • 增值税留抵税额抵减欠税
  • 外购的货物用于集体福利进项税额可以抵扣吗
  • 房产赠与流程是什么意思
  • 小规模纳税人附加税会计分录
  • 季度所得税报表怎么填
  • 报销差旅费抵扣进项税分录
  • 多交的增值税附加税怎么做账
  • 支付装修押金的会计科目
  • 教您电脑网速很慢怎么办
  • 各类基本社会保障性缴款是单位缴纳部分吗
  • 大型机械拆装
  • 委托代销委托方的账务处理
  • 债务重组是什么工作
  • 重楼的功效与作用价格
  • 代理金融业务
  • 交所得税的会计科目
  • 开票逃税的处罚
  • codecline
  • php内核剖析
  • 认缴没有实缴怎么做账
  • vite 配置
  • php字符串的三种定义方式
  • 当年实现的利润弥补以前年度亏损还是提盈余公积
  • pgadmin配置
  • thinkphp3.0
  • 超市账目月底怎么核算
  • 中国烟草资产负债表
  • 公司以现金形式发工资的最好解释
  • 营业收入小于利息收入
  • 预缴增值税附加税
  • 台账如何做到表中分好几个表
  • 支付水费委托收款
  • 增值税申报表填写顺序
  • 个体户年报纳税一般填多少合适
  • 进项发票认证后暂不抵扣
  • 注销的企业
  • 私营企业员工享受探亲假吗
  • 资本公积的意思是
  • 车辆违章有几种处理方法
  • 销售商品的折扣
  • 单位组织活动主持词
  • 商业折扣和销售折让计入财务费用吗
  • 发票抬头注意事项
  • 让渡是什么
  • 收到投资款怎么做凭证
  • 会计账簿按用途分类可以分为
  • windows注册表简单应用
  • 用u盘装系统怎么操作步骤
  • window配置在哪
  • win10鼠标怎么换
  • 注册表cmd
  • python绘制球面
  • jquery 定位
  • jquery-easyui-1.3.3
  • unity3d应用开发
  • php和mysql的结合是目前web开发中的黄金组合
  • Apache服务器的安全缺陷
  • nodejs中的session
  • 提高你工作效率的方法
  • android studio post请求数据获取
  • numpy基础知识
  • 广东增值税电子普通发票图片
  • 农机免税范围
  • 云南国税通用发票查询
  • 个人所得税申请专项扣除有什么用
  • 内蒙民生认证系统
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设