ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

SO-Net中分类器(classifier)的实现过程

2021-11-03 22:03:08  阅读:220  来源: 互联网

标签:center self mask 节点 分类器 lr SO Net classifier


一、数据准备

  该部分代码位于./data/modelnet_shrec_loader.py中。读取的数据为pc_np(点的坐标),surface_normal_np(法向量),som_node_np(som节点坐标)和class_id(类别)。然后对数据增强,包括旋转、微扰、尺度变换和位移。返回点的坐标、法向量、类别、som节点和每个som节点在som节点中的k个近邻点的索引。

二、模型

  该部分代码位于./models/classifier.py。主要包括两个部分:编码网络和分类网络。
在这里插入图片描述

2.1编码网络

  编码器如上图蓝色底色的部分所示。具体的定义在./models/networks.py中。在Encoder类forward中可以看到,其输入为点的坐标、法向量、节点坐标和节点到节点的knn索引。

2.1.1 SOM层

  如上图所示,第一部分为SOM层。通过./util/som.pyBatchSOM类中的query_topk()函数实现映射,返回mask,mask_row_max和min_idx。
(1)mask:其大小为[B,kN,M],第i个N*M行表示节点是否为N个点的第i近邻节点。
在这里插入图片描述

上图中红色的1表示第2个节点是第一个点的第一近邻点。蓝色的1表示第3个节点是第四个点的第一近邻点,而最下方的1表示第5个节点是第三个点的第二近邻点。
(2)mask_row_max:其大小为[B,M],表示每个节点是否存在近邻的点
在这里插入图片描述

其中1表示第i个节点是某个点的近邻点。
(3)min_idx:其大小为[B,kN],第i个N*1行表示N个点的第i近邻节点的索引。
在这里插入图片描述

与mask对应,表示最近邻索引值。
  然后获取以每个节点为近邻的所有点的中心。

		self.mask, mask_row_max, min_idx = self.som_builder.query_topk(x.data, k=self.opt.k)  # BxkNxnode_num, Bxnode_num
        mask_row_sum = torch.sum(self.mask, dim=1)  # BxM
        mask = self.mask.unsqueeze(1)  # Bx1xkNxM

        #将x和sn堆叠
        x_list, sn_list = [], []
        for i in range(self.opt.k):
            x_list.append(x)
            sn_list.append(sn)
        x_stack = torch.cat(tuple(x_list), dim=2)  # B x C x kN
        sn_stack = torch.cat(tuple(sn_list), dim=2)# B x C x kN

        # 计算以每个点为近邻的所有点的平均坐标,作为新的节点坐标
        x_stack_data_unsqueeze = x_stack.data.unsqueeze(3)  # BxCxkNx1
        x_stack_data_masked = x_stack_data_unsqueeze * mask.float()  # BxCxkNxM
        cluster_mean = torch.sum(x_stack_data_masked, dim=2) / (mask_row_sum.unsqueeze(1).float()+1e-5)  # BxCxM,为了防止数值不稳定,即没有点以该节点为近邻点
        self.som_builder.node = cluster_mean
        self.som_node = self.som_builder.node

然后对于每个点进行去中心化,并与sn拼接到一起,作为输入。

		node_expanded = self.som_node.data.unsqueeze(2)  # BxCx1xM
        self.centers = torch.sum(mask.float() * node_expanded, dim=3).detach()  # BxCxkN

        self.x_decentered = (x_stack - self.centers).detach()  # Bx3xkN
        x_augmented = torch.cat((self.x_decentered, sn_stack), dim=1)  # Bx6xkN

2.1.2 first_pointnet

  该部分代码在./models/layers.py。其实质是一个残差网络,每一层是一个EquivariantLayer的结构,其定义在同样在./models/layers.py中。

在这里插入图片描述

(index_max是cuda的c++扩展,还没弄懂是什么意思,太菜了555)

