位置: IT常识 - 正文

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

编辑:rootadmin
【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

推荐整理分享【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

新年新气象,兄弟们新年快乐。撒花!!!

之前我们的项目已经讲过了常见的4种深度学习任务(当然还有一些没有接触到的,例如GAN和今年大红的Transformer),今天这个blog我们就来谈谈一谈常见的损失函数。损失函数的更新也是非常的快,各位大佬的想法也是层出不穷,我们站在巨人的肩膀上,就可以看的更远,走的更远。

1.BCELoss

BCELoss又叫二分类交叉熵损失,顾名思义,它是用来做二分类的损失函数,我们先来看看BCELoss的公式。

其中pt---模型预测值,target---标签值, w---权重值,一般是1

上面这个公式是单个样本的,当一个batch有N个样本时

这么说是不是显得很苍白无力,所以我们来一个例子吧,我们先创建一个pt和target

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print(pt)print(target)

pt我用的随机数代替的,target一般是0或者1,我们print一下,看看目前的数值是多少

这里的torch.rand(2, 3)中的2代表2个样本,3代表每个样本是一个1*3的向量

好了,我们来挨个计算:

pt的第一行第一列的值是 0.4963,它对应的标签target的第一行第一列的值是0,所以求根据刚才的公式L(pt,target) = -w*(target * ln(pt) + (1-target) * ln(1-pt)),w一般取1

L = -1 * (0*ln(0.4963)+1*ln(1-0.4963)) = -ln(1-0.4963) = 0.685774426230532 ≈ 0.6857

pt的第一行第二列的值是 0.7682,它对应的标签target的第一行第一列的值是0

L = -1 * (0*ln(0.7682)+1*ln(1-0.7682)) = -ln(1-0.7682) = 1.46188034807648  ≈ 1.4620

pt的第一行第三列的值是 0.0885,它对应的标签target的第一行第一列的值是1

L = -1 * (1*ln(0.0885)+0*ln(1-0.0885)) = -ln(0.0885) = -2.424752726968253  ≈ 2.4250

接下去,我就不算了,留个兄弟们来算,我们用代码来验证一下算对了没有吧

def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss)

此时x就是pt,y也就是target,值得注意的是torch.log = ln,它不是真的log,看看计算结果吧

看第一行,和我们刚刚的计算结果完全吻合,确实是这么算的,没跑了

别忘了,同时每一个样本也要求一下平均值

第一个样本的平均值是 (0.6857 + 1.4620 + 2.4250)/ 3 = 1.524233333333333333333

第二个样本的平均值是 (2.0247 + 0.3673 + 1.0053)/ 3 = 1.132433333333333333333

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

根据公式:

 所以loss = (1.524233333333333333333 + 1.132433333333333333333)/ 2 ≈1.328333

 上代码看看是不是这么回事吧

torch.manual_seed(0) pt = torch.rand(2, 3) target = torch.tensor([[0., 0., 1.], [1., 0., 0.]]) print('pt:',pt) print('target:',target) def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss) losss = torch.mean(com(pt, target)) print('总loss:',losss)

看看结果

 不错,一模一样,算对了。但是你肯定有疑问了,你这是你自己手算的,代码也是你自己写的,你只能证明你的计算和你的代码是对上了,怎么证明真正的和BCELoss对上了,那我们请出Pytorch的nn.BCELoss来看看结果吧

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print('pt:',pt)print('target:',target)loss = nn.BCELoss()print('pytorch loss:',loss(pt, target))

怎么样,我是不是算对了。

值得注意的是,在用BCELoss的时候,要记得先经过一个sigmoid或者softmax,以保证pt是0-1之间的。当然了,pytorch不可能想不到这个啊,所以它还提供了一个函数nn.BCEWithLogitsLoss()他会自动进行sigmoid操作。棒棒的!

2.带权重的BCELoss

先看看BCELoss的公式,w就是所谓的权重

 torch.nn.BCELoss()中,其实提供了一个weight的参数

我们要保持weight的形状和维度与target一致就可以了。

于是我手写一个带权重BCELoss,上代码

class BCE_WITH_WEIGHT(torch.nn.Module): def __init__(self, alpha=0.25, reduction='mean'): super(BCE_WITH_WEIGHT, self).__init__() self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5)) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

 核心带代码是

loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5))

