ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

『论文笔记』SuperGlue

2021-09-12 21:32:22  阅读:290  来源: 互联网

标签:dim SuperGlue torch log matches self 论文 笔记 scores


https://zhuanlan.zhihu.com/p/342105673

特征处理部分比较好理解,点的self、cross注意力机制实现建议看下源码(MultiHeadedAttention),

def attention(query, key, value):
    dim = query.shape[1]
    scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    prob = torch.nn.functional.softmax(scores, dim=-1)
    return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob


class MultiHeadedAttention(nn.Module):
    """ Multi-head attention to increase model expressivitiy """
    def __init__(self, num_heads: int, d_model: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.dim = d_model // num_heads
        self.num_heads = num_heads
        self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])

    def forward(self, query, key, value):
        batch_dim = query.size(0)
        query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1)
                             for l, x in zip(self.proj, (query, key, value))]
        x, prob = attention(query, key, value)
        self.prob.append(prob)
        return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))

这里直接跳到最后的逻辑部分,这部分论文写的比较粗略,需要看下源码才知道在讲啥(也许有大佬不用看)。

看这里,即是说推理时检出的匹配关系是不考虑最后一行和最后一列的,而是设定阈值,将不合格的匹配过滤掉

        # Get the matches with score above "match_threshold".
        max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
        indices0, indices1 = max0.indices, max1.indices
        mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)  # [0,0...,1,..0]
        mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
        zero = scores.new_tensor(0)
        mscores0 = torch.where(mutual0, max0.values.exp(), zero)
        mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
        valid0 = mutual0 & (mscores0 > self.config['match_threshold'])
        valid1 = mutual1 & valid0.gather(1, indices1)
        indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
        indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))

推理时代码如下,可见图A和图B互相匹配的结果(按照score的行列取最大值的index)不必完全一致:

                kpts0, kpts1 = pred['keypoints0'].cpu().numpy()[0], pred['keypoints1'].cpu().numpy()[0]
                matches, conf = pred['matches0'].cpu().detach().numpy(), pred['matching_scores0'].cpu().detach().numpy()
                image0 = read_image_modified(image0, opt.resize, opt.resize_float)
                image1 = read_image_modified(image1, opt.resize, opt.resize_float)
                valid = matches > -1
                mkpts0 = kpts0[valid]
                mkpts1 = kpts1[matches[valid]]
                mconf = conf[valid]

然后看求解分配矩阵的部分,couplings为相似度得分矩阵,为其添加了最后一行一列,并赋值为1,在原文提到的约束下,使用sinkhorn(待看)算法求解,求出分配矩阵Z,

# b(m+1)(n+1), b(m+1), b(n+1)
def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int):
    """ Perform Sinkhorn Normalization in Log-space for stability"""
    u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
    for _ in range(iters):
        # [log(m+n) ..., log(n)+log(m+n)] - []
        u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)  # b(m+1)
        v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
    return Z + u.unsqueeze(2) + v.unsqueeze(1)


def log_optimal_transport(scores, alpha, iters: int):
    """ Perform Differentiable Optimal Transport in Log-space for stability"""
    b, m, n = scores.shape
    one = scores.new_tensor(1)
    ms, ns = (m*one).to(scores), (n*one).to(scores)

    bins0 = alpha.expand(b, m, 1)  # only a new view
    bins1 = alpha.expand(b, 1, n)
    alpha = alpha.expand(b, 1, 1)

    # b(m+1)(n+1), 额外行列值为1
    couplings = torch.cat([torch.cat([scores, bins0], -1),  # bmn,bm1->bm(n+1)
                           torch.cat([bins1, alpha], -1)], 1)  # b1n,b11->b1(n+1)

    norm = - (ms + ns).log()
    log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])  # m+1: [log(m+n) ..., log(n)+log(m+n)]
    log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])  # n+1: [log(m+n) ..., log(m)+log(m+n)]
    log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)  # b(m+1), b(n+1)

    Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
    Z = Z - norm  # multiply probabilities by M+N
    return Z

损失函数就是最大化这个分配矩阵Z,即下面的scores矩阵,匹配对中肯定不包含dustbin点的,也就是说对dustbin的考量蕴含在sinkhorn中,注意下面的函数调用的参数self.bin_score,这是superglue网络的一个可学习的参数:

        bin_score = torch.nn.Parameter(torch.tensor(1.))         self.register_parameter('bin_score', bin_score) 回头看上面的log_optimal_transport代码,每次给couplings的额外行列赋的值就是这个值。
        all_matches = data['all_matches'].permute(1,2,0) # shape=torch.Size([1, 87, 2])

        ……

        # Run the optimal transport.
        scores = log_optimal_transport(
            scores, self.bin_score,
            iters=self.config['sinkhorn_iterations'])

        ……

        # check if indexed correctly
        loss = []
        for i in range(len(all_matches[0])):
            x = all_matches[0][i][0]
            y = all_matches[0][i][1]
            loss.append(-torch.log( scores[0][x][y].exp() )) # check batch size == 1 ?

原文里对分配矩阵的约束如下,这个应该是引入sinkhorn的作用,在代码中分配矩阵P_head并没有显式出现,所以没法辅助我理解这个公式:

 相对应的,P的约束就很好理解:

 

标签:dim,SuperGlue,torch,log,matches,self,论文,笔记,scores
来源: https://www.cnblogs.com/hellcat/p/15260145.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有