位置: IT常识 - 正文

Boundary Loss 原理与代码解析(bounded linear functional)

编辑:rootadmin
Boundary Loss 原理与代码解析

推荐整理分享Boundary Loss 原理与代码解析(bounded linear functional),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:loss.backward()原理,loose bound,boundary-crossing,loose bound,boundary-crossing,boundary loss for highly unbalanced segmentation,boundary solution,boundary load,内容如对您有帮助,希望把文章链接给更多的朋友!

paper:Boundary loss for highly unbalanced segmentation

Introduction

在医学图像分割中任务中通常存在严重的类别不平衡问题,目标前景区域的大小常常比背景区域小几个数量级,比如下图中前景区域比背景区域小500倍以上。

分割通常采用的交叉熵损失函数,在高度不平衡的问题上存在着众所周知的缺点即它假设所有样本和类别的重要性相同,这通常会导致训练的不稳定,并导致决策边界偏向于数量多的类别。对于类别不平衡问题,一种常见的策略是对数目多的类别进行降采样来重新平衡类别的先验分布,但是这种策略限制了训练图像的使用。另一种策略是加权,即对数量少的类别赋予更大的权重,对数量多的类别赋予更小的权重,虽然这种方法对一些不平衡的问题是有效的,但处理极度不平衡的数据时还是有困难。在少数几个像素上计算的交叉熵梯度通常包含了噪声,赋予少数类别更大的权重进一步加大了噪声从而导致训练的不稳定。

分割中另一种常见的损失函数dice loss,在不平衡的医学图像分割问题中通常比ce loss的效果好。但遇到非常小的区域时可能会遇到困难,错误分类的像素可能会导致loss的剧烈降低,从而导致优化的不稳定。此外,dice loss对应精度和召回的调和平均,当true positive不变时,false postive和false negative重要性相同,因此dice loss主要适用于这两种类型的误差数量差不多的情况。

Contributions

CE loss和Dice loss分别是基于分布和基于区域的损失函数,本文提出了一种基于边界的损失函数,它在轮廓空间而不是区域空间上采用距离度量的形式。边界损失计算的不是区域上积分,而是区域之间边界上积分,因此可以缓解高度不平衡分割问题中区域损失的相关问题。

但是怎么根据CNN的regional softmax输出来表示对应的boundary points是个很大的挑战,本文受到用离散基于图的优化方法来计算曲线演化梯度流的启发,采用积分方法来计算边界的变化,避免了轮廓点上的局部微分计算,最终的boundary loss是网络输出区域softmax概率的线性函数和,因此可以和现有的区域损失结合使用。

Formulation

 \(I:\Omega \subset \mathbb{R}^{2,3}\rightarrow \mathbb{R}\) 表示空间域 \(\Omega\) 中的一张图片,\(g:\Omega \rightarrow \begin{Bmatrix} 0,1 \end{Bmatrix}\) 是该图片的ground truth分割二值图,如果像素 \(p\) 属于目标区域 \(G\subset \Omega\) (前景区域),\(g(p)=1\),否则为0,即 \(p\in\Omega\setminus G\)(背景区域)。\(s_{\theta}:\Omega\rightarrow [0,1]\) 表示分割网络的softmax概率输出,\(S_{\theta}\subset\Omega\) 表示模型输出的对应前景区域即 \(S_{\theta}=\begin{Bmatrix} p\in\Omega|s_{\theta}(p)\geqslant \delta  \end{Bmatrix}\),其中 \(\delta\) 是提前设定的阈值。

我们的目的是构建一个边界损失函数 \(Dist(\partial G,\partial S_{\theta })\),它采用 \(\Omega\) 中区域边界空间中距离度量的形式,其中 \(\partial G\) 是ground truth区域 \(G\) 的边界的一种表示(比如边界上所有点的集和),\(\partial S_{\theta }\) 是网络输出定义的分割区域的边界。如何将 \(\partial S_{\theta }\) 上的点表示成网络输出区域 \(s_{\theta }\) 的可导函数尚不清楚。考虑下面的形状空间上非对称 \(L_{2}\ distance\) 的表示,它评估的是两个临近边界 \(\partial S\) 和 \(\partial G\) 之间的距离变化

