位置: IT常识 - 正文

图像分割之U-Net、U2-Net及其Pytorch代码构建(图像分割 unet)

编辑:rootadmin
图像分割之U-Net、U2-Net及其Pytorch代码构建 图像分割之U-Net、U2 -Net及其Pytorch代码构建1、图像分割

推荐整理分享图像分割之U-Net、U2-Net及其Pytorch代码构建(图像分割 unet),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:图像分割的经典算法fcn,图像分割的经典算法fcn,图像分割ncut,图像分割csdn,图像分割otsu,图像分割otsu,图像分割otsu,图像分割 unet,内容如对您有帮助,希望把文章链接给更多的朋友!

图像分割就是把图像分成若干个特定的、具有独特性质的区域并提出感兴趣目标的技术和过程。

做法便是对图片中的每一个像素进行分类。

在自动驾驶、自动抠图、医疗影像等领域有着比较广泛的应用。

图像分割大致可分为以下三类:

普通分割:将不同分属不同物体的像素区域分开。比如前景和背景分割开,狗的区域和猫的区域与背景分割开。语义分割:在普通分割的基础上,分类出每一块区域的语义(即这块区域是什么物体)。如把画面中的所有物体都指出他们各自的类别。实例分割:在语义分割的基础上,给每一个物体编号。如这个是该画面中的狗A,那个是画面中的狗B。普通分割语义分割实例分割

可以看出,图像分割是由一张图片到另一张图片。因此,神经网络的输入是图片,输出也是同样的图片,Encoder-Decoder的结构是合适的。U-Net、U2 -Net可作为语义分割使用,可以按照生成图像的方式,生成分割图。也可以按通道划分类,每一个通道就是一个类别,使用sigmoid激活。

2、U-Net

U-Net即使用Encoder-Decoder的结构,首先下采样,然后上采样,中间每一级由残差组成。

则可构建网络的代码如下:

首先是卷积层,可以看出,网络在每一级,均有两层卷积组成。因此构建卷积层如下:

from torch import nnimport torchclass ConvolutionLayer(nn.Module): def __init__(self, in_channels, out_channels): """ 卷积层 :param in_channels: 输入通道 :param out_channels: 输出通道 """ super(ConvolutionLayer, self).__init__() self.layer = nn.Sequential( # 卷积层 nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False), nn.BatchNorm2d(out_channels), # BN层 nn.ReLU(), # 激活 nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) def forward(self, x): return self.layer(x)

同时与图示不同的地方在于,使用了Padding,以免图片在卷积中的尺寸缩小。这样,横向的灰色箭头可以直接使用cat进行两个特征图的拼接。

模型图中,红色箭头的max pool 2×2,使用的是池化窗口为2×2的最大值池化。这里的目的是进行下采样,因此可以定义一个下采样如下:

class DownSample(nn.Module): def __init__(self,): """ 最大池化层构成的下采样,池化窗口为2×2 """ super(DownSample, self).__init__() self.layer = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): return self.layer(x)

模型图中,绿色箭头的up-conv 2×2,使用的是反卷积。这里的目的是进行上采样,因此可以定义一个上采样如下:

class UpSample(nn.Module): def __init__(self, in_channels): """ 反卷积,上采样,通道数将会减半, :param in_channels: 输入通道数 """ super(UpSample, self).__init__() self.layer = nn.Sequential( nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=(2, 2), stride=(2, 2)), nn.LeakyReLU(), ) def forward(self, x): return self.layer(x)

首先定义各个网络层:

class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() self.conv1 = ConvolutionLayer(in_channels, 64) # 三通道拓展至64通道 self.down1 = DownSample() # 下采样至1/2 self.conv2 = ConvolutionLayer(64, 128) # 64通道==>128通道 self.down2 = DownSample() # 下采样至1/4 self.conv3 = ConvolutionLayer(128, 256) # 128通道==>256通道 self.down3 = DownSample() # 下采样至1/8 self.conv4 = ConvolutionLayer(256, 512) # 256通道==>512通道 self.down4 = DownSample() # 下采样至1/16 self.conv5 = ConvolutionLayer(512, 1024) # 512通道==>1024通道 self.up1 = UpSample(1024) # 上采样至1/8 self.conv6 = ConvolutionLayer(1024, 512) # 1024通道==>512通道 self.up2 = UpSample(512) # 上采样至1/4 self.conv7 = ConvolutionLayer(512, 256) # 512通道==>256通道 self.up3 = UpSample(256) # 上采样至1/2 self.conv8 = ConvolutionLayer(256, 128) # 256通道==>128通道 self.up4 = UpSample(128) # 上采样至1/1 self.conv9 = ConvolutionLayer(128, 64) # 128通道==>64通道 self.predict = nn.Sequential( # 输出层,由sigmoid函数激活 nn.Conv2d(64, out_channels, kernel_size=(3,3), stride=(1,1), padding=1), nn.Sigmoid() ) def forward(self, image_tensor):pass

