位置: IT常识 - 正文

分割网络损失函数总结!交叉熵,Focal loss,Dice,iou,TverskyLoss!(网络分割算法)

编辑:rootadmin
分割网络损失函数总结!交叉熵,Focal loss,Dice,iou,TverskyLoss! 文章目录前言一、交叉熵loss二、Focal loss一、Dice损失函数一、IOU损失一、TverskyLoss总结前言

推荐整理分享分割网络损失函数总结!交叉熵,Focal loss,Dice,iou,TverskyLoss!(网络分割算法),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:gan网络损失函数,网络分割算法,分割网络有哪些,分类网络损失函数,分割网络有哪些,网络分割算法,分类网络损失函数,分割网络损失函怎么写,内容如对您有帮助,希望把文章链接给更多的朋友!

在实际训练分割网络任务过程中,损失函数的选择尤为重要。对于语义分割而言,极有可能存在着正负样本不均衡,或者说类别不平衡的问题,因此选择一个合适的损失函数对于模型收敛以及准确预测有着至关重要的作用。

一、交叉熵loss

M为类别数; yic为示性函数,指出该元素属于哪个类别; pic为预测概率,观测样本属于类别c的预测概率,预测概率需要事先估计计算;

缺点: 交叉熵Loss可以用在大多数语义分割场景中,但它有一个明显的缺点,那就是对于只用分割前景和背景的时候,当前景像素的数量远远小于背景像素的数量时,即背景元素的数量远大于前景元素的数量,背景元素损失函数中的成分就会占据主导,使得模型严重偏向背景,导致模型训练预测效果不好。

分割网络损失函数总结!交叉熵,Focal loss,Dice,iou,TverskyLoss!(网络分割算法)

同理BCEloss同样面临着这个问题,BCEloss如下。 对所有N个类别都做一次二分类损失计算。

#二值交叉熵,这里输入要经过sigmoid处理import torchimport torch.nn as nnimport torch.nn.functional as Fnn.BCELoss(F.sigmoid(input), target)#多分类交叉熵, 用这个 loss 前面不需要加 Softmax 层nn.CrossEntropyLoss(input, target)二、Focal loss

何凯明团队在RetinaNet论文中引入了Focal Loss来解决难易样本数量不平衡,我们来回顾一下。 对样本数和置信度做惩罚,认为大样本的损失权重和高置信度样本损失权重较低。

class FocalLoss(nn.Module): """ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' Focal_Loss= -1*alpha*(1-pt)*log(pt) :param num_class: :param alpha: (tensor) 3D or 4D the scalar factor for this criterion :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more focus on hard misclassified example :param smooth: (float,double) smooth value when cross entropy :param balance_index: (int) balance class index, should be specific when alpha is float :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. """ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): super(FocalLoss, self).__init__() self.apply_nonlin = apply_nonlin self.alpha = alpha self.gamma = gamma self.balance_index = balance_index self.smooth = smooth self.size_average = size_average if self.smooth is not None: if self.smooth < 0 or self.smooth > 1.0: raise ValueError('smooth value should be in [0,1]') def forward(self, logit, target): if self.apply_nonlin is not None: logit = self.apply_nonlin(logit) num_class = logit.shape[1] if logit.dim() > 2: # N,C,d1,d2 -> N,C,m (m=d1*d2*...) logit = logit.view(logit.size(0), logit.size(1), -1) logit = logit.permute(0, 2, 1).contiguous() logit = logit.view(-1, logit.size(-1)) target = torch.squeeze(target, 1) target = target.view(-1, 1) # print(logit.shape, target.shape) # alpha = self.alpha if alpha is None: alpha = torch.ones(num_class, 1) elif isinstance(alpha, (list, np.ndarray)): assert len(alpha) == num_class alpha = torch.FloatTensor(alpha).view(num_class, 1) alpha = alpha / alpha.sum() elif isinstance(alpha, float): alpha = torch.ones(num_class, 1) alpha = alpha * (1 - self.alpha) alpha[self.balance_index] = self.alpha else: raise TypeError('Not support alpha type') if alpha.device != logit.device: alpha = alpha.to(logit.device) idx = target.cpu().long() one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() one_hot_key = one_hot_key.scatter_(1, idx, 1) if one_hot_key.device != logit.device: one_hot_key = one_hot_key.to(logit.device) if self.smooth: one_hot_key = torch.clamp( one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth) pt = (one_hot_key * logit).sum(1) + self.smooth logpt = pt.log() gamma = self.gamma alpha = alpha[idx] alpha = torch.squeeze(alpha) loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt if self.size_average: loss = loss.mean() else: loss = loss.sum() return loss一、Dice损失函数

集合相似度度量函数。通常用于计算两个样本的相似度,属于metric learning。X为真实目标mask,Y为预测目标mask,我们总是希望X和Y交集尽可能大,占比尽可能大,但是loss需要逐渐变小,所以在比值前面添加负号。 可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标。Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失。因此单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CE Loss或者Dice Loss+Focal Loss等。

该处说明原文链接:https://blog.csdn.net/Mike_honor/article/details/125871091