2.1.3 knnlayer

  该部分代码在./models/layers.py中。首先计算每个center的前k个距离的center的索引:

        coordinate_tensor = coordinate.data  # Bx3xM  以节点为近邻点的所有点的中心
        if precomputed_knn_I is not None:
            assert precomputed_knn_I.size()[2] >= K
            knn_I = precomputed_knn_I[:, :, 0:K]
        else:
            coordinate_Mx1 = coordinate_tensor.unsqueeze(3)  # Bx3xMx1
            coordinate_1xM = coordinate_tensor.unsqueeze(2)  # Bx3x1xM
            norm = torch.sum((coordinate_Mx1 - coordinate_1xM) ** 2, dim=1)  # BxMxM, each row corresponds to each coordinate - other coordinates
            knn_D, knn_I = torch.topk(norm, k=K, dim=2, largest=False, sorted=True)  # BxMxK 每个center到其他center的前k个最近距离的距离和索引

然后对于每个center,计算k个近邻center的坐标、均值及去中心化后的坐标值:

        neighbors = operations.knn_gather_wrapper(coordinate_tensor, knn_I)  # Bx3xMxK 每个center最近的k个center的坐标
        if center_type == 'avg':  # 如果以k个平均值为中心
            neighbors_center = torch.mean(neighbors, dim=3, keepdim=True)  # Bx3xMx1 每个center最近的k个center的坐标中心
        elif center_type == 'center':  # 以center本身为中心
            neighbors_center = coordinate_tensor.unsqueeze(3)  # Bx3xMx1 每个center的坐标
        neighbors_decentered = (neighbors - neighbors_center).detach() # Bx3xMxK 每个center最近的k个center的去中心坐标
        neighbors_center = neighbors_center.squeeze(3).detach()  # Bx3xM  中心坐标

最后得到每个center最近邻的k个center的特征向量,并作为卷积层的输入,该卷积层是在同一个文件中定义的。返回值是center点的坐标和特征向量。

        x_neighbors = operations.knn_gather_by_indexing(x, knn_I)  # BxCxMxK 每个center最近邻k个center的特征向量
        x_augmented = torch.cat((neighbors_decentered, x_neighbors), dim=1)  # Bx(3+C)xMxK 与中心坐标拼接

2.1.4 final_pointnet

  该部分代码在./models/layers.py中。是一个常规的pointnet网络结构。得到全局的特征向量(图中global feature)。

2.2 分类器

  该部分代码位于models/networks.py中。其实质是一个三层的全连接层,输出点云对于每个类别的概率,最后用交叉熵损失进行训练。

三、测试与保存

3.1 get_current_errors & visualizer.plot_current_errors

  该部分代码在./models/classifier.py中。统计预测的准确率,并进行可视化(loss-time曲线)。

3.2 model.save_network

  保存模型到指定路径。

    def save_network(self, network, network_label, epoch_label, gpu_id):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.opt.checkpoints_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if gpu_id>=0 and torch.cuda.is_available():
            # torch.cuda.device(gpu_id)
            network.to(self.opt.device)

3.3 model.update_learning_rate

  更新学习速率。

def update_learning_rate(self, ratio):
        lr_clip = 0.00001

        # encoder
        lr_encoder = self.old_lr_encoder * ratio
        if lr_encoder < lr_clip:
            lr_encoder = lr_clip
        for param_group in self.optimizer_encoder.param_groups:
            param_group['lr'] = lr_encoder
        print('update encoder learning rate: %f -> %f' % (self.old_lr_encoder, lr_encoder))
        self.old_lr_encoder = lr_encoder

        # classifier
        lr_classifier = self.old_lr_classifier * ratio
        if lr_classifier < lr_clip:
            lr_classifier = lr_clip
        for param_group in self.optimizer_classifier.param_groups:
            param_group['lr'] = lr_classifier
        print('update classifier learning rate: %f -> %f' % (self.old_lr_classifier, lr_classifier))
        self.old_lr_classifier = lr_classifier

标签:center,self,mask,节点,分类器,lr,SO,Net,classifier
来源: https://blog.csdn.net/qq_43173635/article/details/121107760

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有