位置: IT常识 - 正文

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

编辑:rootadmin
【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

推荐整理分享【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

新年新气象,兄弟们新年快乐。撒花!!!

之前我们的项目已经讲过了常见的4种深度学习任务(当然还有一些没有接触到的,例如GAN和今年大红的Transformer),今天这个blog我们就来谈谈一谈常见的损失函数。损失函数的更新也是非常的快,各位大佬的想法也是层出不穷,我们站在巨人的肩膀上,就可以看的更远,走的更远。

1.BCELoss

BCELoss又叫二分类交叉熵损失,顾名思义,它是用来做二分类的损失函数,我们先来看看BCELoss的公式。

其中pt---模型预测值,target---标签值, w---权重值,一般是1

上面这个公式是单个样本的,当一个batch有N个样本时

这么说是不是显得很苍白无力,所以我们来一个例子吧,我们先创建一个pt和target

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print(pt)print(target)

pt我用的随机数代替的,target一般是0或者1,我们print一下,看看目前的数值是多少

这里的torch.rand(2, 3)中的2代表2个样本,3代表每个样本是一个1*3的向量

好了,我们来挨个计算:

pt的第一行第一列的值是 0.4963,它对应的标签target的第一行第一列的值是0,所以求根据刚才的公式L(pt,target) = -w*(target * ln(pt) + (1-target) * ln(1-pt)),w一般取1

L = -1 * (0*ln(0.4963)+1*ln(1-0.4963)) = -ln(1-0.4963) = 0.685774426230532 ≈ 0.6857

pt的第一行第二列的值是 0.7682,它对应的标签target的第一行第一列的值是0

L = -1 * (0*ln(0.7682)+1*ln(1-0.7682)) = -ln(1-0.7682) = 1.46188034807648  ≈ 1.4620

pt的第一行第三列的值是 0.0885,它对应的标签target的第一行第一列的值是1

L = -1 * (1*ln(0.0885)+0*ln(1-0.0885)) = -ln(0.0885) = -2.424752726968253  ≈ 2.4250

接下去,我就不算了,留个兄弟们来算,我们用代码来验证一下算对了没有吧

def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss)

此时x就是pt,y也就是target,值得注意的是torch.log = ln,它不是真的log,看看计算结果吧

看第一行,和我们刚刚的计算结果完全吻合,确实是这么算的,没跑了

别忘了,同时每一个样本也要求一下平均值

第一个样本的平均值是 (0.6857 + 1.4620 + 2.4250)/ 3 = 1.524233333333333333333

第二个样本的平均值是 (2.0247 + 0.3673 + 1.0053)/ 3 = 1.132433333333333333333

【Pytorch深度学习50篇】·······第六篇:【常见损失函数篇】-----BCELoss及其变种

根据公式:

 所以loss = (1.524233333333333333333 + 1.132433333333333333333)/ 2 ≈1.328333

 上代码看看是不是这么回事吧

torch.manual_seed(0) pt = torch.rand(2, 3) target = torch.tensor([[0., 0., 1.], [1., 0., 0.]]) print('pt:',pt) print('target:',target) def com(x, y): loss = -(y * torch.log(x) + (1 - y) * torch.log(1 - x)) return loss losss = com(pt, target) print(losss) losss = torch.mean(com(pt, target)) print('总loss:',losss)

看看结果

 不错,一模一样,算对了。但是你肯定有疑问了,你这是你自己手算的,代码也是你自己写的,你只能证明你的计算和你的代码是对上了,怎么证明真正的和BCELoss对上了,那我们请出Pytorch的nn.BCELoss来看看结果吧

torch.manual_seed(0)pt = torch.rand(2, 3)target = torch.tensor([[0., 0., 1.], [1., 0., 0.]])print('pt:',pt)print('target:',target)loss = nn.BCELoss()print('pytorch loss:',loss(pt, target))

怎么样,我是不是算对了。

值得注意的是,在用BCELoss的时候,要记得先经过一个sigmoid或者softmax,以保证pt是0-1之间的。当然了,pytorch不可能想不到这个啊,所以它还提供了一个函数nn.BCEWithLogitsLoss()他会自动进行sigmoid操作。棒棒的!

2.带权重的BCELoss

先看看BCELoss的公式,w就是所谓的权重

 torch.nn.BCELoss()中,其实提供了一个weight的参数

我们要保持weight的形状和维度与target一致就可以了。

于是我手写一个带权重BCELoss,上代码

class BCE_WITH_WEIGHT(torch.nn.Module): def __init__(self, alpha=0.25, reduction='mean'): super(BCE_WITH_WEIGHT, self).__init__() self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5)) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

 核心带代码是