def dice_loss(prediction, target): """Calculating the dice loss Args: prediction = predicted image target = Targeted image Output: dice_loss""" smooth = 1.0 i_flat = prediction.view(-1) t_flat = target.view(-1) intersection = (i_flat * t_flat).sum() return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))def calc_loss(prediction, target, bce_weight=0.5): """Calculating the loss and metrics Args: prediction = predicted image target = Targeted image metrics = Metrics printed bce_weight = 0.5 (default) Output: loss : dice loss of the epoch """ bce = F.binary_cross_entropy_with_logits(prediction, target) prediction = F.sigmoid(prediction) dice = dice_loss(prediction, target) loss = bce * bce_weight + dice * (1 - bce_weight) return loss一、IOU损失

该损失函数与Dice损失函数类似,都是metric learning衡量,在实验中都可以尝试,在小目标分割收敛中有奇效!

def SoftIoULoss( pred, target): # Old One pred = torch.sigmoid(pred) smooth = 1 # print("pred.shape: ", pred.shape) # print("target.shape: ", target.shape) intersection = pred * target loss = (intersection.sum() + smooth) / (pred.sum() + target.sum() -intersection.sum() + smooth) # loss = (intersection.sum(axis=(1, 2, 3)) + smooth) / \ # (pred.sum(axis=(1, 2, 3)) + target.sum(axis=(1, 2, 3)) # - intersection.sum(axis=(1, 2, 3)) + smooth) loss = 1 - loss.mean() # loss = (1 - loss).mean() return loss一、TverskyLoss

分割任务也有不同侧重点,如医学分割更加关注召回率(高灵敏度),即真实mask尽可能都被预测出来,不太关注预测mask有没有多预测。B为真实mask,A为预测mask。|A-B|为假阳,|B-A|为假阴,alpha和beta可以控制假阳和假阴之间的权衡。若我们更加关注召回,则放大|B-A|的影响。 其中alpha和beta可以影响找回率和准确率,若想目标有较高的召回率,那么我们可以选择较高的beta。

class TverskyLoss(nn.Module): def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1., square=False): """ paper: https://arxiv.org/pdf/1706.05721.pdf """ super(TverskyLoss, self).__init__() self.square = square self.do_bg = do_bg self.batch_dice = batch_dice self.apply_nonlin = apply_nonlin self.smooth = smooth self.alpha = 0.3 self.beta = 0.7 def forward(self, x, y, loss_mask=None): shp_x = x.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth) if not self.do_bg: if self.batch_dice: tversky = tversky[1:] else: tversky = tversky[:, 1:] tversky = tversky.mean() return -tversky总结

在经过一系列实验后,发现后四种损失函数更加适合小目标分割网络训练。但是每个任务都有差异,如果时间很充裕的话,可以挨个尝试一下。

本文链接地址:https://www.jiuchutong.com/zhishi/298409.html 转载请保留说明!

上一篇:vue3生命周期及setup介绍(vue3生命周期及使用)

