ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch深度学习实战教程(二):UNet语义分割网络

2021-05-11 16:04:23  阅读:223  来源: 互联网

标签:kernel 教程 self Pytorch channels stride UNet True size


一、前言

本文属于Pytorch深度学习语义分割系列教程。

该系列文章的内容有:

Pytorch的基本使用
语义分割算法讲解
如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章《Pytorch深度学习实战教程(一):语义分割基础与环境搭建》。

本文的开发环境采用上一篇文章搭建好的Windows环境,环境情况如下:

开发环境:Windows

开发语言:Python3.7.4

框架版本:Pytorch1.3.0

CUDA:10.2

cuDNN:7.6.0

本文主要讲解UNet网络结构,以及相应代码的代码编写。

二、UNet网络结构

在语义分割领域,基于深度学习的语义分割算法开山之作是FCN (Fully Convolutional Networks for Semantic Segmentation), 而UNet是遵循FCN的原理,并进行了相应的改进,使其适应小样本的简单分割问题。
研究一个深度学习算法,可以先看网络结构,看懂网络结构后,再Loss计算方法、训练方法等。本文主要针对UNet的网络结构进行讲解,其他内容会在后续章节进行说明。

1.网络结构原理

UNet最早发表在2015的MICCAI会议上,4年多的时间,论文引用量已经达到了9700多次。
UNet成为了大多做医疗影像语义分割任务的baseline,同时也启发了大量研究者对于U型网络结构的研究,发表了一批基于UNet网络结构的改进方法的论文。
UNet网络结构,最主要的两个特点是:U型网络结构和Skip Connection跳层链接。
在这里插入图片描述
UNet是一个对称的网络结构,左侧为下采样,右侧为上采样。
按照功能可以将左侧的一系列下采样操作成为encoder, 将右侧的一系列上采样操作成为decoder。
Skip Connection中间四条灰色的平行线,Skip Connection就是在上采样的过程中,融合下采样过程中的Feature map。
Skip Connection用到的融合的操作也很简单,就是将feature map的通道进行叠加,俗称Concat。
Concat操作也很好理解,举个例子:一本大小为10cm10cm,厚度为3cm的书A,和一本大小为10cm10cm,厚度为4cm的书B。
将书A和书B,边缘对齐地摞在一起。这样就得到了,大小为10cm*10cm, 厚度为7cm的书,类似这种:
在这里插入图片描述
这种摞在一起的操作,就是concat。
同样道理,对于feature map, 一个大小为256x256x64的feature map,即feature map的w(宽)为256,h(高)为256,c(通道数)为64,。和一个大小为256x256x32的feature map进行concat融合,就会得到一个大小为256x256x96的feature map。
在实际使用中,Concat融合的两个feature map的大小不一定相同,例如256x256x64的feature map和240x240x32的feature map进行concat。
这时候就有两种办法:
第一种:将大256x256x64的feature map进行裁剪,裁剪为240x240x64的feature map,比如上下左右,各舍弃8pixel,裁剪后再进行concat,得到240x240x96的feature map。
第二种:将小240x240x32的feature map进行padding操作,padding为256x256x32的feature map,比如上下左右,各补8pixel,padding后在进行concat,得到256x256x96的feature map。
UNet采用的concat方案就是第二种,将小的feature map进行padding,padding的方式是补0,一种常规的常量填充。

2.代码

我们将整个UNet网络拆分为多个模块进行讲解。
DoubleConv模块:
先看下连续两次的卷积操作。

在这里插入图片描述
从UNet网络中可以看出,不管是下采样过程还是上采样过程,每一层都会连续进行两次卷积操作,这种操作在UNet网络中重复很多次,可以单独写一个DoubleConv模块:

import torch.nn as nn

class DoubleConv(nn.Module):
	“”“(convolution => [BN] => ReLU) * 2”“”
	def __init__(self, in_channels, out_channels):
		super().__init__()
		self.double_conv = nn.Sequential(
		nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 0),
		nn.BatchNorm2d(out_channels),
		nn.ReLU(inplace = True),
		nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 0),
		nn.BatchNorm2d(out_channels),
		nn.ReLU(inplace = True)
		)
	def forward(self, x):
		return self.double_conv(x)
		

解释下,上述的pytorch代码,torch.nn.Sequential是一个时序容器,Modules会以它们传入的顺序被添加到容器中。比如上述代码的操作顺序:卷积->BN->ReLU->卷积->BN->ReLU
DoubleConv模块的in_channels和out_channels可以灵活设定,以便扩展使用。
如上图所示的网络,in_channels设为1,out_channels设为64。
输入图片大小为572x572,经过步长为1,padding为0的3x3卷积,得到570x570的feature map,再经过一次卷积得到568x568的feature map。
计算公式:O= (H-F+2*P)/ S + 1
H为输入feature map大小,O为输出feature map大小,F为卷积核的大小,P为padding的大小,S为步长。

Down模块
在这里插入图片描述
UNet网络一共有4次下采样过程,模块化代码如下:

class Down(nn.Module):
	"""Downscaling with maxpool then double conv"""
	def __init__(self, in_channels, out_channels):
		super.__init__():
		self.maxpool_conv = nn.Sequential(
			nn.MaxPool2d(2),
			DoubleConv(in_channels, out_channels)
		)
	def forward(self, x):
		return self.maxpool_conv(x)