对应于模型图如下:

class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() """ ...... """ def forward(self, x): """下采样""" x1 = self.conv1(x) # ===> 1/1 64 d1 = self.down1(x1) # ===> 1/2 64 x2 = self.conv2(d1) # ===> 1/2 128 d2 = self.down2(x2) # ===> 1/4 128 x3 = self.conv3(d2) # ===> 1/4 256 d3 = self.down3(x3) # ===> 1/8 256 x4 = self.conv4(d3) # ===> 1/8 512 d4 = self.down4(x4) # ===> 1/16 512 x5 = self.conv5(d4) # ===> 1/16 1024 """上采样""" up1 = self.up1(x5) # ===> 1/8 512 x6 = self.conv6(torch.cat((x4, up1), dim=1)) # ===> 1/8 512 up2 = self.up2(x6 # ===> 1/4 256 x7 = self.conv7(torch.cat((x3, up2), dim=1)) # ===> 1/4 256 up3 = self.up3(x7) # ===> 1/2 128 x8 = self.conv8(torch.cat((x2, up3), dim=1)) # ===> 1/2 128 up4 = self.up4(x8) # ===> 1/1 64 x9 = self.conv9(torch.cat((x1, up4), dim=1)) # ===> 1/1 64 mask = self.predict(x9) # ===> 1/1 out_channels return mask

以一张512×512的3通道图片为例,其张量的形状为(1,3,512,512),经过conv1得到x1 (1, 64, 512, 512),下采样至(1, 64, 256, 256);经过conv2得到x2 (1, 128, 256, 256),下采样至(1, 128, 128, 128);经过conv3得到x3 (1, 256, 128, 128),下采样至(1, 256, 64, 64);经过conv4得到x4 (1, 512, 64, 64),下采样至(1, 512, 32, 32);经过conv5得到x5 (1, 1024, 32, 32)。下采样过程完成,开始上采样还原至原始图片大小。

x5经过up1得到up1 (1, 512, 64, 64),同x4 拼接(cat)在一起 组成(1, 1024, 64, 64)的张量,经过conv6得到x6(1, 512, 64, 64);

x6经过up2得到up2 (1, 256, 128, 128),同x3 拼接在一起 组成(1, 512, 128, 128)的张量,经过conv7得到x7(1, 256, 128, 128);

x7经过up3得到up3 (1, 128, 256, 256),同x2 拼接在一起 组成(1, 256, 256, 256)的张量,经过conv8得到x8(1, 128, 256, 256);

x8经过up4得到up4 (1, 64, 512, 512),同x1 拼接在一起 组成(1, 128, 512, 512)的张量,经过conv6得到x9(1, 64, 512, 512);

最后,x9经过预测层predict输出,得到分割图mask。

以drive数据集为例训练网络,数据示例如下。

标签如下:

输入数据为3通道的图片,而输出数据为1通道的二值图。一张图片的原始尺寸是565×584

可以在原始图像中随机裁剪256×256大小的图片,进行训练,而在使用时,图像尺寸只要是16的倍数即可。

定义数据加载函数如下:

import torchimport randomimport cv2from torch.utils.data import Datasetclass DriveDataset(Dataset): def __init__(self,root='data/training'): super(DriveDataset, self).__init__() self.dataset = [] start = 20 for i in range(1, 21): # 按照一一对应的原则,加载图像和标签的路径 image_path = f'{root}/images/{i+start}_training.tif' label_path = f'{root}/1st_manual/{i + start}_manual1.gif' self.dataset.append((image_path, label_path)) def __len__(self): return len(self.dataset) def __getitem__(self, item): image_path, label_path = self.dataset[item] # 获取图像路径 image = cv2.imread(image_path) # 图片 video = cv2.VideoCapture(label_path) _, mask_label = video.read() # 读取标签掩码图 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) mask_label = cv2.cvtColor(mask_label, cv2.COLOR_BGR2GRAY) # 转换至单通道图 """随即裁剪256×256的图幅,图片和标签裁剪相同的位置""" h, w = mask_label.shape w = random.randint(0, w-256) h = random.randint(0, h-256) image = image[h:h+256, w:w+256] mask_label = mask_label[h:h + 256, w:w + 256] """转换至tensor""" image = torch.from_numpy(image).float().permute(2, 0, 1)/255 mask_label = torch.from_numpy(mask_label).unsqueeze(0).float()/255 return image, mask_label

读取相对应的图片和标签,转换为张量,供网络学习。其中,标签的读取使用了OpenCV的视频捕获(VideoCapture)读取首帧完成标签的数据加载。