其中 \(p\in\Omega\) 是边界 \(\partial G\) 上的一点,\(y_{\partial S}(p)\) 是边界 \(\partial S\) 上对应的点,即 \(y_{\partial S}(p)\) 是 \(\partial G\) 上点 \(p\) 处的发现与 \(\partial S\) 的交点,如图2(a)所示,\(\left \| \cdot  \right \|\) 表示 \(L_{2}\) 范数。和其它直接调用轮廓 \(\partial S\)上点的轮廓军距离一样,对于 \(\partial S=\partial S_{\theta}\) 式(2)不能直接作为loss函数使用。但是很容易证明式(2)中的微分边界变化可以用积分方法来近似,这就避免了涉及轮廓上点的微分计算,并用区域积分来表示边界变化,如下

其中 \(\bigtriangleup S\) 表示两个轮廓之间的区域,\(D_{G}:\Omega\rightarrow \mathbb{R}^{+}\) 是一个相对于边界 \(\partial G\) 的distance map,即 \(D_{G}(q)\) 表示任意点 \(q\in\Omega\) 与轮廓 \(\partial G\) 上最近点 \(z_{\partial G}(q)\) 之间的距离:\(D_{G}(q)=\left \| q-z_{\partial G}(q) \right \|\),如图2(b)所示。

为了证明这种近似,沿连接 \(\partial G\) 上的一点 \(p\) 与 \(y_{\partial S}(p)\) 之间的法线对距离图 \(2D_{G}(q)\) 进行积分通过如下的转换可得 \(\left \| y_{\partial S(p)}-p \right \|^{2}\)

Boundary Loss 原理与代码解析(bounded linear functional)

由式(3)进一步得到下式

其中 \(s:\Omega\rightarrow \left \{ 0,1 \right \}\) 是区域 \(S\) 的二元指示函数:\(s(q)=1\ if\ q\in S\) 属于目标否则为0。\(\phi _{G}:\Omega\rightarrow \mathbb{R}\) 是边界 \(\partial G\) 的水平集表示:\(\phi _{G}(q)=-D_{G}(q)\ if\ q\in G\) 否则 \(\phi _{G}(q)=D_{G}(q)\)。对于 \(S=S_{\theta}\),即用网络的softmax输出 \(s_{\theta}(q)\) 替换式(4)中的 \(s(q)\),我们就得到了如下所示的边界损失

注意我们去掉了式(4)中的最后一项,因为它不包含模型参数。水平集函数 \(\phi_{G}\) 是直接根据gt区域 \(G\) 提前计算得到的。边界损失可以与常用的基于区域的损失函数结合起来用于 \(N\) 类的分割问题

其中 \(\alpha \in\mathbb{R}\) 是平衡两个损失的权重参数。

在式(5)中,每个点 \(q\) 的softmax输出通过距离函数进行加权,在基于区域的损失函数中,这种到边界距离的信息被忽略了,区域内每个点不管到边界距离大小都都按同样的权重进行处理。

在作者提出的边界损失中,当距离函数中所有的负值都保留(模型对即gt区域中所有像素的softmax预测都为1)而所有的正值都舍去(即模型对背景的softmax预测都为0)时,边界损失到达全局最小,即模型的softmax预测正好输出ground truth时边界损失最小,这也验证了边界损失的有效性。

在后续的实验中可以看到,通常要把边界损失和区域损失结合起来使用才能取得好的效果。作者在文中解释的原因没太看懂,贴一下原文

 "As discussed earlier, the global optimum of our boundary loss corresponds to a strictly negative value, with the softmax probabilities yielding a non-empty foreground region. However, an empty foreground, with approximately null values of the softmax probabilities almost everywhere, corresponds to very low gradients. Therefore, this trivial solution is close to a local minimum or a saddle point. This is why we integrate our boundary loss with a regional loss"

ExperimentsComparision of regional losses

在于其它损失函数的对比实验中,\alpha采用rebalance策略,即初始值为0.01,每个epoch后增加0.01。