这里的代码很简单,就是一个maxpool池化层,进行下采样,然后接一个DoubleConv模块。
至此,UNet网络的左半部分的下采样过程的代码都写好了,接下来是右半部分的上采样过程。
Up模块:
上采样过程用到的最多的当然就是上采样了,除了常规的上采样操作,还有进行特征的融合。
在这里插入图片描述
这块的代码实现起来也稍复杂一些:

class Up(nn.Moudule):
	"""Upscaling then double conv"""
	def __init__(self, in_channels, out_channels, bilinear=True):
		super().__init__()
		# if bilinear, use the normal convolutions to reduce the number of channels
		if bilinear:
			self.up = nn.Upsample(scale_factor = 2, mode='bilinear', align_corner = True)
		else:
			self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, 				   			  stride=2)
		self.conv = DoubleConv(in_channels, out_channels)

	def forward(self, x1, x2):
		x1 = self.up(x1)
		# input is CHW
		diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
		diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
		
		x1 = F.pad(x1, [diffX // 2, diffX - diffY // 2])
		x = torch.cat([x2, x1], dim = 1)
		
		return self.conv(x)

代码复杂一些,我们可以分开来看,首先是__init__初始化函数里定义的上采样方法以及卷积采用DoubleConv。上采样,定义了两种方法:Upsample和ConvTranspose2d,也就是双线性插值和反卷积
双线性插值很好理解,示意图:
在这里插入图片描述
熟悉双线性插值的朋友对于这幅图应不陌生,简单地讲:已知Q11、Q12、Q21、Q22四个点坐标,通过Q11和Q21求R1,再通过Q12和Q22求R2,最后通过R1和R2求P,这个过程就是双线性插值。
对于一个feature map而言,其实就像是在像素点中间补点,补的点的值是多少,是由相邻像素点的值决定的。
反卷积,顾名思义,就是反着卷积。卷积是让feature map越来越小,反卷积就是让feature map越来越大,示意图:
在这里插入图片描述
下面蓝色为原始图片,周围白色的虚线方块为padding结果,通常为0,上面绿色为卷积后的图片。
这个示意图就是一个从2x2的feature map -> 4x4的feature map过程。
在forward前向传播函数中,x1接收的是上采样的数据,x2接收的是特征融合的数据。特征融合的方法就是,上文提到的,先对小的feature map进行padding,在进行concat。

OutConv模块:
用上述的DoubleConv模块、Down模块、Up模块就可以拼出UNet的主题网络结构。UNet网络的输出需要根据分割数量,整合输出通道,结果如下图所示:
在这里插入图片描述
操作很简单,就是channel的变换,上图展示的是分类为2的情况(通道为2)

虽然这个操作很简单,也就调用一次,为了美观整洁,也封装一下吧。

class OutConv(nn.Module):
	def __init__(self, in_channels, out_channels):
		super(OutConv, self).__init__()
		self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
	def forward(self, x):
		return self.conv(x)

至此,UNet网络用到的模块都已经写好,我们可以将上述的模块代码都放到一个unet_part.py文件里,然后再创建unet_model.py,根据UNet网络结构,设置每个模块的输入输出通道个数以及调用顺序,编写如下代码:

import torch.nn.functional as F
from unet_parts import *
class UNet(nn.Module):
	def __init__(self, n_channels, n_classes, bilinear=False):
		super(UNet, self).__init__()
		self.n_channels = n_channels
		self.n_classes = n_classes
		self.bilinear = bilinear
		
		self.inc = DoubleConv(n_channels, 64)
		self.down1 = Down(64, 128)
		self.down2 = Down(128, 256)
		self.down3 = Down(256, 512)
		self.down4 = Down(512, 1024)
		self.up1 = Up(1024, 512, bilinear)
		self.up2 = Up(512, 256, bilinear)
		self.up3 = Up(256, 128, bilinear)
		self.up4 = Up(128, 64, bilinear)
		self.outc = OutConv(64, n_classes)
	
	def forward(self, x):
		x1 = self.inc(x)
		x2 = self.down1(x1)
		x3 = self.down(x2)
		x4 = self.down(x3)
		x5 = self.down(x4)
		x = self.up1(x5, x4)
		x = self.up2(x, x3)
		x = self.up3(x, x2)
		x = self.up4(x, x1)
		logits = self.outc(x)
		return logits

if __name__ == '__main__':
	net = UNet(n_channels=3, n_classes=1)
	print(net)
		

使用命令python unet_model.py,如果没有错误,你会得到如下结果:

17
118
119
120
121
122
123
124
125
126
127
UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down2): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down3): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (down4): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
          (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (up1): Up(
    (up): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up2): Up(
    (up): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up3): Up(
    (up): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (up4): Up(
    (up): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (outc): OutConv(
    (conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

网络搭建完成,下一步就是使用网络进行训练了。

总结

本文主要讲解了UNet网络结构,并对UNet网络进行了模块化梳理。
下篇文章讲解如何使用UNet网络,编写训练代码。

标签:kernel,教程,self,Pytorch,channels,stride,UNet,True,size
来源: https://blog.csdn.net/qq_40147888/article/details/115490563

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有