位置: IT常识 - 正文

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

编辑:rootadmin
【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

推荐整理分享【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

新年新气象,兄弟们新年快乐。撒花!!!

之前我们的项目已经讲过了常见的4种深度学习任务(当然还有一些没有接触到的,例如GAN和今年大红的Transformer),今天这个blog我们就来谈谈一谈常见的损失函数。损失函数的更新也是非常的快,各位大佬的想法也是层出不穷,我们站在巨人的肩膀上,就可以看的更远,走的更远。

1.BCELoss

BCELoss又叫二分类交叉熵损失,顾名思义,它是用来做二分类的损失函数,我们先来看看BCELoss的公式。

其中pt---模型预测值,target---标签值, w---权重值,一般是1

上面这个公式是单个样本的,当一个batch有N个样本时

这么说是不是显得很苍白无力,所以我们来一个例子吧,我们先创建一个pt和target

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print(pt)print(target)

pt我用的随机数代替的,target一般是0或者1,我们print一下,看看目前的数值是多少

这里的torch.rand(2, 3)中的2代表2个样本,3代表每个样本是一个1*3的向量

好了,我们来挨个计算:

pt的第一行第一列的值是 0.4963,它对应的标签target的第一行第一列的值是0,所以求根据刚才的公式L(pt,target) = -w*(target * ln(pt) + (1-target) * ln(1-pt)),w一般取1

L = -1 * (0*ln(0.4963)+1*ln(1-0.4963)) = -ln(1-0.4963) = 0.685774426230532 ≈ 0.6857

pt的第一行第二列的值是 0.7682,它对应的标签target的第一行第一列的值是0

L = -1 * (0*ln(0.7682)+1*ln(1-0.7682)) = -ln(1-0.7682) = 1.46188034807648  ≈ 1.4620

pt的第一行第三列的值是 0.0885,它对应的标签target的第一行第一列的值是1

L = -1 * (1*ln(0.0885)+0*ln(1-0.0885)) = -ln(0.0885) = -2.424752726968253  ≈ 2.4250

接下去,我就不算了,留个兄弟们来算,我们用代码来验证一下算对了没有吧

def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss)

此时x就是pt,y也就是target,值得注意的是torch.log = ln,它不是真的log,看看计算结果吧

看第一行,和我们刚刚的计算结果完全吻合,确实是这么算的,没跑了

别忘了,同时每一个样本也要求一下平均值

第一个样本的平均值是 (0.6857 + 1.4620 + 2.4250)/ 3 = 1.524233333333333333333

第二个样本的平均值是 (2.0247 + 0.3673 + 1.0053)/ 3 = 1.132433333333333333333

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

根据公式:

 所以loss = (1.524233333333333333333 + 1.132433333333333333333)/ 2 ≈1.328333

 上代码看看是不是这么回事吧

torch.manual_seed(0) pt = torch.rand(2, 3) target = torch.tensor([[0., 0., 1.], [1., 0., 0.]]) print('pt:',pt) print('target:',target) def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss) losss = torch.mean(com(pt, target)) print('总loss:',losss)

看看结果

 不错,一模一样,算对了。但是你肯定有疑问了,你这是你自己手算的,代码也是你自己写的,你只能证明你的计算和你的代码是对上了,怎么证明真正的和BCELoss对上了,那我们请出Pytorch的nn.BCELoss来看看结果吧

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print('pt:',pt)print('target:',target)loss = nn.BCELoss()print('pytorch loss:',loss(pt, target))

怎么样,我是不是算对了。

值得注意的是,在用BCELoss的时候,要记得先经过一个sigmoid或者softmax,以保证pt是0-1之间的。当然了,pytorch不可能想不到这个啊,所以它还提供了一个函数nn.BCEWithLogitsLoss()他会自动进行sigmoid操作。棒棒的!

2.带权重的BCELoss

先看看BCELoss的公式,w就是所谓的权重

 torch.nn.BCELoss()中,其实提供了一个weight的参数

我们要保持weight的形状和维度与target一致就可以了。

于是我手写一个带权重BCELoss,上代码

class BCE_WITH_WEIGHT(torch.nn.Module): def __init__(self, alpha=0.25, reduction='mean'): super(BCE_WITH_WEIGHT, self).__init__() self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5)) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

 核心带代码是

loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5))