下一篇:Cifar-10图像分类/Pytorch/LeNet/AlexNet(cifar10图像分类实验报告)

  • 苹果11桌面小组件怎么设置(苹果11桌面小组件照片无可用内容)

    苹果11桌面小组件怎么设置(苹果11桌面小组件照片无可用内容)

  • 微信相册里的照片怎么删除(微信相册里的照片占内存吗)

    微信相册里的照片怎么删除(微信相册里的照片占内存吗)

  • 小米mix3屏幕刷新率(小米mix3屏幕刷新率是多少赫兹)

    小米mix3屏幕刷新率(小米mix3屏幕刷新率是多少赫兹)

  • 虚拟机连接不上网络(虚拟机连接不上finalshell)

    虚拟机连接不上网络(虚拟机连接不上finalshell)

  • 乐划锁屏是干什么的(乐划锁屏干什么用)

    乐划锁屏是干什么的(乐划锁屏干什么用)

  • 华为笔记本充电没反应(华为笔记本充电70就不充了)

    华为笔记本充电没反应(华为笔记本充电70就不充了)

  • 电脑蓝屏0x0000001e(电脑蓝屏0x0000001A是什么问题)

    电脑蓝屏0x0000001e(电脑蓝屏0x0000001A是什么问题)

  • 猎豹wifi老是自动断开(猎豹wifi自动断开)

    猎豹wifi老是自动断开(猎豹wifi自动断开)

  • 下载微信步骤(oppo手机下载微信步骤)

    下载微信步骤(oppo手机下载微信步骤)

  • 网页打开速度慢的原因(网页打开速度慢源码问题)

    网页打开速度慢的原因(网页打开速度慢源码问题)

  • 苹果7plus用起来不流畅(苹果7plus运行速度慢怎么办)

    苹果7plus用起来不流畅(苹果7plus运行速度慢怎么办)

  • 关闭微信别人能看到吗(关闭微信别人能看到微信运动吗)

    关闭微信别人能看到吗(关闭微信别人能看到微信运动吗)

  • 手机qq聊天背景怎么恢复默认(手机QQ聊天背景图片怎么保存下来?)

    手机qq聊天背景怎么恢复默认(手机QQ聊天背景图片怎么保存下来?)

  • 大王卡看快手为什么还要流量(大王卡看快手为什么扣费)

    大王卡看快手为什么还要流量(大王卡看快手为什么扣费)

  • 抖音号被禁了怎么解封(抖音号被禁怎么解决)

    抖音号被禁了怎么解封(抖音号被禁怎么解决)

  • 探探往左滑是什么意思(探探往左滑右滑怎么办)

    探探往左滑是什么意思(探探往左滑右滑怎么办)

  • iphonex是iphone10吗(iphone x是苹果10吗)

    iphonex是iphone10吗(iphone x是苹果10吗)

  • 京东买手机可以退货吗(京东买手机可以七天无理由退货吗)

    京东买手机可以退货吗(京东买手机可以七天无理由退货吗)

  • 乐视手机怎么截长图(乐视手机怎么截图快捷键)

    乐视手机怎么截长图(乐视手机怎么截图快捷键)

  • word如何快速生成目录(word如何快速生成流程图)

    word如何快速生成目录(word如何快速生成流程图)

  • 华为nova4为什么发烫(华为nova4为什么更新不了鸿蒙)

    华为nova4为什么发烫(华为nova4为什么更新不了鸿蒙)

  • 手机与调音台连接方法(手机和调音台怎么连)

    手机与调音台连接方法(手机和调音台怎么连)

  • vivo手机反向充电怎么打开(vivo手机反向充电功能)

    vivo手机反向充电怎么打开(vivo手机反向充电功能)

  • Samsung三星笔记本电脑BIOS设置全功能菜单详解(三星笔记app功能介绍)

    Samsung三星笔记本电脑BIOS设置全功能菜单详解(三星笔记app功能介绍)

  • Vue生命周期钩子剖析(共12个钩子)(vue生命周期钩子函数)

    Vue生命周期钩子剖析(共12个钩子)(vue生命周期钩子函数)

  • 【Vue全家桶】新一代的状态管理--Pinia(vue全家桶教程)

    【Vue全家桶】新一代的状态管理--Pinia(vue全家桶教程)

  • 国家税收与地方税收
  • 实际已缴纳所得税额在汇算清缴报告里怎么看
  • 专用增值税发票和普通发票区别
  • 住宿发票要附清单吗
  • 资产负债表要素包括几项
  • 基础设施特许权包括
  • 一般纳税人的含税收入怎么算
  • 企业无法收回的账款
  • 公司注册核税后如何建账?
  • 公休假补贴多少钱
  • 房地产企业有投资性房地产吗
  • 房产交易会涉及哪些费用
  • 招待费怎么处理
  • 通讯费税前扣除标准
  • 做季报和月报增发的区别
  • 一般纳税人申报表填写顺序
  • 网络发票开具
  • 粮食购销企业
  • qq游戏怎么玩不了怎么回事
  • 如何设置电源键关闭屏幕
  • 期间费用包括哪几个科目
  • 清算资金往来借贷方什么意思
  • win7安装驱动程序
  • 公司注销帐上的钱取出来要交税吗
  • retrorun.exe - retrorun有什么用 是什么进程
  • 计提安全费用含税还是不含税
  • 票据承兑与票据贴现的区别
  • 股东股权折价转让会计分录
  • 金钱树的养殖方法 盆栽
  • 支付手续费方式委托代销商品确认收入
  • 应付职工薪酬包括个人社保和个税吗
  • laravel视频教程
  • php use function
  • echart设置legend
  • php浮点数四舍五入
  • 小微企业免征增值税政策2023
  • 工程服务费会计怎么做账
  • thinkphp excel
  • yii2框架中文手册
  • php新手入门教程
  • Using Visual Leak Detector
  • 政府会计代扣公积金怎么做分录
  • 增值税专用发票丢了怎么补救
  • 促销服务费分录
  • 单位收的房租可以发工资吗
  • python tkinter ttk
  • mysql建表的完整步骤
  • php 上传
  • 事务所的账务处理
  • sqlserver2005数据库磁盘满了什么文件能删
  • 企业福利费账务处理
  • 小规模免税收入怎么做账
  • 营业利润是负数什么原因
  • 外出经营必须办理外管证吗
  • 业务协作费计入什么科目
  • 内部招待所管理规定
  • 住宿费记入成本会计分录
  • 工程款发票开给委托方要如何处理?
  • 服务类企业主要经营范围
  • 在建工程转固规定
  • 息税前利润变动百分比计算公式
  • sql server的修改语句
  • Ubuntu下mysql与mysql workbench安装教程
  • mysql 5.7.11 winx64.zip安装配置方法图文教程
  • centos 命令
  • ARP欺骗攻击原理
  • fedora系统安装教程
  • win10系统如何查看激活状态
  • 音频文件恢复
  • 电脑系统垃圾
  • linux chakan
  • linux !!
  • opengl绘制三维图形代码
  • js弹出层效果
  • Node.js中的全局变量有哪些
  • jquery页面跳转的方法
  • unity控制三维模型
  • androidstudio和idea
  • 我国税务师事务所有哪些
  • 税收征管工作的基本目标
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设