alpha就是权重了,一般很多时候,正负样本是不平衡的,如果不加入权重,网络训练的时候,训练的关注的重点就跑到了样本多的那一类样本上去,对样本少的就不公平了,所以为了维护世界和平,贯彻爱与真实的邪恶,可爱又迷人的反派角色,带权重的损失函数就出现了。

大家可以看到,我在有一个地方是torch.log(pt+1e-5),1e-5的意思就是10的-5次方,为什么要加入1e-5,这个跟ln函数有关系,因为ln(0) = -无穷大,这样损失就爆炸了,训练就会出错误,所以默认就把它加上了。

3.BCE版本的Focal_Loss

FocalLoss的公式

此时的pt就是刚刚的那个pt了,此时的pt就是刚刚我们的BCEloss的结果了 

先上代码看看吧

class BCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='mean'): super(BCEFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt++1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5))) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

核心代码:

loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt+1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5)))

Focalloss的目前不仅是为了控制样本不平衡的现象,还有个作用就是,让网络着重训练难样本。

好了,BCE讲的差不多了,讲的不对的地方,欢迎大家指出。

至此,敬礼,salute!!!

老规矩,上咩咩狗

本文链接地址:https://www.jiuchutong.com/zhishi/299220.html 转载请保留说明!

上一篇:element-ui动态表单和验证(elementui动态表单数据回显)