alpha就是权重了,一般很多时候,正负样本是不平衡的,如果不加入权重,网络训练的时候,训练的关注的重点就跑到了样本多的那一类样本上去,对样本少的就不公平了,所以为了维护世界和平,贯彻爱与真实的邪恶,可爱又迷人的反派角色,带权重的损失函数就出现了。

大家可以看到,我在有一个地方是torch.log(pt+1e-5),1e-5的意思就是10的-5次方,为什么要加入1e-5,这个跟ln函数有关系,因为ln(0) = -无穷大,这样损失就爆炸了,训练就会出错误,所以默认就把它加上了。

3.BCE版本的Focal_Loss

FocalLoss的公式

此时的pt就是刚刚的那个pt了,此时的pt就是刚刚我们的BCEloss的结果了 

先上代码看看吧

class BCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='mean'): super(BCEFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt++1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5))) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

核心代码:

loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt+1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5)))

Focalloss的目前不仅是为了控制样本不平衡的现象,还有个作用就是,让网络着重训练难样本。

好了,BCE讲的差不多了,讲的不对的地方,欢迎大家指出。

至此,敬礼,salute!!!

老规矩,上咩咩狗

本文链接地址:https://www.jiuchutong.com/zhishi/299220.html 转载请保留说明!

上一篇:element-ui动态表单和验证(elementui动态表单数据回显)

