ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

UNET建筑物分割轮廓识别

2021-09-19 10:01:37  阅读:229  来源: 互联网

标签:loss 分割 512 self UNET 64 128 轮廓 256


语义分割UNET模型

UNET模型

unet语义分割模型在kaggle竞赛中的一些图像识别任务比较火,比如data-science-bowl-2018airbus-ship-detection。另外它在医学图像上表现也非常好。它简单,高效,易懂,容易构建,而且训练所需的数据集数量也无需特别多。

unet论文中的网络结构长成如下图所示。这个结构比较简单,左边相当于一个Encoder,右边相当于一个Decoder。左边的Encoder主要是提取特征,主要操作是使用size为3的卷积核进行卷积,然后进行maxpooling。右边为Decoder,主要操作是up-conv和3*3的卷积操作。有两个地方需要注意是1. UNET网络进行特征图的copy and crop。2. 在最后的输出层进行size为1的卷积操作。UNet共进行了4次上采样,并在同一个stage使用了skip connection,而不是直接在高级语义特征上进行监督和loss反传,这样就保证了最后恢复出来的特征图融合了更多的low-level的feature,也使得不同scale的feature得到了的融合,从而可以进行多尺度预测和DeepSupervision。4次上采样也使得分割图恢复边缘等信息更加精细。skip-connection联系了输入图像的很多信息,有助于还原降采样所带来的信息损失,在一定程度上,它和残差的操作非常类似。

在这里插入图片描述

总结一下unet网络结构。左边是编码器,作用是提取特征。右边是解码器,通过上采样的方式将结果输出。unet相比FCN网络,unet通过拼接融合特征图,这样做的好处是:深层网络层,有更大的感受野,更关注图像本质的特征,而浅层特征图关注的是纹理特征。因此无论深层,浅层的特征图,都有其作用,通过这种拼接融合,使得网络能够很好地学习到特征。

下图是我使用unet算法对影像图像的建筑物进行识别,数据集大概十几张即可,训练出来的效果还能接受。

在这里插入图片描述

代码实现

本代码是参考这大佬的代码进行修改:keras实现unet模型

model

Backbone使用的是VGG16的网络,分别进行两个卷积操作,记录卷积后操作的特征图,然后进行maxpooling。以此类推进行五次操作。

def VGG16(img_input):
    # Block 1
    # 512,512,3 -> 512,512,64
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block1_conv1')(img_input)
    x = layers.Conv2D(64, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block1_conv2')(x)
    feat1 = x
    # 512,512,64 -> 256,256,64
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)

    # Block 2
    # 256,256,64 -> 256,256,128
    x = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block2_conv1')(x)
    x = layers.Conv2D(128, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block2_conv2')(x)
    feat2 = x
    # 256,256,128 -> 128,128,128
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)


    # Block 3
    # 128,128,128 -> 128,128,256
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block3_conv1')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block3_conv2')(x)
    x = layers.Conv2D(256, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block3_conv3')(x)
    feat3 = x
    # 128,128,256 -> 64,64,256
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)

    # Block 4
    # 64,64,256 -> 64,64,512
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block4_conv1')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block4_conv2')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block4_conv3')(x)
    feat4 = x
    # 64,64,512 -> 32,32,512
    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)

    # Block 5
    # 32,32,512 -> 32,32,512
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block5_conv1')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block5_conv2')(x)
    x = layers.Conv2D(512, (3, 3),
                      activation='relu',
                      padding='same',
                      kernel_initializer = random_normal(stddev=0.02), 
                      name='block5_conv3')(x)
    feat5 = x
    return feat1, feat2, feat3, feat4, feat5

以上是encoder特征提取的过程,下面的代码是上采样decoder的过程:

def Unet(input_shape=(256,256,3), num_classes=21):
    inputs = Input(input_shape)
    feat1, feat2, feat3, feat4, feat5 = VGG16(inputs) 
    channels = [64, 128, 256, 512]
    # 32, 32, 512 -> 64, 64, 512
    P5_up = UpSampling2D(size=(2, 2))(feat5)
    # 64, 64, 512 + 64, 64, 512 -> 64, 64, 1024
    P4 = Concatenate(axis=3)([feat4, P5_up])
    # 64, 64, 1024 -> 64, 64, 512
    P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P4)
    P4 = Conv2D(channels[3], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P4)
    # 64, 64, 512 -> 128, 128, 512
    P4_up = UpSampling2D(size=(2, 2))(P4)
    # 128, 128, 256 + 128, 128, 512 -> 128, 128, 768
    P3 = Concatenate(axis=3)([feat3, P4_up])
    # 128, 128, 768 -> 128, 128, 256
    P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P3)
    P3 = Conv2D(channels[2], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P3)

    # 128, 128, 256 -> 256, 256, 256
    P3_up = UpSampling2D(size=(2, 2))(P3)
    # 256, 256, 256 + 256, 256, 128 -> 256, 256, 384
    P2 = Concatenate(axis=3)([feat2, P3_up])
    # 256, 256, 384 -> 256, 256, 128
    P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P2)
    P2 = Conv2D(channels[1], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P2)

    # 256, 256, 128 -> 512, 512, 128
    P2_up = UpSampling2D(size=(2, 2))(P2)
    # 512, 512, 128 + 512, 512, 64 -> 512, 512, 192
    P1 = Concatenate(axis=3)([feat1, P2_up])
    # 512, 512, 192 -> 512, 512, 64
    P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P1)
    P1 = Conv2D(channels[0], 3, activation='relu', padding='same', kernel_initializer = random_normal(stddev=0.02))(P1)

    # 512, 512, 64 -> 512, 512, num_classes
    P1 = Conv2D(num_classes, 1, activation="softmax")(P1)

    model = Model(inputs=inputs, outputs=P1)
    return model

