位置: IT常识 - 正文

DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合

编辑:rootadmin
DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合

目录

Mobilenetv2的改进

浅层特征和深层特征的融合

完整代码

参考资料


Mobilenetv2的改进

推荐整理分享DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

在DeeplabV3当中,一般不会5次下采样,可选的有3次下采样和4次下采样。因为要进行五次下采样的话会损失较多的信息。

在这里mobilenetv2会从之前写好的模块中得到,但注意的是,我们在这里获得的特征是[-1],也就是最后的1x1卷积不取,只取循环完后的模型。

down_idx是InvertedResidual进行的次数。

# t, c, n, s[1, 16, 1, 1], [6, 24, 2, 2],    2[6, 32, 3, 2],    4[6, 64, 4, 2],    7  [6, 96, 3, 1],[6, 160, 3, 2],   14[6, 320, 1, 1], 

根据下采样的不同,当downsample_factor=8时,进行3次下采样,对倒数两次,步长为2的InvertedResidual进行参数的修改,让步长变为1,膨胀系数为2。

DeepLabV3+:Mobilenetv2的改进以及浅层特征和深层特征的融合

当downsample_factor=16时,进行4次下采样,只需对最后一次进行参数的修改。

import torchimport torch.nn as nnimport torch.nn.functional as Ffrom functools import partialfrom net.mobilenetv2 import mobilenetv2from net.ASPP import ASPPclass MobileNetV2(nn.Module): def __init__(self, downsample_factor=8, pretrained=True): super(MobileNetV2, self).__init__() model = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, x

forward当中,会输出两个特征层,一个是浅层特征层,具有浅层的语义信息;另一个是深层特征层,具有深层的语义信息。

浅层特征和深层特征的融合

 具有高语义信息的部分先进行上采样,低语义信息的特征层进行1x1卷积,二者进行特征融合,再进行3x3卷积进行特征提取

self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)

这一步就是获得那个绿色的特征层;

low_level_features = self.shortcut_conv(low_level_features)

从这里将是对浅层特征的初步处理(1x1卷积);

x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)x = self.cat_conv(torch.cat((x, low_level_features), dim=1))

上采样后进行特征融合,这样我们输入和输出的大小才相同,每一个像素点才能进行预测;

完整代码# deeplabv3plus.pyimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom functools import partialfrom net.xception import xceptionfrom net.mobilenetv2 import mobilenetv2from net.ASPP import ASPPclass MobileNetV2(nn.Module): def __init__(self, downsample_factor=8, pretrained=True): super(MobileNetV2, self).__init__() model = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate//2, dilate//2) m.padding = (dilate//2, dilate//2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, xclass DeepLab(nn.Module): def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16): super(DeepLab, self).__init__() if backbone=="xception": # 获得两个特征层:浅层特征 主干部分 self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 2048 low_level_channels = 256 elif backbone=="mobilenet": # 获得两个特征层:浅层特征 主干部分 self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 320 low_level_channels = 24 else: raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone)) # ASPP特征提取模块 # 利用不同膨胀率的膨胀卷积进行特征提取 self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor) # 浅层特征边 self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) self.cat_conv = nn.Sequential( nn.Conv2d(48+256, 256, kernel_size=(3,3), stride=(1,1), padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, kernel_size=(3,3), stride=(1,1), padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) self.cls_conv = nn.Conv2d(256, num_classes, kernel_size=(1,1), stride=(1,1)) def forward(self, x): H, W = x.size(2), x.size(3) # 获得两个特征层,low_level_features: 浅层特征-进行卷积处理 # x : 主干部分-利用ASPP结构进行加强特征提取 low_level_features, x = self.backbone(x) x = self.aspp(x) low_level_features = self.shortcut_conv(low_level_features) # 将加强特征边上采样,与浅层特征堆叠后利用卷积进行特征提取 x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True) x = self.cat_conv(torch.cat((x, low_level_features), dim=1)) x = self.cls_conv(x) x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) return x参考资料