定义训练器如下:

from torch import nnimport torchfrom torch.utils.data import DataLoaderfrom torchvision.utils import save_imagefrom u_net import UNetfrom dataset import DriveDatasetimport osclass Trainer: def __init__(self): self.device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") # 设置设备 self.net = UNet(3, 1).to(self.device) # 实例U-Net if os.path.exists('unet.pth'): # 加载权重,如果存在的话 self.net.load_state_dict(torch.load('unet.pth', map_location='cpu')) self.dataset = DriveDataset() # 实例数据集 self.data_loader = DataLoader(self.dataset, 3, True, drop_last=True) # 实例数据加载器 self.loss_func = nn.BCELoss() # 实例二值交叉熵 self.optimizer = torch.optim.Adam(self.net.parameters()) # 实例adam优化器 def train(self): # 训练 for epoch in range(100000): # 迭代epoch for i, (image, target) in enumerate(self.data_loader): image = image.to(self.device) target = target.to(self.device) out = self.net(image) # 预测 loss = self.loss_func(out, target) # 计算损失 self.optimizer.zero_grad() # 清空梯度 loss.backward() # 反向传播 self.optimizer.step() # 优化 print(epoch, loss.item()) if epoch % 5 == 0: torch.save(self.net.state_dict(),'unet.pth') save_image([image[0], target[0].expand(3, 256, 256), out[0].expand(3, 256, 256)], f'{epoch}.jpg',normalize=True,range=(0,1))

二值交叉熵做损失,adam优化器优化网络。

class Trainer: """ ...... """if __name__ == '__main__': trainer = Trainer() trainer.train()

训练过程见下图。左边为原图,中间为标签,右边为网络预测值

epochimages012图像分割之U-Net、U2-Net及其Pytorch代码构建(图像分割 unet)

完整代码:https://github.com/HibikiJie/UNetAndU2Net

3、U2-Net

而U2-Net,就是U-Net的堆叠,类似于,将U-Net中的conv块,替换成完整的U-Net网络。

其网络图如下:

其中EN_1与De_1一致,EN_2与De_2一致,EN_3与De_3一致,EN_4与De_4一致,EN_5、En6和De_5一致。

先分别定义,EN_1、EN_2、EN_3、EN_4、EN_5为UNet1、UNet2、UNet3、UNet4、UNet5.

首先定义UNet1:

注意到,图中的白色的方块示意的,卷积使用到了dilation参数,因此,定义ConvolutionLayer为:

import torchimport torch.nn as nnimport torch.nn.functional as Fclass ConvolutionLayer(nn.Module): def __init__(self, in_channels, out_channels, dilation=1): super(ConvolutionLayer, self).__init__() self.layer = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1 * dilation, dilation=(1 * dilation, 1 * dilation)), # 卷积 nn.BatchNorm2d(out_channels), # BN nn.ReLU(inplace=True) # 激活函数 ) def forward(self, x): return self.layer(x)

卷积层由Conv、BN、ReLU构成。

上采样使用机器学习算法,由双线性插值法完成上采样:

def upsample_like(src, tar): src = F.upsample(src, size=tar.shape[2:], mode='bilinear') return src

该方法,将使src上采样至tar相同的尺寸大小。

而下采样同样使用最大池化完成,这里可以使用与U-Net相同的代码。

因此,UNet1:

class UNet1(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet1, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down3 = DownSample() self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down4 = DownSample() self.conv5 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down5 = DownSample() self.conv6 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv7 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv8 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv9 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv10 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv11 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv12 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv13 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): """下采样,编码encode的过程""" x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) d3 = self.down3(x3) x4 = self.conv4(d3) d4 = self.down4(x4) x5 = self.conv5(d4) d5 = self.down5(x5) x6 = self.conv6(d5) x7 = self.conv7(x6)"""上采样,解码decode的过程""" x8 = self.conv8(torch.cat((x7, x6), dim=1)) up1 = upsample_like(x8, x5) x9 = self.conv9(torch.cat((up1, x5), dim=1)) up2 = upsample_like(x9, x4) x10 = self.conv10(torch.cat((up2, x4), dim=1)) up3 = upsample_like(x10, x3) x11 = self.conv11(torch.cat((up3, x3), dim=1)) up4 = upsample_like(x11, x2) x12 = self.conv12(torch.cat((up4, x2), dim=1)) up5 = upsample_like(x12, x1) x13 = self.conv13(torch.cat((up5, x1), dim=1)) return x13 + x0

按照上图所示的方式编码,可见,与写UNet的代码是非常类似的。可以对比着看。可见,U2-Net是U-Net的堆叠。

于是类似的,UNet2的代码为:

class UNet2(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet2, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down3 = DownSample() self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down4 = DownSample() self.conv5 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv6 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv7 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv8 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv9 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv10 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv11 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): """encode""" x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) d3 = self.down3(x3) x4 = self.conv4(d3) d4 = self.down4(x4) x5 = self.conv5(d4) x6 = self.conv6(x5)"""decode""" x7 = self.conv7(torch.cat((x6, x5), dim=1)) up1 = upsample_like(x7, x4) x8 = self.conv8(torch.cat((up1, x4), dim=1)) up2 = upsample_like(x8, x3) x9 = self.conv9(torch.cat((up2, x3), dim=1)) up3 = upsample_like(x9, x2) x10 = self.conv10(torch.cat((up3, x2), dim=1)) up4 = upsample_like(x10, x1) x11 = self.conv11(torch.cat((up4, x1), dim=1)) return x11 + x0

UNet3为:

class UNet3(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet3, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down3 = DownSample() self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv5 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv6 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv7 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv8 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv9 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): """encode""" x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) d3 = self.down3(x3) x4 = self.conv4(d3) x5 = self.conv5(x4)"""decode""" x6 = self.conv6(torch.cat((x5, x4), dim=1)) up1 = upsample_like(x6, x3) x7 = self.conv7(torch.cat((up1, x3), dim=1)) up2 = upsample_like(x7, x2) x8 = self.conv8(torch.cat((up2, x2), dim=1)) up3 = upsample_like(x8, x1) x9 = self.conv9(torch.cat((up3, x1), dim=1)) return x9 + x0

UNet4为:

class UNet4(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet4, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv5 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv6 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv7 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): """encode""" x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) x4 = self.conv4(x3) """decode""" x5 = self.conv5(torch.cat((x4, x3), dim=1)) up1 = upsample_like(x5, x2) x6 = self.conv6(torch.cat((up1, x2), dim=1)) up2 = upsample_like(x6, x1) x7 = self.conv7(torch.cat((up2, x1), dim=1)) return x7 + x0

UNet5为:

class UNet5(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet5, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=4) self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=8) self.conv5 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=4) self.conv6 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=2) self.conv7 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): x0 = self.conv0(x) x1 = self.conv1(x0) x2 = self.conv2(x1) x3 = self.conv3(x2) x4 = self.conv4(x3) x5 = self.conv5(torch.cat((x4, x3), dim=1)) x6 = self.conv6(torch.cat((x5, x2), dim=1)) x7 = self.conv7(torch.cat((x6, x1), dim=1)) return x7 + x0

于是将

UNet1、UNet2、UNet3、UNet4、UNet5.组装成为U2-Net

再看一下网络结构图:

其中EN_1与De_1一致,EN_2与De_2一致,EN_3与De_3一致,EN_4与De_4一致,EN_5、En6和De_5一致。

先分别定义,EN_1、EN_2、EN_3、EN_4、EN_5为UNet1、UNet2、UNet3、UNet4、UNet5.

于是EN_1与De_1使用UNet1;

EN_2与De_2使用UNet2;

EN_3与De_3使用UNet3;

EN_4与De_4使用UNet4;

EN_5、EN_6、De_5使用UNet1。

故,构建网络U2-Net:

class U2Net(nn.Module): def __init__(self, in_channels=3, out_channels=1): super(U2Net, self).__init__() self.en_1 = UNet1(in_channels, 32, 64) self.down1 = DownSample() self.en_2 = UNet2(64, 32, 128) self.down2 = DownSample() self.en_3 = UNet3(128, 64, 256) self.down3 = DownSample() self.en_4 = UNet4(256, 128, 512) self.down4 = DownSample() self.en_5 = UNet5(512, 256, 512) self.down5 = DownSample() self.en_6 = UNet5(512, 256, 512) # decoder self.de_5 = UNet5(1024, 256, 512) self.de_4 = UNet4(1024, 128, 256) self.de_3 = UNet3(512, 64, 128) self.de_2 = UNet2(256, 32, 64) self.de_1 = UNet1(128, 16, 64) self.side1 = nn.Conv2d(64, out_channels, kernel_size=(3, 3), padding=1) self.side2 = nn.Conv2d(64, out_channels, kernel_size=(3, 3), padding=1) self.side3 = nn.Conv2d(128, out_channels, kernel_size=(3, 3), padding=1) self.side4 = nn.Conv2d(256, out_channels, kernel_size=(3, 3), padding=1) self.side5 = nn.Conv2d(512, out_channels, kernel_size=(3, 3), padding=1) self.side6 = nn.Conv2d(512, out_channels, kernel_size=(3, 3), padding=1) self.out_conv = nn.Conv2d(6, out_channels, kernel_size=(1, 1)) def forward(self, x): # ------encode ------ x1 = self.en_1(x) d1 = self.down1(x1) x2 = self.en_2(d1) d2 = self.down2(x2) x3 = self.en_3(d2) d3 = self.down3(x3) x4 = self.en_4(d3) d4 = self.down4(x4) x5 = self.en_5(d4) d5 = self.down5(x5) x6 = self.en_6(d5) up1 = upsample_like(x6, x5) # ------decode ------ x7 = self.de_5(torch.cat((up1, x5), dim=1)) up2 = upsample_like(x7, x4) x8 = self.de_4(torch.cat((up2, x4), dim=1)) up3 = upsample_like(x8, x3) x9 = self.de_3(torch.cat((up3, x3), dim=1)) up4 = upsample_like(x9, x2) x10 = self.de_2(torch.cat((up4, x2), dim=1)) up5 = upsample_like(x10, x1) x11 = self.de_1(torch.cat((up5, x1), dim=1)) # side output sup1 = self.side1(x11) sup2 = self.side2(x10) sup2 = upsample_like(sup2, sup1) sup3 = self.side3(x9) sup3 = upsample_like(sup3, sup1) sup4 = self.side4(x8) sup4 = upsample_like(sup4, sup1) sup5 = self.side5(x7) sup5 = upsample_like(sup5, sup1) sup6 = self.side6(x6) sup6 = upsample_like(sup6, sup1) sup0 = self.out_conv(torch.cat((sup1, sup2, sup3, sup4, sup5, sup6), dim=1)) return torch.sigmoid(sup0)