下一篇:【HTML+CSS】实现网页的导航栏和下拉菜单(html cssjs)

  • 如何知道对方有没有删除自己微信(如何知道对方有没有屏蔽我的朋友圈)

    如何知道对方有没有删除自己微信(如何知道对方有没有屏蔽我的朋友圈)

  • 华为手机怎么清理缓存(华为手机怎么清后幕)

    华为手机怎么清理缓存(华为手机怎么清后幕)

  • windows101909版本要不要升级(windows101909版本怎么屏蔽更新)

    windows101909版本要不要升级(windows101909版本怎么屏蔽更新)

  • 剪映怎么降噪不成功呢(剪映怎么设置降噪)

    剪映怎么降噪不成功呢(剪映怎么设置降噪)

  • 静音模式震动啥意思(静音模式震动总是自动关闭)

    静音模式震动啥意思(静音模式震动总是自动关闭)

  • oppoa92s支持微信视频美颜吗(oppoa92s微信视频聊天怎么开美颜)

    oppoa92s支持微信视频美颜吗(oppoa92s微信视频聊天怎么开美颜)

  • 华为P30怎么设置应用锁(华为p30怎么设置锁屏壁纸)

    华为P30怎么设置应用锁(华为p30怎么设置锁屏壁纸)

  • 操作系统的分类(操作系统的分类及具体举例)

    操作系统的分类(操作系统的分类及具体举例)

  • 抖音被别人实名认证了怎么办(抖音被别人实名认证了而且封禁了)

    抖音被别人实名认证了怎么办(抖音被别人实名认证了而且封禁了)

  • pad屏幕无法旋转(pad屏幕旋转不了)

    pad屏幕无法旋转(pad屏幕旋转不了)

  • 为什么苹果录屏很模糊(为什么苹果录屏资源写入器无法存储)

    为什么苹果录屏很模糊(为什么苹果录屏资源写入器无法存储)

  • 知乎如何设置不允许回复(知乎如何设置不让别人看见收藏)

    知乎如何设置不允许回复(知乎如何设置不让别人看见收藏)

  • 小米6支持pd快充吗(小米6支持pd快充嘛)

    小米6支持pd快充吗(小米6支持pd快充嘛)

  • 半导体数码显示器的内部接法有哪两种形式(半导体数码显示器的内部接法有两种形式: 接法和 接法)

    半导体数码显示器的内部接法有哪两种形式(半导体数码显示器的内部接法有两种形式: 接法和 接法)

  • 手机上怎么把两张照片合成一张(手机上怎么把两张图片叠在一起)

    手机上怎么把两张照片合成一张(手机上怎么把两张图片叠在一起)

  • 手机快捷图标怎么找回(手机快捷图标怎么删除掉)

    手机快捷图标怎么找回(手机快捷图标怎么删除掉)

  • vivox20有nfc功能吗(vivox20有没有nfc)

    vivox20有nfc功能吗(vivox20有没有nfc)

  • wps2019高级选项在哪里(wps的高级选项卡)

    wps2019高级选项在哪里(wps的高级选项卡)

  • 荣耀手环3怎么强制关机(荣耀手环3怎么连接手机)

    荣耀手环3怎么强制关机(荣耀手环3怎么连接手机)

  • 修改APN有什么坏处(apn更改会怎么样)

    修改APN有什么坏处(apn更改会怎么样)

  • 抖音可以自动播放下一条吗(抖音怎么下载视频到手机)

    抖音可以自动播放下一条吗(抖音怎么下载视频到手机)

  • 如何把照片变成表情包(如何把照片变成卡通图)

    如何把照片变成表情包(如何把照片变成卡通图)

  • 横向表格怎么制作(横向表格怎么制作日期姓名数量)

    横向表格怎么制作(横向表格怎么制作日期姓名数量)

  • 新版Edge浏览器开启“在关闭多个标签页之前询问”功能(新版edge浏览器如何恢复设置)

    新版Edge浏览器开启“在关闭多个标签页之前询问”功能(新版edge浏览器如何恢复设置)

  • 前端 Git-Hooks 工程化实践(git web hook)

    前端 Git-Hooks 工程化实践(git web hook)

  • Discus X 3 门户改造熊掌号网页教程

    Discus X 3 门户改造熊掌号网页教程

  • 个人去税务局开劳务费怎么交税
  • 怎么确认债权
  • 企业所得税减免优惠政策
  • 企业会计准则可以中途变更吗
  • 2021成品油增值税计算
  • 小规模纳税人农产品进项税抵扣
  • 公司报销费用发票怎么开
  • 海关废品回收
  • 加工企业购入辅料记入什么科目?
  • 合同印花税进哪个科目
  • 钢化玻璃税率是多少?
  • 一个月无纳税凭证怎么处理
  • 一般纳税人转为小规模2022政策
  • 已付款未收到发票
  • 个税附加扣除如何填写合适
  • 营业外收入怎么申报
  • 进项税和销项税的借贷方向
  • 借款本金和借款余额
  • 工程哪些材料可以做
  • 售后回租 出租方
  • 1697506445
  • 如何查询发票是否验旧
  • 民办非企业年底额度不能低于多少
  • 企业与企业之间借款账务如何处理
  • 结转到生产成本的科目
  • 地方水利建设基金
  • 仓储费计入存货成本吗
  • 子公司开票给母公司,冲减利润,怎么避免税务风险
  • 增值税专用发票几个点
  • 液晶显示器容易坏点
  • uniapp引入全局scss
  • 阿里云onedata
  • 广告代理费制度
  • 工程实践指的是
  • transformer for
  • 什么是男人无力的行为
  • 美国人用什么英语词典
  • python累加求和代码,直到最后一项小于10^-6
  • 丧葬补贴金和抚恤金怎样领取
  • 进项税转出能转回吗
  • 发票管理办法是法律吗
  • 管理费用属于费用类吗
  • 应交税费在会计科目的借贷方向
  • 汇算清缴利润调增70万会预警吗
  • 汇算清缴时发现收入少了
  • 商业银行提取的盈余公积可用于
  • 长期待摊费用如何评估
  • 经营费用属于什么类科目
  • 私募基金如何做账
  • 充值优惠怎么写
  • 以旧换新会计科目
  • 车辆不在公司名下加油费可以进公司吗
  • 电子记录表怎么填写
  • 资产负债表的期末数是指什么
  • 企业支付宝买东西怎么买
  • 工程预算费用怎么做会计分录
  • 保险公司理赔款如何入账
  • 预付差旅费属于什么类型
  • mysql %s
  • xp系统直接开机
  • 戴尔电脑u盘快速启动
  • 微软今天正式停产了吗
  • win10网卡驱动不正常连不上网怎么办
  • win8可以装pr2018吗
  • 如何让卖家给你乖乖退款
  • 批处理常用命令总结
  • ogre 引擎
  • 批处理自动关机命令
  • jquery的加载事件
  • Node.js中的什么模块是用于处理文件和目录的
  • 象棋软件编程
  • 百度贴吧上传图片大小
  • 有关中秋节的古诗
  • jquery瀑布流
  • BootStrap glyphicon图标无法显示的解决方法
  • python语言基本语法
  • python socket用法
  • javascript程序设计教程
  • android studio操作指南
  • 河北电子税务局开票流程
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设