位置: IT常识 - 正文

loss.item()用法和注意事项详解(loss for)

编辑:rootadmin
loss.item()用法和注意事项详解

推荐整理分享loss.item()用法和注意事项详解(loss for),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:loss at,loss at,loss method,loss from,loss at,loss的用法,loss into for,loss=mse,内容如对您有帮助,希望把文章链接给更多的朋友!

.item()方法是,取一个元素张量里面的具体元素值并返回该值,可以将一个零维张量转换成int型或者float型,在计算loss,accuracy时常用到。

作用:

1.item()取出张量具体位置的元素元素值 2.并且返回的是该位置元素值的高精度值 3.保持原元素类型不变;必须指定位置

4.节省内存(不会计入计算图)

import torchloss = torch.randn(2, 2)print(loss)print(loss[1,1])print(loss[1,1].item())

输出结果

tensor([[-2.0274, -1.5974],         [-1.4775,  1.9320]]) tensor(1.9320) 1.9319512844085693

其它:loss = criterion(out, label) loss_sum += loss # <--- 这里

运行着就发现显存炸了,观察发现随着每个batch显存消耗在不断增大…因为输出的loss的数据类型是Variable。PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大,那么消耗的显存也就越来越大。

loss.item()用法和注意事项详解(loss for)

正确的loss一般是这样写 

loss_sum += loss.data[0]

其它注意事项:

使用loss += loss.detach()来获取不需要梯度回传的部分。

使用loss.item()直接获得对应的python数据类型。

补充阅读,pytorch 计算图

Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。

Pytorch中的计算图是动态图。这里的动态主要有两重含义。

第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。

第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用torch.autograd.grad方法计算了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。

1,计算图的正向传播是立即执行的。

import torchw = torch.tensor([[3.0,1.0]],requires_grad=True)b = torch.tensor([[3.0]],requires_grad=True)X = torch.randn(10,2)Y = torch.randn(10,1)Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关loss = torch.mean(torch.pow(Y_hat-Y,2))print(loss.data)print(Y_hat.data)tensor(17.8969)tensor([[3.2613], [4.7322], [4.5037], [7.5899], [7.0973], [1.3287], [6.1473], [1.3492], [1.3911], [1.2150]])

2,计算图在反向传播后立即销毁。

import torchw = torch.tensor([[3.0,1.0]],requires_grad=True)b = torch.tensor([[3.0]],requires_grad=True)X = torch.randn(10,2)Y = torch.randn(10,1)Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关loss = torch.mean(torch.pow(Y_hat-Y,2))#计算图在反向传播后立即销毁,如果需要保留计算图, 需要设置retain_graph = Trueloss.backward() #loss.backward(retain_graph = True) #loss.backward() #如果再次执行反向传播将报错

参考链接:pytorch学习:loss为什么要加item()_dlvector的博客-CSDN博客_loss.item()

https://blog.csdn.net/cs111211/article/details/126221102

本文链接地址:https://www.jiuchutong.com/zhishi/298517.html 转载请保留说明!

上一篇:神经网络模型之BP算法及实例分析(神经网络模型是干嘛的)

下一篇:UNIAPP手机号一键登录(uniapp获取手机通讯录)

  • 应交增值税月末出现借方余额怎么处理
  • 代销商品怎么交增值税
  • 应征增值税不含税销售额(3%征收率)怎么填2020年
  • 劳务收入个税需要进行所得税汇算吗
  • 10万以内免交的增值税怎么做帐
  • 计税金额是含税还是不含税
  • 一般纳税人税种认定有几个增值税要怎么申报呀
  • 突然收到银联入账收入怎么办
  • 一般开发间接费
  • 营业执照的注册地址怎么填
  • 营改增后取得施工作业收入需要交哪些税?
  • 税审需要什么资料和材料
  • 增值税减除后附加税计算方法
  • 施工单位项目部牌子
  • 小规模纳税人月末结转增值税
  • 不同税率的依据
  • 财税2018年39号公告残保金
  • 一年内到期的应收质保金
  • 医院医保统筹支付后还能报销吗
  • 管理不善存货损失 企业所得税
  • 税控设备实际抵减增值税时如何做分录?
  • 计提企业所得税的账务处理
  • 实验设备折旧率
  • 利润表的上期金额和本期金额之间的关系
  • 当月没有进项税额抵扣怎么办
  • 反映留存收益的账户
  • 存货折扣怎样做账
  • 购入的财务软件怎么入账
  • 补税后算偷税漏税吗
  • 长期待摊费用挂账原因
  • Win11错误提示"the pc must support secure boot"怎么解决
  • php vr
  • php数组函数大全
  • 新准则规定
  • 短期投资计入什么科目
  • php中.的作用
  • 视同销售的行为
  • 纳入资本公积
  • 资本公积属于谁
  • deepwiser怎么用
  • 为什么生产工人工资不属于固定成本
  • vue开发教程
  • 发行债券支付的费用要减吗
  • 管理费用属于什么类
  • 培训费开票属于哪个征收明目
  • python从键盘输入正整数n,计算1+2+3
  • 为什么盈余公积减少,未分配利润增加
  • 房地产企业卖房子增值税税率
  • python中的比较
  • 发票校验码被章盖住了
  • 普票和专票的
  • sql 数据计算
  • 房产税征收对象和依据2021
  • 市政道路基础设施
  • 厂家给经销商的活动方案怎么写
  • 公司车辆违章
  • 民间非盈利组织使用什么会计准则
  • 注册资金没有到位
  • 银行收取服务费
  • 记账凭证应交税费填写样本图片
  • 固定资产转让开票大类是什么
  • 其他应付款长期挂账违反什么规定
  • 核定征收方式有哪些
  • windows 10的安装
  • mysql字段超长
  • windows9怎么截图
  • 容器内存限制
  • windows2008无法识别usb
  • win10升级win1
  • Win10 Build 14279正式推送 更新后QQ可能会崩溃
  • js定义函数的几种方法
  • 控制数值颜色
  • 不用javascript可以吗
  • android实现简单的计算机界面
  • python教程目录
  • js继承的方法
  • 山东省国家税务局官网
  • 金华市税务
  • 网上如何申请
  • 青岛工商全程电子化
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设