下一篇:【HTML+CSS】实现网页的导航栏和下拉菜单(html cssjs)

  • 微信零钱单日限额多少(微信零钱单日限额10000怎么解除)

    微信零钱单日限额多少(微信零钱单日限额10000怎么解除)

  • 腾讯视频小鹅花钱在哪(腾讯视频小鹅花钱征信不好能申请吗)

    腾讯视频小鹅花钱在哪(腾讯视频小鹅花钱征信不好能申请吗)

  • 笔记本有麦克风吗(笔记本有麦克风功能吗)

    笔记本有麦克风吗(笔记本有麦克风功能吗)

  • loading页面是什么意思 (loading界面什么意思)

    loading页面是什么意思 (loading界面什么意思)

  • ed2k是什么文件(ed2k文件下载软件)

    ed2k是什么文件(ed2k文件下载软件)

  • qq举报人了对自己有影响吗(qq 举报人后会被他本人发现吗)

    qq举报人了对自己有影响吗(qq 举报人后会被他本人发现吗)

  • 苹果x屏幕模糊怎么调(苹果x显示效果模糊)

    苹果x屏幕模糊怎么调(苹果x显示效果模糊)

  • 华为手环app叫什么(华为手环app叫什么名称)

    华为手环app叫什么(华为手环app叫什么名称)

  • 学生用计算器咋关机(学生用计算器咋开4次根号)

    学生用计算器咋关机(学生用计算器咋开4次根号)

  • 华为笔记本开机键在哪(华为笔记本开机怎么开)

    华为笔记本开机键在哪(华为笔记本开机怎么开)

  • 微信受限制几天能恢复(微信受限制几天会解封)

    微信受限制几天能恢复(微信受限制几天会解封)

  • ipad里面的计算器在哪里找(ipad上面的计算器)

    ipad里面的计算器在哪里找(ipad上面的计算器)

  • 红包封面序列号怎么弄(微信红包封面序列号)

    红包封面序列号怎么弄(微信红包封面序列号)

  • 主副显示器切换快捷键(主显示器副显示器)

    主副显示器切换快捷键(主显示器副显示器)

  • 手机充电口小芯片坏了(手机充电口小芯片坏了多少钱)

    手机充电口小芯片坏了(手机充电口小芯片坏了多少钱)

  • 通常所说的cache是指(通常所说的I/O设备指什么)

    通常所说的cache是指(通常所说的I/O设备指什么)

  • vivo手机怎么开启nfc(vivo手机怎么开空调)

    vivo手机怎么开启nfc(vivo手机怎么开空调)

  • 头条小视频怎么能保存下来(头条小视频怎么清屏播放)

    头条小视频怎么能保存下来(头条小视频怎么清屏播放)

  • a1524的iphone 6 plus是什么版本(a1524的iphone 6 plus可以用电信卡吗)

    a1524的iphone 6 plus是什么版本(a1524的iphone 6 plus可以用电信卡吗)

  • win10锁屏壁纸在哪个文件夹(win10锁屏壁纸在哪)

    win10锁屏壁纸在哪个文件夹(win10锁屏壁纸在哪)

  • Vivox27手机返回键怎么设置(vivox27手机返回键)

    Vivox27手机返回键怎么设置(vivox27手机返回键)

  • iphonex怎样正确充电(苹果x操作技巧)

    iphonex怎样正确充电(苹果x操作技巧)

  • 手机无法加载插件设置(手机无法加载插件怎么处理)

    手机无法加载插件设置(手机无法加载插件怎么处理)

  • word第一页不要页码

    word第一页不要页码

  • 借款约定分期还,如何计算诉讼时效?(借款人约定分期还款中途可以起诉吗)

    借款约定分期还,如何计算诉讼时效?(借款人约定分期还款中途可以起诉吗)

  • 完美解决 Error: Cannot find module ‘@vue/cli-plugin-eslint‘ 报错(完美解决战网已休眠正在唤醒它)

    完美解决 Error: Cannot find module ‘@vue/cli-plugin-eslint‘ 报错(完美解决战网已休眠正在唤醒它)

  • 【前端】Vue+Element UI案例:通用后台管理系统-导航栏(前端vue3)

    【前端】Vue+Element UI案例:通用后台管理系统-导航栏(前端vue3)

  • 出口退税新政策报关费发票要怎么开
  • 捐赠支出税前扣除条件
  • 关联企业的判定标准
  • 安全生产费用怎么入账
  • 应付账款周转率和存货周转率公式
  • 个体工商户2023年税收政策
  • 附加税费申报没有怎么填
  • 定期定额的个税起征点
  • 减税降费对企业的影响案例
  • 公司帮其他单位开发票违法吗
  • 向职工支付职工福利费
  • 煤炭企业补偿费会计分录
  • 资产减值确定计量原则包括哪些
  • 购置税发票如何下载
  • 普通发票扣税
  • 有留抵税额的会计处理
  • 支付明年报刊费
  • 预收账款在什么科目核算
  • 网上纳税申报怎么填
  • 未分配利润应该在借方还是贷方??
  • 案例分析建筑业发展趋势
  • 污泥处置中心所得税优惠政策
  • 税盘忘记清盘了怎么办
  • 年所得12万元以上的纳税人,在纳税年度终了后
  • 损失性费用的会计科目有
  • 建筑劳务预缴税款后怎么申报
  • win10打开游戏老是提示
  • 公司名下汽车过户给个人
  • 违约金条款的特点
  • winds10教育版
  • php嵌入js
  • laravel跨库查询
  • 收到应缴财政款
  • 评估价格是按原值还是净值
  • 企业要方便客户与企业的沟通,尽可能降低
  • 扣除年度未扣除怎么计算
  • 抵债资产怎么入账
  • 退回以前年度多交的附加税怎么做分录
  • php后端主要会涉及到哪些技术
  • next frame
  • Vue中 Vue-Baidu-Map基本使用
  • sbc奇思妙想
  • 工伤保险赔付计算
  • 帝国cms怎么增加子栏目
  • php登录不了
  • 发票一式两联
  • 母子公司吸收合并的税收有哪些
  • 金蝶新建账套如何录入固定资产账套
  • 利润表中的利息费用是利息支出吗
  • 自制小汽车
  • 个税年终奖计算方法2022税率表
  • 营业外支出在贷方
  • 装修费应该按几折算
  • 发工资摘要没写工资
  • 公司送礼怎么记账
  • 收房租的收据怎么写
  • 冲减上年度多计提所得税
  • 季报能弥补以前年度亏损吗
  • 收到退回多付的材料退款
  • 税控盘的会计分录怎么做
  • 餐饮费与业务招标的关系
  • mysql的json数据类型
  • Linux下卸载MySQL数据库
  • windows mysql my.cnf
  • 任务管理器边框怎么设置
  • Win10 Mobile 10581预览版升级界面曝光 上手视频观赏
  • geom是什么文件
  • windows关机电源不断电
  • win7的记事本在哪里打开
  • win7的环境变量如何还原
  • android observer
  • android项目打包成jar包
  • linux压缩命令compress
  • javascript数据类型有哪些
  • jquery插件库免费
  • 临时税务登记证是什么意思
  • 安徽省渔业管理办法第十条规定
  • 天津环保网站官网
  • 什么是税务证书密码
  • 实名办税人员承诺书范本
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设