loss = -((1-self.alpha) * target * torch.log(pt+1e-5) + self.alpha * (1 - target) * torch.log(1 - pt+1e-5))

alpha就是权重了,一般很多时候,正负样本是不平衡的,如果不加入权重,网络训练的时候,训练的关注的重点就跑到了样本多的那一类样本上去,对样本少的就不公平了,所以为了维护世界和平,贯彻爱与真实的邪恶,可爱又迷人的反派角色,带权重的损失函数就出现了。

大家可以看到,我在有一个地方是torch.log(pt+1e-5),1e-5的意思就是10的-5次方,为什么要加入1e-5,这个跟ln函数有关系,因为ln(0) = -无穷大,这样损失就爆炸了,训练就会出错误,所以默认就把它加上了。

3.BCE版本的Focal_Loss

FocalLoss的公式

此时的pt就是刚刚的那个pt了,此时的pt就是刚刚我们的BCEloss的结果了 

先上代码看看吧

class BCEFocalLoss(torch.nn.Module): def __init__(self, gamma=2, alpha=0.25, reduction='mean'): super(BCEFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, predict, target): pt = predict loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt++1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5))) if self.reduction == 'mean': loss = torch.mean(loss) elif self.reduction == 'sum': loss = torch.sum(loss) return loss

核心代码:

loss = - ((1 - self.alpha) * ((1 - pt+1e-5) ** self.gamma) * (target * torch.log(pt+1e-5)) + self.alpha * ( (pt+1e-5) ** self.gamma) * ((1 - target) * torch.log(1 - pt+1e-5)))

Focalloss的目前不仅是为了控制样本不平衡的现象,还有个作用就是,让网络着重训练难样本。

好了,BCE讲的差不多了,讲的不对的地方,欢迎大家指出。

至此,敬礼,salute!!!

老规矩,上咩咩狗

本文链接地址:https://www.jiuchutong.com/zhishi/299220.html 转载请保留说明!

上一篇:element-ui动态表单和验证(elementui动态表单数据回显)