从表中可以看到不管是cross-entropy loss、general dice loss还是focal loss,在于boundary loss结合使用后都获得了一定的精度提升,表明了边界损失的有效性。 

Selection of \(\alpha\)

作者对比了三种不同的方式,一是constant \(\alpha\),即在整个训练过程中 \(\alpha\) 的值保持不变;二是increase \(\alpha\),即初始设置为一个大于0但比较小的值,在每个epoch结束后逐渐增加 \(\alpha\)值,但区域损失的权重保持不变,直到训练结束,两种损失的权重一样大;三是rebalance \(\alpha\),即按 \((1-\alpha)L_{R}+\alpha L_{B}\) 的方式组合两种损失,每个epoch后增加 \(\alpha\) 的值,随着训练的进行边界损失的权重越来越大,而区域损失的权重越来越小。实验结果如下

可以看出,Rebalance的策略获得了最优结果,因此在于其它区域损失的结果对比实验中,也全部使用了该策略。

Implementation

其中data是ground truth,这里只考虑二分类的情况,即前景和背景。logits是softmax后的输出,这里为了方便相当于通过argmax或是阈值的方式将模型输出中的每个像素划分到对应类别了,实际上这里的值应该是softmax的输出,介于[0, 1]之间。其中计算distance map是通过scipy库中的distance_transform_edt函数,关于这个函数的介绍可参考 scipy.ndimage.distance_transform_edt 和 cv2.distanceTransform用法

import torchimport numpy as npfrom torch import einsumfrom torch import Tensorfrom scipy.ndimage import distance_transform_edt as distancefrom typing import Any, Callable, Iterable, List, Set, Tuple, TypeVar, Union# switch between representationsdef probs2class(probs: Tensor) -> Tensor: b, _, w, h = probs.shape # type: Tuple[int, int, int, int] assert simplex(probs) res = probs.argmax(dim=1) assert res.shape == (b, w, h) return resdef probs2one_hot(probs: Tensor) -> Tensor: _, C, _, _ = probs.shape assert simplex(probs) res = class2one_hot(probs2class(probs), C) assert res.shape == probs.shape assert one_hot(res) return resdef class2one_hot(seg: Tensor, C: int) -> Tensor: if len(seg.shape) == 2: # Only w, h, used by the dataloader seg = seg.unsqueeze(dim=0) assert sset(seg, list(range(C))) b, w, h = seg.shape # type: Tuple[int, int, int] res = torch.stack([seg == c for c in range(C)], dim=1).type(torch.int32) assert res.shape == (b, C, w, h) assert one_hot(res) return resdef one_hot2dist(seg: np.ndarray) -> np.ndarray: assert one_hot(torch.Tensor(seg), axis=0) C: int = len(seg) res = np.zeros_like(seg) # res = res.astype(np.float64) for c in range(C): posmask = seg[c].astype(np.bool) if posmask.any(): negmask = ~posmask res[c] = distance(negmask) * negmask - (distance(posmask) - 1) * posmask return resdef simplex(t: Tensor, axis=1) -> bool: _sum = t.sum(axis).type(torch.float32) _ones = torch.ones_like(_sum, dtype=torch.float32) return torch.allclose(_sum, _ones)def one_hot(t: Tensor, axis=1) -> bool: return simplex(t, axis) and sset(t, [0, 1]) # Assert utilsdef uniq(a: Tensor) -> Set: return set(torch.unique(a.cpu()).numpy())def sset(a: Tensor, sub: Iterable) -> bool: return uniq(a).issubset(sub)class SurfaceLoss(): def __init__(self): # Self.idc is used to filter out some classes of the target mask. Use fancy indexing self.idc: List[int] = [1] # 这里忽略背景类 https://github.com/LIVIAETS/surface-loss/issues/3 # probs: bcwh, dist_maps: bcwh def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor: assert simplex(probs) assert not one_hot(dist_maps) pc = probs[:, self.idc, ...].type(torch.float32) dc = dist_maps[:, self.idc, ...].type(torch.float32) multiplied = einsum("bcwh,bcwh->bcwh", pc, dc) loss = multiplied.mean() return lossif __name__ == "__main__": data = torch.tensor([[[0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0], [0, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]]) # (b, h, w)->(1,4,7) data2 = class2one_hot(data, 2) # (b, num_class, h, w): (1,2,4,7) data2 = data2[0].numpy() # (2,4,7) data3 = one_hot2dist(data2) # bcwh logits = torch.tensor([[[0, 0, 0, 0, 0, 0, 0], [0, 1, 1, 1, 1, 1, 0], [0, 1, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0]]]) # (b, h, w) logits = class2one_hot(logits, 2) Loss = SurfaceLoss() data3 = torch.tensor(data3).unsqueeze(0) res = Loss(logits, data3, None) print('loss:', res)