U2-Net完整代码:

import torchimport torch.nn as nnimport torch.nn.functional as Fclass ConvolutionLayer(nn.Module): def __init__(self, in_channels, out_channels, dilation=1): super(ConvolutionLayer, self).__init__() self.layer = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1 * dilation, dilation=(1 * dilation, 1 * dilation)), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) self.conv_s1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1 * dilation, dilation=(1 * dilation, 1 * dilation)) self.bn_s1 = nn.BatchNorm2d(out_channels) self.relu_s1 = nn.ReLU(inplace=True) def forward(self, x): return self.layer(x)def upsample_like(src, tar): src = F.interpolate(src, size=tar.shape[2:], mode='bilinear') return srcclass DownSample(nn.Module): def __init__(self, ): super(DownSample, self).__init__() self.layer = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): return self.layer(x)class UNet1(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet1, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down3 = DownSample() self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down4 = DownSample() self.conv5 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down5 = DownSample() self.conv6 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv7 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv8 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv9 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv10 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv11 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv12 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv13 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) d3 = self.down3(x3) x4 = self.conv4(d3) d4 = self.down4(x4) x5 = self.conv5(d4) d5 = self.down5(x5) x6 = self.conv6(d5) x7 = self.conv7(x6) x8 = self.conv8(torch.cat((x7, x6), 1)) up1 = upsample_like(x8, x5) x9 = self.conv9(torch.cat((up1, x5), 1)) up2 = upsample_like(x9, x4) x10 = self.conv10(torch.cat((up2, x4), 1)) up3 = upsample_like(x10, x3) x11 = self.conv11(torch.cat((up3, x3), 1)) up4 = upsample_like(x11, x2) x12 = self.conv12(torch.cat((up4, x2), 1)) up5 = upsample_like(x12, x1) x13 = self.conv13(torch.cat((up5, x1), 1)) return x13 + x0class UNet2(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet2, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down3 = DownSample() self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down4 = DownSample() self.conv5 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv6 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv7 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv8 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv9 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv10 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv11 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) d3 = self.down3(x3) x4 = self.conv4(d3) d4 = self.down4(x4) x5 = self.conv5(d4) x6 = self.conv6(x5) x7 = self.conv7(torch.cat((x6, x5), dim=1)) up1 = upsample_like(x7, x4) x8 = self.conv8(torch.cat((up1, x4), dim=1)) up2 = upsample_like(x8, x3) x9 = self.conv9(torch.cat((up2, x3), dim=1)) up3 = upsample_like(x9, x2) x10 = self.conv10(torch.cat((up3, x2), dim=1)) up4 = upsample_like(x10, x1) x11 = self.conv11(torch.cat((up4, x1), dim=1)) return x11 + x0class UNet3(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet3, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down3 = DownSample() self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv5 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv6 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv7 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv8 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv9 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) d3 = self.down3(x3) x4 = self.conv4(d3) x5 = self.conv5(x4) x6 = self.conv6(torch.cat((x5, x4), 1)) up1 = upsample_like(x6, x3) x7 = self.conv7(torch.cat((up1, x3), 1)) up2 = upsample_like(x7, x2) x8 = self.conv8(torch.cat((up2, x2), 1)) up3 = upsample_like(x8, x1) x9 = self.conv9(torch.cat((up3, x1), 1)) return x9 + x0class UNet4(nn.Module): def __init__(self, in_channels, mid_channels=12, out_channels): super(UNet4, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.down1 = DownSample() self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.down2 = DownSample() self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=1) self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv5 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv6 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=1) self.conv7 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): """encode""" x0 = self.conv0(x) x1 = self.conv1(x0) d1 = self.down1(x1) x2 = self.conv2(d1) d2 = self.down2(x2) x3 = self.conv3(d2) x4 = self.conv4(x3) """decode""" x5 = self.conv5(torch.cat((x4, x3), 1)) up1 = upsample_like(x5, x2) x6 = self.conv6(torch.cat((up1, x2), 1)) up2 = upsample_like(x6, x1) x7 = self.conv7(torch.cat((up2, x1), 1)) return x7 + x0class UNet5(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(UNet5, self).__init__() self.conv0 = ConvolutionLayer(in_channels, out_channels, dilation=1) self.conv1 = ConvolutionLayer(out_channels, mid_channels, dilation=1) self.conv2 = ConvolutionLayer(mid_channels, mid_channels, dilation=2) self.conv3 = ConvolutionLayer(mid_channels, mid_channels, dilation=4) self.conv4 = ConvolutionLayer(mid_channels, mid_channels, dilation=8) self.conv5 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=4) self.conv6 = ConvolutionLayer(mid_channels * 2, mid_channels, dilation=2) self.conv7 = ConvolutionLayer(mid_channels * 2, out_channels, dilation=1) def forward(self, x): x0 = self.conv0(x) x1 = self.conv1(x0) x2 = self.conv2(x1) x3 = self.conv3(x2) x4 = self.conv4(x3) x5 = self.conv5(torch.cat((x4, x3), 1)) x6 = self.conv6(torch.cat((x5, x2), 1)) x7 = self.conv7(torch.cat((x6, x1), 1)) return x7 + x0class U2Net(nn.Module): def __init__(self, in_channels=3, out_channels=1): super(U2Net, self).__init__() self.en_1 = UNet1(in_channels, 32, 64) self.down1 = DownSample() self.en_2 = UNet2(64, 32, 128) self.down2 = DownSample() self.en_3 = UNet3(128, 64, 256) self.down3 = DownSample() self.en_4 = UNet4(256, 128, 512) self.down4 = DownSample() self.en_5 = UNet5(512, 256, 512) self.down5 = DownSample() self.en_6 = UNet5(512, 256, 512) # decoder self.de_5 = UNet5(1024, 256, 512) self.de_4 = UNet4(1024, 128, 256) self.de_3 = UNet3(512, 64, 128) self.de_2 = UNet2(256, 32, 64) self.de_1 = UNet1(128, 16, 64) self.side1 = nn.Conv2d(64, out_channels, kernel_size=(3, 3), padding=1) self.side2 = nn.Conv2d(64, out_channels, kernel_size=(3, 3), padding=1) self.side3 = nn.Conv2d(128, out_channels, kernel_size=(3, 3), padding=1) self.side4 = nn.Conv2d(256, out_channels, kernel_size=(3, 3), padding=1) self.side5 = nn.Conv2d(512, out_channels, kernel_size=(3, 3), padding=1) self.side6 = nn.Conv2d(512, out_channels, kernel_size=(3, 3), padding=1) self.out_conv = nn.Conv2d(6, out_channels, kernel_size=(1, 1)) def forward(self, x): # ------encode ------ x1 = self.en_1(x) d1 = self.down1(x1) x2 = self.en_2(d1) d2 = self.down2(x2) x3 = self.en_3(d2) d3 = self.down3(x3) x4 = self.en_4(d3) d4 = self.down4(x4) x5 = self.en_5(d4) d5 = self.down5(x5) x6 = self.en_6(d5) up1 = upsample_like(x6, x5) # ------decode ------ x7 = self.de_5(torch.cat((up1, x5), dim=1)) up2 = upsample_like(x7, x4) x8 = self.de_4(torch.cat((up2, x4), dim=1)) up3 = upsample_like(x8, x3) x9 = self.de_3(torch.cat((up3, x3), dim=1)) up4 = upsample_like(x9, x2) x10 = self.de_2(torch.cat((up4, x2), dim=1)) up5 = upsample_like(x10, x1) x11 = self.de_1(torch.cat((up5, x1), dim=1)) # side output sup1 = self.side1(x11) sup2 = self.side2(x10) sup2 = upsample_like(sup2, sup1) sup3 = self.side3(x9) sup3 = upsample_like(sup3, sup1) sup4 = self.side4(x8) sup4 = upsample_like(sup4, sup1) sup5 = self.side5(x7) sup5 = upsample_like(sup5, sup1) sup6 = self.side6(x6) sup6 = upsample_like(sup6, sup1) sup0 = self.out_conv(torch.cat((sup1, sup2, sup3, sup4, sup5, sup6), 1)) return torch.sigmoid(sup0)if __name__ == '__main__': u2net = U2Net(3, 1) x = torch.randn(1,3, 512, 512) print(u2net(x).shape) x5 = self.en_5(d4) d5 = self.down5(x5) x6 = self.en_6(d5) up1 = upsample_like(x6, x5) # ------decode ------ x7 = self.de_5(torch.cat((up1, x5), dim=1)) up2 = upsample_like(x7, x4) x8 = self.de_4(torch.cat((up2, x4), dim=1)) up3 = upsample_like(x8, x3) x9 = self.de_3(torch.cat((up3, x3), dim=1)) up4 = upsample_like(x9, x2) x10 = self.de_2(torch.cat((up4, x2), dim=1)) up5 = upsample_like(x10, x1) x11 = self.de_1(torch.cat((up5, x1), dim=1)) # side output sup1 = self.side1(x11) sup2 = self.side2(x10) sup2 = upsample_like(sup2, sup1) sup3 = self.side3(x9) sup3 = upsample_like(sup3, sup1) sup4 = self.side4(x8) sup4 = upsample_like(sup4, sup1) sup5 = self.side5(x7) sup5 = upsample_like(sup5, sup1) sup6 = self.side6(x6) sup6 = upsample_like(sup6, sup1) sup0 = self.out_conv(torch.cat((sup1, sup2, sup3, sup4, sup5, sup6), 1)) return torch.sigmoid(sup0)if __name__ == '__main__': u2net = U2Net(3, 1) x = torch.randn(1,3, 512, 512) print(u2net(x).shape)
本文链接地址:https://www.jiuchutong.com/zhishi/287058.html 转载请保留说明!