至此,unet模型的结构已经实现了。

Train

  • 数据加载部分

    数据记载部分感觉没什么可以说的,主要有以下操作:1. 数据增强,随机对图像进行处理。首先是要resize图像,然后翻转图像,接着distort 图像等。2. 标签需要编码成one-hot的形式。

    
    class Generator(object):
        def __init__(self,batch_size,train_lines,image_size,num_classes,dataset_path):
            self.batch_size     = batch_size
            self.train_lines    = train_lines
            self.train_batches  = len(train_lines)
            self.image_size     = image_size
            self.num_classes    = num_classes
            self.dataset_path   = dataset_path
    
        def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5):
            label = Image.fromarray(np.array(label))
    
            h, w = input_shape
            # resize image
            rand_jit1 = rand(1-jitter,1+jitter)
            rand_jit2 = rand(1-jitter,1+jitter)
            new_ar = w/h * rand_jit1/rand_jit2
    
            scale = rand(0.25, 2)
            if new_ar < 1:
                nh = int(scale*h)
                nw = int(nh*new_ar)
            else:
                nw = int(scale*w)
                nh = int(nw/new_ar)
            image = image.resize((nw,nh), Image.BICUBIC)
            label = label.resize((nw,nh), Image.NEAREST)
            label = label.convert("L")
            
            # flip image or not
            flip = rand()<.5
            if flip: 
                image = image.transpose(Image.FLIP_LEFT_RIGHT)
                label = label.transpose(Image.FLIP_LEFT_RIGHT)
            
            # place image
            dx = int(rand(0, w-nw))
            dy = int(rand(0, h-nh))
            new_image = Image.new('RGB', (w,h), (128,128,128))
            new_label = Image.new('L', (w,h), (0))
            new_image.paste(image, (dx, dy))
            new_label.paste(label, (dx, dy))
            image = new_image
            label = new_label
    
            # distort image
            hue = rand(-hue, hue)
            sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat)
            val = rand(1, val) if rand()<.5 else 1/rand(1, val)
            x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
            x[..., 0] += hue*360
            x[..., 0][x[..., 0]>1] -= 1
            x[..., 0][x[..., 0]<0] += 1
            x[..., 1] *= sat
            x[..., 2] *= val
            x[x[:,:, 0]>360, 0] = 360
            x[:, :, 1:][x[:, :, 1:]>1] = 1
            x[x<0] = 0
            image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
            return image_data,label
            
        def generate(self, random_data = True):
            i = 0
            length = len(self.train_lines)
            inputs = []
            targets = []
            while True:
                if i == 0:
                    shuffle(self.train_lines)
                annotation_line = self.train_lines[i]
                name = annotation_line.split()[0]
    
                # 从文件中读取图像
                jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "JPEGImages"), name + ".jpg"))
                png = Image.open(os.path.join(os.path.join(self.dataset_path, "labels"), name + ".png"))
    
                if random_data:
                    jpg, png = self.get_random_data(jpg,png,(int(self.image_size[1]),int(self.image_size[0])))
                else:
                    jpg, png = letterbox_image(jpg, png, (int(self.image_size[1]),int(self.image_size[0])))
                
                inputs.append(np.array(jpg)/255)
                
                png = np.array(png)
                png[png >= self.num_classes] = self.num_classes
                seg_labels = np.eye(self.num_classes+1)[png.reshape([-1])]
                seg_labels = seg_labels.reshape((int(self.image_size[1]),int(self.image_size[0]),self.num_classes+1))
                
                targets.append(seg_labels)
                i = (i + 1) % length
                if len(targets) == self.batch_size:
                    tmp_inp = np.array(inputs)
                    tmp_targets = np.array(targets)
                    inputs = []
                    targets = []
                    yield tmp_inp, tmp_targets
    
  • CE/CE_LOSS

    def dice_loss_with_CE(beta=1, smooth = 1e-5):
       def _dice_loss_with_CE(y_true, y_pred):
           y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
    
           CE_loss = - y_true[...,:-1] * K.log(y_pred)
           CE_loss = K.mean(K.sum(CE_loss, axis = -1))
    
           tp = K.sum(y_true[...,:-1] * y_pred, axis=[0,1,2])
           fp = K.sum(y_pred         , axis=[0,1,2]) - tp
           fn = K.sum(y_true[...,:-1], axis=[0,1,2]) - tp
    
           score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
           score = tf.reduce_mean(score)
           dice_loss = 1 - score
           # dice_loss = tf.Print(dice_loss, [dice_loss, CE_loss])
           return CE_loss + dice_loss
       return _dice_loss_with_CE
    
    def CE():
       def _CE(y_true, y_pred):
           y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
    
           CE_loss = - y_true[...,:-1] * K.log(y_pred)
           CE_loss = K.mean(K.sum(CE_loss, axis = -1))
           # dice_loss = tf.Print(CE_loss, [CE_loss])
           return CE_loss
       return _CE
    

    详情请参考:语义分割损失函数总结

  • 记录训练损失

    class LossHistory(keras.callbacks.Callback):
        def __init__(self, log_dir):
            import datetime
            curr_time = datetime.datetime.now()
            time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
            self.log_dir    = log_dir
            self.time_str   = time_str
            self.save_path  = os.path.join(self.log_dir, "loss_" + str(self.time_str))  
            self.losses     = []
            self.val_loss   = []
            
            os.makedirs(self.save_path)
    
        def on_epoch_end(self, batch, logs={}):
            self.losses.append(logs.get('loss'))
            self.val_loss.append(logs.get('val_loss'))
            with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f:
                f.write(str(logs.get('loss')))
                f.write("\n")
            with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f:
                f.write(str(logs.get('val_loss')))
                f.write("\n")
            # self.loss_plot()
    
        def loss_plot(self):
            iters = range(len(self.losses))
    
            plt.figure()
            plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
            plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
            try:
                if len(self.losses) < 25:
                    num = 5
                else:
                    num = 15
                
                plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
                plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
            except:
                pass
    
            plt.grid(True)
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('A Loss Curve')
            plt.legend(loc="upper right")
    
            plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png"))
    
            plt.cla()
            plt.close("all")
    
    

