位置: IT常识 - 正文

UNet - unet网络(program status)

编辑:rootadmin
UNet - unet网络

目录

1. u-net介绍

2. u-net网络结构

3. u-net 网络搭建

3.1 DoubleConv

3.2 Down 下采样

3.3 Up 上采样

3.4 网络输出

3.5 UNet 网络

UNet 网络

forward  前向传播

3.6 网络的参数

4. 完整代码


1. u-net介绍

推荐整理分享UNet - unet网络(program status),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:unet网络的优缺点,unet网络的优缺点,描写建筑工人的文章,钆喷酸葡胺和钆贝葡胺的区别,缺氧为什么会引起神经细胞兴奋性降低,3点40的飞机,几点到机场,郝加明教授,钆喷酸葡胺和钆贝葡胺的区别,内容如对您有帮助,希望把文章链接给更多的朋友!

Unet网络是医学图像分割领域常用的分割网络,因为网络的结构很像个U,所以称为Unet

Unet 网络是针对像素点的分类,之前介绍的LeNet、ResNet等等都是图像分类,最后分的是整幅图像的类别,而Unet是对像素点输出的是前景还是背景的分类

注:因为Unet 具体的网络框架均有所不同,例如有的连续卷积后会改变图像的size,有的上采样用的是线性插值的方法。这里只介绍same卷积和上采样用的转置卷积

Unet网络是个U型结构,左边是Encoder,右边为Decoder

左边是下采样的过程,通过减少图像size,增加图像channel来提取特征。

右边是还原图像的过程,上采样将逐步还原图像的size,这里上采样的输入特征图不仅仅是上一步的输出,还包含了左边对应特征信息。

2. u-net网络结构

本章采用的unet网络如图,为了后面数据的训练和预测。这里实现的方式和下图有些细小的区别,具体的会在下面讲解

首先,网络输入图像的size设定为(480,480)的灰度图像(注意:这里输入是单通道的灰度图)

然后经过成对的3*3卷积,将图像的深度加深,变成维度为(64,480,480),这里因为图像的size没有变,又因为kernel_size = 3,stride = 1,因此需要保证padding = 1

接下来是下采样层,先经过一个最大池化层,stride = 2,kernel_size = 2 将图像的size变为原来的一半。然后接两个3*3 的卷积,输出的特征图维度是(128,240,240)

下采样层总共有四次,根据每次下采样都会将图像的size减半,图像的channel翻倍来计算的话。最后一次图像的size = 480 / (2^4) = 30 ,channel = 64 * (2^4) = 1024 ,所以最后一次下采样图像的维度为(1024,30,30)------> 这里和图上不一样,因为后面用的是转置卷积

左边的下采样部分实现后,就是右边的上采样部分

上采样会使图像的channel减半,size变为两倍,正好和下采样的部分反过来。这里利用的操作是转置卷积,转置卷积具体的实现这里不做介绍,主要看它的维度变换。转置卷积变换的公式为:

这里为了保证图像的size变为两倍,所以要保证 out = 2 * in ,而in的系数2只能从stride来,所以公式变为out = 2 * in - 2 - 2 * padding + ksize ,这里我们让ksize = 2,因此padding = 0 就可以满足要求。而channel的减半只需要把卷积核的个数减半即可

之前介绍过,最后一层的维度是(1024,30,30),这样通过转置卷积的操作图像的维度就变成了(512,60,60),刚好等于左边下采样的维度!! 所以将它们加在一块,然后进行成对的3*3卷积

之后就是和下采样的次数一样,重复四次上采样,直到将图像还原成(64,480,480)

最后一步,如果是图像分类的话,这里应该是全连接层找最大的预测值了。但是Unet是像素点的分类,所以最后产生的也是一副图像,因为这时候图像的size已经是480不需要变了,只需要将图像的channel改变,所以这里只需要一个kernel_size = 1的卷积核就可以了。

注:最后输出图像的维度是(480,480)的灰度图像,准确的说是二值图像

3. u-net 网络搭建3.1 DoubleConv

观察unet 网络可以发现,3*3的卷积核都是成对出现的,所以这里将成对卷积核的操作封装成一个类

UNet - unet网络(program status)

1. 因为采用的是两个连续的3*3  卷积,不改变图像的size,所以这里卷积的参数要设置padding=1

2. ResNet 介绍过,BN代替Dropout 的时候,不需要Bias 

3. 最后经过ReLU 激活函数

3.2 Down 下采样