下一篇:【HTML+CSS】实现网页的导航栏和下拉菜单(html cssjs)

  • 抖音私信对话框在哪(抖音私信对话框颜色深了)

    抖音私信对话框在哪(抖音私信对话框颜色深了)

  • 手机屏幕发白(手机屏幕发白闪烁怎么回事)

    手机屏幕发白(手机屏幕发白闪烁怎么回事)

  • oppo手机在设置里面麦克风是哪一个(oppo手机在设置里面怎么调字体大小)

    oppo手机在设置里面麦克风是哪一个(oppo手机在设置里面怎么调字体大小)

  • 华为怎么开启下面三个键(华为怎么开启下面的导航键)

    华为怎么开启下面三个键(华为怎么开启下面的导航键)

  • 安卓手机越更新越卡吗(安卓手机越更新越卡是真的吗)

    安卓手机越更新越卡吗(安卓手机越更新越卡是真的吗)

  • 荣耀30指纹识别不灵敏(荣耀30指纹识别没有了)

    荣耀30指纹识别不灵敏(荣耀30指纹识别没有了)

  • 大写锁定已打开是什么意思(大写锁定已打开卡在桌面)

    大写锁定已打开是什么意思(大写锁定已打开卡在桌面)

  • 抖音在线是什么意思(抖音在线是什么时候在线)

    抖音在线是什么意思(抖音在线是什么时候在线)

  • 微信有通知但没有消息(微信通知有信息,到微信里看不见)

    微信有通知但没有消息(微信通知有信息,到微信里看不见)

  • 为什么电脑上没有压缩这一功能(为什么电脑上没有word文档)

    为什么电脑上没有压缩这一功能(为什么电脑上没有word文档)

  • 为什么话费无故被扣除(为什么话费无故扣钱)

    为什么话费无故被扣除(为什么话费无故扣钱)

  • 华为怎么设置隐私相册(华为怎么设置隐藏键盘)

    华为怎么设置隐私相册(华为怎么设置隐藏键盘)

  • 200m宽带可以用千兆端口吗(200m宽带可以用多少流量)

    200m宽带可以用千兆端口吗(200m宽带可以用多少流量)

  • 关闭4g通话啥意思(4g+关掉)

    关闭4g通话啥意思(4g+关掉)

  • 快充手机可以充一夜吗(快充手机可以充一个晚上的电吗)

    快充手机可以充一夜吗(快充手机可以充一个晚上的电吗)

  • 华为freebuds3怎么关机(华为freebuds3怎么样)

    华为freebuds3怎么关机(华为freebuds3怎么样)

  • 搜电充电宝丢了怎么办(搜电充电宝丢了怎么归还)

    搜电充电宝丢了怎么办(搜电充电宝丢了怎么归还)

  • 微信朋友的新动态怎么关闭(微信朋友的新动态怎么找)

    微信朋友的新动态怎么关闭(微信朋友的新动态怎么找)

  • 电脑可以下载手写输入法吗(电脑可以下载手机版剪映吗)

    电脑可以下载手写输入法吗(电脑可以下载手机版剪映吗)

  • 怎么设置微信已读回执(怎么设置微信已读未读功能)

    怎么设置微信已读回执(怎么设置微信已读未读功能)

  • 闲鱼被骗投诉有用吗(闲鱼如果被骗了 举报官方会受理吗)

    闲鱼被骗投诉有用吗(闲鱼如果被骗了 举报官方会受理吗)

  • windows10有必要分区吗(win10有没有必要分盘)

    windows10有必要分区吗(win10有没有必要分盘)

  • vivox27能开空调吗(vivox27可以开空调)

    vivox27能开空调吗(vivox27可以开空调)

  • 火狐浏览器如何升级(火狐浏览器如何截图)

    火狐浏览器如何升级(火狐浏览器如何截图)

  • 如何使用网络测试分析仪?(如何进行网络测试网速测试)

    如何使用网络测试分析仪?(如何进行网络测试网速测试)

  • 视同销售收入是纳税调整项目吗?
  • 扶贫入股分红能领多久
  • 公司卖东西怎么开票
  • 成本结算怎么处理?
  • 外籍人员可以在中国工作吗
  • 什么是保函业务?如何进行核算?
  • 汽车租赁企业
  • 哪些企业适用于品种法
  • 年终奖分摊到每个月
  • 百旺购货方红字信息表怎么开具
  • 企业所得税查增值税吗
  • 分公司独立核算和非独立核算区别
  • 资产减少应注意的问题有哪些?
  • 企业所得税季报和年报的区别
  • 办理营业执照需要钱吗
  • 查账征收的个体户需要申报个人所得税吗
  • 国税2017年16号文
  • 小微企业所得税优惠政策
  • 做内账收入含税吗
  • macbook设置壁纸后开机变回原样
  • win11发热严重怎么解决
  • 员工 意外保险
  • 超市库存商可以分为几大类
  • PHP:oci_num_rows()的用法_Oracle函数
  • 苹果手机设置来电铃声怎么设置
  • 月末账务结转
  • 土地增值税怎么计算举例说明
  • php数组函数,选班长
  • hbuilderx怎么运行代码
  • 工程投标保证金一般是多少
  • 房地产企业现金流管理问题研究
  • 微信小程序自定义tabbar
  • 营改增前取得的有形动产为标的物
  • 冷饮成本价
  • 固定资产处置怎么计算
  • 供应商发票多开了3毛钱能做到财务费吗
  • python requests读取服务器响应
  • 不动产简易征收增值税发票 可以抵扣
  • 累计折旧会影响净残值吗
  • 企业向银行借入长期借款,应借记
  • MySQL导入导出命令
  • 所得税预缴政策
  • 货物已到发票未开具
  • 认缴 实收资本
  • 端午节发放的现金福利会计处理
  • 房地产企业固定资产贷款
  • 应付账款周转率越大越好还是越小越好?
  • 工程挂靠取得的收入怎么做账?
  • 现金日记账和银行日记账必须逐月结出余额
  • 分期付款购无形资产怎么入账?
  • 什么是发票抬头怎么填
  • 企业取得土地使用权会计处理
  • 施工图审查费计算公式
  • 对公账户提取备用金怎么做账
  • 日记账的建账工作
  • 存货与总账对账
  • mysql获取所有表的数据量
  • mysql绿色版和安装版有什么区别
  • Python MySQL进行数据库表变更和查询
  • windows写字板功能
  • windows8进入桌面
  • 安装WIN10系统后怎么调过设置
  • linux tar -czvf
  • win8应用商店官网
  • win10不停的自动重启
  • win8 embedded
  • windows7 无线服务
  • win10mobile下载官网
  • win10预览版绿屏重启解决
  • dos字符串替换
  • Node.js中的全局变量有哪些
  • express中间件面试题
  • nodejs实战
  • node.js golang
  • 异步加载场景
  • 安卓 触摸屏
  • 代收的款项支付需不需要开发票
  • 汽车运输发票税率是多少
  • 广东电子发票开票软件?
  • 我国公益性企业有哪些
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设