Predict

训练的过程中只保留了权重,因此需要先实现网络,然后加载权重。模型代码如下:

class Unet(object):

    #---------------------------------------------------#
    #   初始化UNET
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        _defaults = {
            "model_path"        : kwargs['model'],
            "model_image_size"  : kwargs['model_image_size'],
            "num_classes"       : kwargs['num_classes']
        }
        self.__dict__.update(_defaults)
        self.generate()

    #---------------------------------------------------#
    #   载入模型
    #---------------------------------------------------#
    def generate(self):
        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        self.model = unet(self.model_image_size, self.num_classes)

        self.model.load_weights(self.model_path)
        print('{} model loaded.'.format(self.model_path))
        
        if self.num_classes <= 21:
            self.colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128), 
                    (128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128), 
                    (64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)]
        else:
            # 画框设置不同的颜色
            hsv_tuples = [(x / len(self.class_names), 1., 1.)
                        for x in range(len(self.class_names))]
            self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
            self.colors = list(
                map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
                    self.colors))

    def letterbox_image(self ,image, size):
        image = image.convert("RGB")
        iw, ih = image.size
        w, h = size
        scale = min(w/iw, h/ih)
        nw = int(iw*scale)
        nh = int(ih*scale)

        image = image.resize((nw,nh), Image.BICUBIC)
        new_image = Image.new('RGB', size, (128,128,128))
        new_image.paste(image, ((w-nw)//2, (h-nh)//2))
        return new_image,nw,nh

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #---------------------------------------------------------#
        image = image.convert('RGB')
        
        #---------------------------------------------------#
        #   对输入图像进行一个备份,后面用于绘图
        #---------------------------------------------------#
        old_img = copy.deepcopy(image)
        orininal_h = np.array(image).shape[0]
        orininal_w = np.array(image).shape[1]

        #---------------------------------------------------#
        #   进行不失真的resize,添加灰条,进行图像归一化
        #---------------------------------------------------#
        img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
        img = np.asarray([np.array(img)/255])
        pr = self.model.predict(img)[0]
        #---------------------------------------------------#
        #   取出每一个像素点的种类
        #---------------------------------------------------#
        pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
        #--------------------------------------#
        #   将灰条部分截取掉
        #--------------------------------------#
        pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]

        #------------------------------------------------#
        #   创建一副新图,并根据每个像素点的种类赋予颜色
        #------------------------------------------------#
        seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
        for c in range(self.num_classes):
            seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8')
            seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8')
            seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8')

        image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h), Image.NEAREST)
        blend_image = Image.blend(old_img,image,0.7)

        return image, blend_image

    def get_FPS(self, image, test_interval):
        orininal_h = np.array(image).shape[0]
        orininal_w = np.array(image).shape[1]

        img, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
        img = np.asarray([np.array(img)/255])

        pr = self.model.predict(img)[0]
        pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
        pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
        
        image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST)

        t1 = time.time()
        for _ in range(test_interval):
            pr = self.model.predict(img)[0]
            pr = pr.argmax(axis=-1).reshape([self.model_image_size[0],self.model_image_size[1]])
            pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
            image = Image.fromarray(np.uint8(pr)).resize((orininal_w,orininal_h), Image.NEAREST)
            
        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time
        

参考

  1. UNET论文

  2. keras实现unet模型

标签:loss,分割,512,self,UNET,64,128,轮廓,256
来源: https://blog.csdn.net/u012655441/article/details/120373759

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有