然后定义下采样的操作

1. 这里下采样采用的就是最大池化层,kernel_size = 2,padding =2 会让图像的size减半

2. 然后经过两个连续3*3 的卷积

3. 将 下采样+两个3*3 的卷积 封装成一个新的类Down

3.3 Up 上采样

然后是定义上采样

 

1. 上采样用的是转置卷积,会将图像的size扩大两倍

2.  注意这里不是定义成 Sequential ,因为 Sequential 会从上到下顺序传播。这里还需要一步尺度融合,就是拼接的操作

3. 前向传播的时候,图像首先上采样,会将channel减小一半,size扩大两倍。这样就和左边对应的下采样的位置维度一致,将它们通过torch.cat 拼接,dim = 1是因为batch的维度是0 。然后经过两个3*3 的卷积就行了

3.4 网络输出

最后网络的输出很简单,经过一个1*1 的卷积核,不改变size的情况下。通过卷积核的个数调整图像的channel就行了

3.5 UNet 网络UNet 网络

网络的框架很简单,因为每个小的模块已经搭好了,将它们拼接起来就行了

因为搭建小的模块的时候,我们对于模块的输入都是in和out channel,所以在定义网络的时候,每个模块只要传入对应的channel就行了。

这里按照UNet 网络的框架设置

forward  前向传播

前向传播的过程如下:

在下采样的时候,每个输出都要用变量保存,为了和后面上采样拼接使用

3.6 网络的参数# 计算 UNet 的网络参数个数model = UNet(in_channels=1,num_classes=1)print("Total number of paramerters in networks is {} ".format(sum(x.numel() for x in model.parameters())))

UNet 网络参数个数为:

4. 完整代码

代码:

import torch.nn as nnimport torch# 搭建unet 网络class DoubleConv(nn.Module): # 连续两次卷积 def __init__(self,in_channels,out_channels): super(DoubleConv,self).__init__() self.double_conv = nn.Sequential( nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False), # 3*3 卷积核 nn.BatchNorm2d(out_channels), # 用 BN 代替 Dropout nn.ReLU(inplace=True), # ReLU 激活函数 nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self,x): # 前向传播 x = self.double_conv(x) return xclass Down(nn.Module): # 下采样 def __init__(self,in_channels,out_channels): super(Down, self).__init__() self.downsampling = nn.Sequential( nn.MaxPool2d(kernel_size=2,stride=2), DoubleConv(in_channels,out_channels) ) def forward(self,x): x = self.downsampling(x) return xclass Up(nn.Module): # 上采样 def __init__(self, in_channels, out_channels): super(Up,self).__init__() self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积 self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.upsampling(x1) x = torch.cat([x2, x1], dim=1) # 从channel 通道拼接 x = self.conv(x) return xclass OutConv(nn.Module): # 最后一个网络的输出 def __init__(self, in_channels, num_classes): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1) def forward(self, x): return self.conv(x)class UNet(nn.Module): # unet 网络 def __init__(self, in_channels = 1, num_classes = 1): super(UNet, self).__init__() self.in_channels = in_channels # 输入图像的channel self.num_classes = num_classes # 网络最后的输出 self.in_conv = DoubleConv(in_channels, 64) # 第一层 self.down1 = Down(64, 128) # 下采样过程 self.down2 = Down(128, 256) self.down3 = Down(256, 512) self.down4 = Down(512, 1024) self.up1 = Up(1024, 512) # 上采样过程 self.up2 = Up(512, 256) self.up3 = Up(256, 128) self.up4 = Up(128, 64) self.out_conv = OutConv(64, num_classes) # 网络输出 def forward(self, x): # 前向传播 输入size为 (10,1,480,480),这里设置batch = 10 x1 = self.in_conv(x) # torch.Size([10, 64, 480, 480]) x2 = self.down1(x1) # torch.Size([10, 128, 240, 240]) x3 = self.down2(x2) # torch.Size([10, 256, 120, 120]) x4 = self.down3(x3) # torch.Size([10, 512, 60, 60]) x5 = self.down4(x4) # torch.Size([10, 1024, 30, 30]) x = self.up1(x5, x4) # torch.Size([10, 512, 60, 60]) x = self.up2(x, x3) # torch.Size([10, 256, 120, 120]) x = self.up3(x, x2) # torch.Size([10, 128, 240, 240]) x = self.up4(x, x1) # torch.Size([10, 64, 480, 480]) x = self.out_conv(x) # torch.Size([10, 1, 480, 480]) return x# 计算 UNet 的网络参数个数model = UNet(in_channels=1,num_classes=1)print("Total number of paramerters in networks is {} ".format(sum(x.numel() for x in model.parameters())))
本文链接地址:https://www.jiuchutong.com/zhishi/298952.html 转载请保留说明!