上一篇:前后端RSA互相加解密、加签验签、密钥对生成(Java)(rsa前端解密)

下一篇:fetch的基本用法、请求参数及响应结果(fetchall的用法)

  • 抖音神秘人套装怎么设置(抖音神秘人套装多少钱)

    抖音神秘人套装怎么设置(抖音神秘人套装多少钱)

  • iphone14和13区别是什么?(苹果14和13的区别)

    iphone14和13区别是什么?(苹果14和13的区别)

  • 三星i929评测(三星i929论坛)(三星i927)

    三星i929评测(三星i929论坛)(三星i927)

  • ip协议版本选哪个(ip协议类型怎么选)

    ip协议版本选哪个(ip协议类型怎么选)

  • 京东本地仓怎么搜索(京东本地仓怎么加盟)

    京东本地仓怎么搜索(京东本地仓怎么加盟)

  • 打印机墨盒在哪个位置(canon打印机墨盒在哪)

    打印机墨盒在哪个位置(canon打印机墨盒在哪)

  • 手滑点赞马上取消别人能看到吗(手滑点赞马上取消别人能看到吗小红书)

    手滑点赞马上取消别人能看到吗(手滑点赞马上取消别人能看到吗小红书)

  • 怎样把a3版面调成a4(a3版面怎么设置)

    怎样把a3版面调成a4(a3版面怎么设置)

  • 魅族17多少倍变焦(魅族17pro最高支持多少倍变焦)

    魅族17多少倍变焦(魅族17pro最高支持多少倍变焦)

  • 配电室直流屏的作用(配电室直流屏的作用,故障报警如何排除)

    配电室直流屏的作用(配电室直流屏的作用,故障报警如何排除)

  • 为什么beats耳机蓝牙搜不到(为什么beats耳机电脑蓝牙搜不到)

    为什么beats耳机蓝牙搜不到(为什么beats耳机电脑蓝牙搜不到)

  • 华为mate30反向充电怎么用(华为mate30反向充电位置)

    华为mate30反向充电怎么用(华为mate30反向充电位置)

  • a57有没有收音机(oppo a52收音机在哪里)

    a57有没有收音机(oppo a52收音机在哪里)

  • 存储器有哪两种(存储器有哪两种存储方式)

    存储器有哪两种(存储器有哪两种存储方式)

  • 荣耀play可以升级emui10吗(荣耀play可以升级magic ui系统吗)

    荣耀play可以升级emui10吗(荣耀play可以升级magic ui系统吗)

  • ipad如何设置小白点(iPad如何设置小窗口)

    ipad如何设置小白点(iPad如何设置小窗口)

  • 苹果手机怎么屏蔽骚扰电话和短信息(苹果手机怎么屏幕变暗了)

    苹果手机怎么屏蔽骚扰电话和短信息(苹果手机怎么屏幕变暗了)

  • 云文档怎么用(云文档怎么用电脑打开)

    云文档怎么用(云文档怎么用电脑打开)

  • 苹果手机三指怎么关闭(苹果的三指模式在哪里)

    苹果手机三指怎么关闭(苹果的三指模式在哪里)

  • 华为gt2运动版和时尚版区别(华为手表gt2和运动版区别)

    华为gt2运动版和时尚版区别(华为手表gt2和运动版区别)

  • 手机木马病毒怎么清除(手机木马病毒怎么植入)

    手机木马病毒怎么清除(手机木马病毒怎么植入)

  • 怎么找回被卸载信任开发者(怎么找回被卸载的在一起交友软件图片)

    怎么找回被卸载信任开发者(怎么找回被卸载的在一起交友软件图片)

  • nfc功能vivox27手机有吗(vivox27有nfc)

    nfc功能vivox27手机有吗(vivox27有nfc)

  • 美团评价要多久才显示(美团评价多久刷新一次)

    美团评价要多久才显示(美团评价多久刷新一次)

  • 抖音视频无法下载是怎么回事(抖音视频无法下载下来)

    抖音视频无法下载是怎么回事(抖音视频无法下载下来)

  • intmon.exe进程信息查询 intmon是什么进程(进程com)

    intmon.exe进程信息查询 intmon是什么进程(进程com)

  • vue项目使用定时器每隔几秒运行一次某方法(vue定时调用方法)

    vue项目使用定时器每隔几秒运行一次某方法(vue定时调用方法)

  • PHP获取整数间的公因数和最大公因数(php获取参数值的三种方式)

    PHP获取整数间的公因数和最大公因数(php获取参数值的三种方式)

  • 增值税销项税额公式
  • 小微企业所得税优惠政策
  • 餐饮业厨房设备专票可以抵税吗
  • 合伙企业投资收益做账
  • 无形资产专利技术计入什么科目
  • 计提附加税的金额怎么算
  • 出售固定资产简易计税账务处理
  • 二手车需要交哪些额外的钱
  • 企业收到股利分红如何缴税
  • 活动策划费属于业务宣传费吗
  • 员工工资计入管理费用吗
  • 公司废品处理一定要入账吗
  • 外地发票是什么意思
  • 应收账款平均占用资金的变动额
  • 土地转让交哪些税种
  • 共同开发无形资产怎么算
  • 自产产品销售增值税
  • 拍卖的房产
  • 应纳税额减征额和减免税额一样吗
  • 关于行邮税四档调三档
  • 广告机计入什么费用
  • 可转换债券存在的问题
  • 住宿发票税率免税是怎么回事
  • 没有销项税进项要交税吗
  • 个人部分公积金可以抵扣个税吗
  • 一般纳税人附加税费减免政策
  • 1697508923
  • 外贸企业出口退税撤销申报
  • 政府装修补贴政策
  • 怎么在bios中开启硬件虚拟化
  • 调整低电量通知怎么设置
  • 土地增值税清算方法与技巧
  • 跨年期许
  • 公司收到银行转账会计分录
  • ev4是什么文件
  • 包装破损导致物品损坏
  • 其他费用的账务处理
  • 公司修好厂房老板怎么说
  • 除了正式发票还有啥
  • 企业租房费用可以计入成本吗
  • 怎么检查当年的核酸结果
  • elements table
  • 稀释性每股收益计算例题
  • php实现邮件发送
  • 本季度企业所得税
  • vue安装使用
  • 独立的分公司可以注册吗
  • linux mail命令详解
  • php windows
  • 无法报销的原因
  • 融资租赁租金计算公式有残值
  • 个体户按季申报吗
  • 利润表期初余额怎么填
  • wordpress技巧
  • 业务招待费汇算清缴怎么填表
  • 增值税进项税额转出的情况有哪些
  • 公司没有账套怎么做账
  • 转月是下个月的意思吗
  • 合并报表抵消分录的基本原理
  • 公司账户转法人个人账户需要交税吗
  • 报销差旅费如何做记账凭证
  • 个税起征点调整最新消息
  • 工程预缴税款流程
  • 一般纳税人支付的哪些增值税进项税额不能抵扣
  • 生产的半成品怎么做分录
  • win8安装虚拟机的步骤
  • 进程死锁原因
  • win10系统找不到mrt
  • win102020h2版本
  • 手把手教您安装软件
  • linux系统安装软件教程
  • 利用python进行
  • javascript数据结构
  • 黑马程序员学费多少钱2022
  • javascript类库
  • 向境外付款需要代扣代缴所得税吗
  • 国家税务总局惠州仲恺高新技术产业开发区税务局
  • 安徽省渔业管理办法第十条规定
  • 沙石开发票
  • 扶贫绩效目标申请怎么写
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设