注意,对于某一类的目标区域,在计算distance map时,该区域外的距离都是正值,该区域内的距离都是负值,且距离区域边界越远,绝对值越大。当有多类时,计算distance map是每一类单独计算的,每一类的目标区域当做前景值为1,其它区域都是背景值为0。理想情况下,模型应该将区域外的像素都预测为背景即全预测为0,将区域内的像素都预测为前景即1,此时的loss是负值且达到全局最小。 

本文链接地址:https://www.jiuchutong.com/zhishi/299877.html 转载请保留说明!

上一篇:ChatGPT 逆天测试,结局出乎预料

下一篇:VueRouter的两种模式(vuerouter模块化)

  • vivox70怎么设置字体大小(vivox70怎么设置呼叫转移)

    vivox70怎么设置字体大小(vivox70怎么设置呼叫转移)

  • word怎么做表格(word怎么做表格在电脑上怎么操作)

    word怎么做表格(word怎么做表格在电脑上怎么操作)

  • 电脑的序列号从哪看(电脑 序列号)

    电脑的序列号从哪看(电脑 序列号)

  • 苹果xsmax频繁自动重启(苹果xsmax总是自动关机)

    苹果xsmax频繁自动重启(苹果xsmax总是自动关机)

  • 宽带灯怎么才正常(宽带灯光怎么闪才正常)

    宽带灯怎么才正常(宽带灯光怎么闪才正常)

  • 苹果手机已停用请在14分钟使用是什么情况(苹果手机已停用怎么办才能解开)

    苹果手机已停用请在14分钟使用是什么情况(苹果手机已停用怎么办才能解开)

  • 抖音注销了别人还能看见作品吗(抖音注销了别人看到是什么样子)

    抖音注销了别人还能看见作品吗(抖音注销了别人看到是什么样子)

  • 负责路由tcpip协议包括哪些协议(路由器服务协议tcp)

    负责路由tcpip协议包括哪些协议(路由器服务协议tcp)

  • ipadair有128g的吗(ipadair有没有128g)

    ipadair有128g的吗(ipadair有没有128g)

  • 什么是web浏览器(web浏览器手机版)

    什么是web浏览器(web浏览器手机版)

  • wps删除整页(wps删除整页怎么删?这样操作简单又快捷)

    wps删除整页(wps删除整页怎么删?这样操作简单又快捷)

  • 戴尔笔记本怎么截图(戴尔笔记本怎么用u盘重装系统)

    戴尔笔记本怎么截图(戴尔笔记本怎么用u盘重装系统)

  • 苹果11怎么设置微信美颜(苹果11怎么设置动态壁纸)

    苹果11怎么设置微信美颜(苹果11怎么设置动态壁纸)

  • oppo手机电话号码怎么保存到卡上(oppo手机电话号码导入sim卡)

    oppo手机电话号码怎么保存到卡上(oppo手机电话号码导入sim卡)

  • 华为荣耀9x指纹在哪里(华为荣耀9x指纹识别地方烫手)

    华为荣耀9x指纹在哪里(华为荣耀9x指纹识别地方烫手)

  • 交换机的交换技术(交换机的交换技术有那三种?)

    交换机的交换技术(交换机的交换技术有那三种?)

  • 滴滴没单子原地等吗(滴滴没单子的时候怎么办)

    滴滴没单子原地等吗(滴滴没单子的时候怎么办)

  • qq宝贝为什么停运(为什么qq宝贝不能玩了)

    qq宝贝为什么停运(为什么qq宝贝不能玩了)

  • mamimo dutti什么牌子(massimo dutti档次)

    mamimo dutti什么牌子(massimo dutti档次)

  • iphone xs max美版是双卡吗(iphone xs max美版是双卡双待吗)

    iphone xs max美版是双卡吗(iphone xs max美版是双卡双待吗)

  • iphonex型号(iphoneX型号号码)

    iphonex型号(iphoneX型号号码)

  • 任务栏图标重叠在一起解决方法(任务栏图标重叠一起)

    任务栏图标重叠在一起解决方法(任务栏图标重叠一起)

  • NPFMSG.exe - NPFMSG是什么进程 有什么用

    NPFMSG.exe - NPFMSG是什么进程 有什么用

  • Python爬虫之Web自动化测试工具Selenium&&Chrome handless(web爬虫视频教程)

    Python爬虫之Web自动化测试工具Selenium&&Chrome handless(web爬虫视频教程)

  • 织梦dedecms自定义表单添加地区联动显示数字解决方法(织梦怎么改网站主页)

    织梦dedecms自定义表单添加地区联动显示数字解决方法(织梦怎么改网站主页)

  • 需要进项税额转出的发票还用勾选吗
  • 企业所得税纳税人
  • 差额征收如何做账
  • 所得税在什么情况下扣除
  • 电子发票限额多了怎么办
  • 利润表中第3栏营业税金及附加等于什么
  • 企业所得有哪些税种
  • 无形资产摊销起止时间
  • 有留抵税额可以享受加计抵减吗
  • 别人提供原材料加工后加工费
  • 发票折扣有没有限制
  • 应收账款已收回但是账面还有余额怎么处理
  • 混凝土的增值税率是多少
  • 投资收益年底结转怎么算
  • 增值税专用发票几个点
  • 运输发票的税率有几种
  • 买入返售金融资产是资产还是负债
  • 子公司是否可以共享总公司的资质
  • 其他应付款注销时怎么冲平
  • 资产负债表与现金流量表的关系
  • 微软为XSX推出星空版主机壳
  • 如何安装电脑系统win7电路连接
  • 个体户个人所得税免征额是多少
  • window10下载cad2014
  • 公司过户费用怎么入账
  • 月底增值税怎么计提
  • 无形资产摊销时点
  • 购进农产品发生非正常损失
  • php入门课程
  • 以银行存款交纳欠缴税金会计分录
  • vue-admin-master
  • 企业可以超范围经营吗
  • php中.的作用
  • 前端axios是什么
  • php数组实现原理
  • 最全vue项目实战
  • html六边形的盒子怎么做
  • cp命令使用
  • 长期借款主要包括哪些
  • 上市公司净资产转正的方法
  • 帝国cms8.0
  • 成本和费用有着根本的区别
  • 工资以现金形式发放英文
  • CentOS 7下MySQL服务启动失败的快速解决方法
  • 怎样备份mysql数据库
  • 开个分公司有啥好处
  • 备用金有发票抵扣吗
  • 股东权益合计等于净资产吗
  • 查询发票真伪
  • 股权转让需要哪些手续及流程
  • 员工的工资属于固定资产吗
  • 净资产收益率多少才是好股
  • 机票价格分类
  • 职工薪酬包括的内容
  • 收到银联商务客户短信
  • 如果以前做了错事怎么办
  • 新建企业需要什么手续
  • win10电脑系统配置
  • 系统审核策略配置
  • 主板bios无法重置
  • windowsxp
  • vsftp查看状态
  • ntldr文件在哪
  • win10 开始
  • 编写一个定时间隔为5ms的子程序
  • win7开机错误代码
  • python火车订票系统
  • Node.js中的全局变量有哪些
  • shell遍历sql查询结果
  • JavaScript中的数据类型
  • js动态改变网页标题
  • java项目怎么变成web项目
  • 一个简单的javaweb项目
  • jquery 插件写法
  • 银行人员司法查询给查错了,怎么办
  • 云办税大厅
  • 江苏税务网上办税服务厅服务提醒
  • 西安税务机关
  • 辽宁省国家税务总局
  • 电子税务平台怎么红冲纸质发票
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设