上一篇:【JavaScript】五个常用功能/案例:判断特定结尾字符串 | 获取指定字符串 | 颜色字符串转换 | 字符串转驼峰格式 | 简易购物车(javascripts)

下一篇:Swagger-的使用(详细教程)

  • 国家反诈中心短信预警怎么开启(国家反诈中心短信提醒怎么回事?)

    国家反诈中心短信预警怎么开启(国家反诈中心短信提醒怎么回事?)

  • 芒果TV的播放模式在哪里设置(芒果tv播放清晰度在哪调)

    芒果TV的播放模式在哪里设置(芒果tv播放清晰度在哪调)

  • excel怎么快速插入多列(excel如何快速插行快捷键)

    excel怎么快速插入多列(excel如何快速插行快捷键)

  • s6微信视频支持美颜不(s6微信视频美颜怎么设置)

    s6微信视频支持美颜不(s6微信视频美颜怎么设置)

  • 转转保证金能退吗(转转保证金200退到哪里)

    转转保证金能退吗(转转保证金200退到哪里)

  • 苹果绑定支付宝付款方式被拒(苹果绑定支付宝付款怎么解除)

    苹果绑定支付宝付款方式被拒(苹果绑定支付宝付款怎么解除)

  • 苹果笔记本桌面图标不见了怎么办(苹果笔记本桌面图标怎么放到桌面)

    苹果笔记本桌面图标不见了怎么办(苹果笔记本桌面图标怎么放到桌面)

  • 手机基带坏了能修吗(手机基带坏了能更新系统吗)

    手机基带坏了能修吗(手机基带坏了能更新系统吗)

  • 显示器有必要2k吗(显示器有必要2k屏幕吗)

    显示器有必要2k吗(显示器有必要2k屏幕吗)

  • 苹果x和xs电池容量(苹果x跟xs电池容量)

    苹果x和xs电池容量(苹果x跟xs电池容量)

  • iphonex能控制空调吗(苹果x如何控制空调)

    iphonex能控制空调吗(苹果x如何控制空调)

  • ipad上面的小孔是干嘛的(ipad边上的小孔)

    ipad上面的小孔是干嘛的(ipad边上的小孔)

  • mouse1是哪个键(mouse1是哪个键盘按键)

    mouse1是哪个键(mouse1是哪个键盘按键)

  • 拨打的用户已关机是什么意思(拨打的用户已关机或不在服务区是什么意思)

    拨打的用户已关机是什么意思(拨打的用户已关机或不在服务区是什么意思)

  • 抖音保存本地视频在哪(抖音保存本地视频在哪设置)

    抖音保存本地视频在哪(抖音保存本地视频在哪设置)

  • 抖音私信有没有已读功能(抖音私信有没有自动回复)

    抖音私信有没有已读功能(抖音私信有没有自动回复)

  • 如何升级电脑操作系统(如何升级电脑操作软件)

    如何升级电脑操作系统(如何升级电脑操作软件)

  • 手机虚拟卡什么意思(手机虚拟卡有什么用途)

    手机虚拟卡什么意思(手机虚拟卡有什么用途)

  • qq手机文件保存位置(qq手机文件保存的位置在哪里)

    qq手机文件保存位置(qq手机文件保存的位置在哪里)

  • 咸鱼消息提醒在哪设置(闲鱼消息短信提醒)

    咸鱼消息提醒在哪设置(闲鱼消息短信提醒)

  • 短信回收站在哪(苹果手机短信回收站在哪)

    短信回收站在哪(苹果手机短信回收站在哪)

  • 复制发朋友圈为什么会折叠(为什么复制发朋友圈的信息不完全显示)

    复制发朋友圈为什么会折叠(为什么复制发朋友圈的信息不完全显示)

  • 华为p30pro玩游戏微信不提醒(华为p30pro玩游戏发热)

    华为p30pro玩游戏微信不提醒(华为p30pro玩游戏发热)

  • b站下载的视频在手机哪里(b站下载的视频怎么保存到电脑)

    b站下载的视频在手机哪里(b站下载的视频怎么保存到电脑)

  • 淘宝怎么投诉卖家电话(淘宝怎么投诉卖家怎么找到客服小蜜人工服务)

    淘宝怎么投诉卖家电话(淘宝怎么投诉卖家怎么找到客服小蜜人工服务)

  • 拉斯梅德拉斯的古罗马金矿遗址,西班牙莱昂 (© DEEPOL by plainpicture/David Santiago Garcia)(梅拉和艾斯德斯)

    拉斯梅德拉斯的古罗马金矿遗址,西班牙莱昂 (© DEEPOL by plainpicture/David Santiago Garcia)(梅拉和艾斯德斯)

  • 【CSS】课程网站 网格商品展示 模块制作 ③ ( 清除浮动需求 | 没有设置高度的盒子且内部设置了浮动 | 使用双伪元素清除浮动 )(css教程网站)

    【CSS】课程网站 网格商品展示 模块制作 ③ ( 清除浮动需求 | 没有设置高度的盒子且内部设置了浮动 | 使用双伪元素清除浮动 )(css教程网站)

  • bind命令  显示或设置键盘按键与其相关的功能(bind函数错误)

    bind命令 显示或设置键盘按键与其相关的功能(bind函数错误)

  • 广东省增值税发票勾选平台
  • 小规模纳税人公户的钱怎么转出来
  • 财务人员如何管理固定资产
  • 收取境外服务费收入如何开票
  • 合伙企业法人股东
  • 筹建期间购买的机械配件
  • 关于年底双薪和分红问题的处理
  • 合伙企业应纳税所得额公式是什么
  • 进项税当月申报怎么申报
  • 应付账款借方余额在资产负债表中怎么列示
  • 生产出口退税企业内部加工费占多少比例
  • 建筑企业劳务分包税务筹划
  • 异地上班员工报销路费
  • 代收代付如何进行账务处理?
  • 新产品开发费用怎么扣除
  • 给了钱不给发票可以报警吗
  • 季度企业所得税计算方法举例
  • 母子公司间的借款现金流计入哪里
  • 怎么给个体户开电子发票
  • 计算企业所得税的公式
  • 固定资产评估增值
  • 雇佣临时工发生意外能追房东赔偿吗
  • 一般纳税人为其他公司制作标书怎么缴税?
  • 会员卡系统多少钱一套
  • 科目余额表借贷方余额不一致
  • 如何以快捷方式打印文件
  • 原始股东减持要交多少税
  • 购买的烟酒怎么入账科目
  • 做汽车配件销售怎么找客户
  • 如何输入特殊符号带圈数字11
  • 广告公司的设计服务费计入什么科目
  • 借款利息税前扣除需要发票
  • zendstudio怎么创建php项目
  • php基于正则批量输出
  • 公司股东年底分红怎么做账
  • win11的截屏
  • 防抖节流实现原理
  • laravel中间件是什么意思
  • 计提工会经费的标准
  • 慈善组织接受股票捐赠流程
  • 小规模纳税人能开6%增值税专用发票吗
  • python中如何合并csv
  • dom事件种类
  • 公司对自己内部的要求
  • 融资租赁手续费一次性还是摊销
  • 现金发放工资会扣税吗
  • 分成收入计入什么科目
  • 个税年度汇算清缴总结
  • 分页存储的优缺点
  • 结转上年
  • 挂靠设计公司费用标准如何记账?
  • 土地使用税减免税优惠
  • 防伪税控技术维护费普通发票怎么申报
  • 进项税额转出结转还是红冲
  • 跨年多计提的工资如何处理
  • 金税盘离线开票时间超限的处理方法
  • 多交增值税不能抵扣
  • 记账凭证烂了要紧吗
  • mysql函数大全以及举例
  • xp系统关机界面设置
  • windows常用功能
  • vmware vnc连接
  • Ubuntu操作系统安全维护
  • linux中添加用户和组的操作
  • windows下键盘不能用
  • gsicon.exe是什么进程 作用是什么 gsicon进程查询
  • -f linux命令
  • 在linux操作系统中,/etc/rc.d/init.d
  • win10使用ie8
  • 原生js制作日历软件
  • 如何用python画花瓣
  • pm2启动nodejs
  • js小数计算精度问题
  • jquery修改css
  • 如何在国税网上做企业会计制度备
  • 个人自行申报纳税
  • 个体户定额纳税
  • 财政法和经济法的关系
  • 武汉办房产证契税怎么交
  • 跨区域涉税事项报告表
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设