DeepLabV3-/论文精选 at main · Auorui/DeepLabV3- (github.com)

(6条消息) 憨批的语义分割重制版9——Pytorch 搭建自己的DeeplabV3+语义分割平台_Bubbliiiing的博客-CSDN博客

本文链接地址:https://www.jiuchutong.com/zhishi/290684.html 转载请保留说明!

上一篇:解决RTX 3090 with CUDA capability sm_86 is not compatible with the current PyTorch installation.(解决脱发的8个方法)

下一篇:在妈妈身旁玩耍的北极熊宝宝们,加拿大曼尼托巴省 (© Andre Gilden/Minden Pictures)(在妈妈身边的说说)

  • 金税三期怎么合理避税
  • 递延所得税负债是什么科目
  • 房屋租赁印花税怎么算
  • 新公司第一年要做亏
  • 工会经费每月必须60块钱
  • 高温费国家有规定,一定要支付吗?
  • 服务不动产扣除项目怎么填
  • 出售未计提完折旧的固定资产
  • 国税地税电子钥匙价格
  • 企业年报社保都是0人的公司
  • 个人所得税合并扣税
  • 微税平台抄税的步骤是怎样的?
  • 乐器的税率
  • 兼兼的意思
  • 利息支出应计入
  • 结算会计和往来账的区别
  • 季度财务报表怎么打印
  • 小规模企业增值税税收优惠政策2023
  • 红字增值税发票含税吗
  • 消费税在哪个环节征税
  • 工商注销债务承担
  • 小规模纳税人个税怎么申报
  • 合并范围外关联方交易是否抵消
  • 电脑维修会不会对电脑有影响
  • bios中怎么设置显卡
  • 会计分录有哪几种形式
  • 创建自定义对象主要哪几种方法,并写出基本语法结构?
  • 结转出租包装物因不能使用而报废的残料价值
  • 行政单位职工福利费使用范围
  • 最大的apple商店
  • 发放员工奖励
  • 创业投资企业可以签订代持股协议吗合法吗
  • 商业承兑汇票贴现
  • 建筑业营改增后税务问题
  • 业务招待费的账务处理金额
  • vscode+cmake
  • php文档系统
  • 职教费可以抵扣进项么
  • 蓝牙11
  • web网页设计期末作业猫眼电影首页
  • spring获取bean的完全限定类名
  • node.js快速入门
  • 管家婆付款单凭证科目如何修改
  • 增值税普通发票可以抵扣吗
  • 哪些发票可以抵企业所得税
  • 如何补缴以前年度的税
  • 货物运输行业前景如何
  • 记账凭证去根据什么填制
  • 会计中金额的正负怎么算
  • 公司账务不正规,账务外包的,财务助理有风险吗
  • 综合所得减除费用标准
  • 航天税盘服务费开的普票可以抵税吗
  • 预缴增值税需要提交什么资料
  • 关联交易现金流
  • 应收账款少收会计分录
  • 用友作废的凭证怎么恢复
  • 长期挂账的其他应付款税务风险
  • 稽查人员是干嘛的
  • sqlserver 禁用触发器 超时
  • win7,win8.1,win10命令行配置ip地址图文教程
  • mac怎么用bootcamp
  • Win10 Mobile 10572预览版上手体验视频
  • Virtualbox共享文件
  • hptasks.exe是病毒吗 是什么进程 hptasks进程说明
  • linux k
  • [OpenGL ES 04]3D变换实践篇:平移,旋转,缩放
  • 基于是什么意思
  • 疯狂冒险王官网
  • js中eval函数是干嘛的
  • node.js gui
  • 请简述vue-router路由的作用
  • javascript教程chm
  • shell判断文件是否存在且大小不为0
  • bootstrapping怎么做
  • JavaScript中的6种运算符总结
  • ActivityManagerService (三)
  • 河南省发票查询真伪查询系统
  • 成都国家税务局每个月交全民付的钱是什么
  • 西安市港务区属于哪个街道办
  • 物业管理用房如何申请
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设