ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Pytorch框架中余弦相似度(Cosine similarity)、欧氏距离(Euclidean distance)源码解析

2021-12-29 18:03:39  阅读:1071  来源: 互联网

标签:distance dist similarity ## torch 矩阵 源码 keepdim expand


一、矩阵操作用于计算余弦相似度

余弦相似度:

cos(u,v)=\frac{uv}{||u||||v||}

我们知道,分子是矩阵的乘法,分母是两个标量的乘积。分母好办,关键是如何在计算分子?很简单,我们可以将公式变变形:

cos(u,v)=\frac{1}{||u||||v||}uv

那么我们只需在矩阵乘法前,使其归一化,乘法之后就是余弦相似度了,来看一下代码(参考:https://zhuanlan.zhihu.com/p/383675457

import torch

##计算两个特征的余弦相似度


def normalize(x, axis=-1):
    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
    return x


##特征向量a
a=torch.rand(4,512)

##特征向量b
b=torch.rand(6,512)

##特征向量进行归一化
a,b=normalize(a),normalize(b)

##矩阵乘法求余弦相似度
cos=1-torch.mm(a,b.permute(1,0))
cos.shape

#输出
torch.Size([4, 6])

 我们来逐行解析一下这段代码吧。

x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)

这是归一化的公式,为什么是这个公式,我也不太明白。不过不妨碍我们解析。看到 torch.norm(x, 2, axis, keepdim=True) ,这是一个非常重要的知识点:torch.norm( input, p, din, out = None, keepdim = False )该函数的功能是求指定维度上的范数;其次看到 expand_as(tensor)函数,这是将张量scale扩展为参数tensor的大小。我这么说可能不太明白?那就糊涂着吧。

##特征向量a
a=torch.rand(4,512)

##特征向量b
b=torch.rand(6,512)

##特征向量进行归一化
a,b=normalize(a),normalize(b)

 这三行就很简单了,从最后一行来看,就是把特征向量a,b归一化。这里面主要要知道 torch.rand(*sizes,out=None) 函数的用法。torch.rand(*sizes,out=None) 是均匀分布,返回的张量包含从区间(0,1)的均匀分布中随机抽取的一组随机数。第一个参数*size定义了输出张量的形状,也就是一个多大的矩阵。不明白?举个例子,比如 t1 = torch.rand(2,3),那它返回一个张量,张量的大小就是一个二行三列的矩阵,结果就是在(0,1)上随机抽取的随机数:

cos=1-torch.mm(a,b.permute(1,0))
cos.shape

这就在求余弦相似度了,注意一下permute()函数,permute作用为调换Tensor的维度,参数为调换的维度。例如对于一个二维Tensor来说,调用tensor.permute(1,0)意为将1轴(列轴)与0轴(行轴)调换,相当于进行转置

二、矩阵操作用于计算欧式距离

代码来自Triplet Loss,实质上都是这样写的,没有大碍。

先搞清楚原理(参考:https://blog.csdn.net/frankzd/article/details/80251042)现在我们有大小为 M X D 的矩阵P,和大小为 N X D 的矩阵C。记P_{i} 是矩阵P的第i行,P_{i}=(P_{i1},P_{i2},...,{P_{iD}}) ;C_{j} 是矩阵C的第j行,C_{j}=(C_{j1},C_{j2},...,C_{jD})

接着我们来看一下源代码怎么实现的:

def euclidean_dist(x, y):
  """
  Args:
    x: pytorch Variable, with shape [m, d]
    y: pytorch Variable, with shape [n, d]
  Returns:
    dist: pytorch Variable, with shape [m, n]
  """
  m, n = x.size(0), y.size(0)
  xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
  yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
  dist = xx + yy
  dist.addmm_(1, -2, x, y.t())
  dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
  return dist

 现在我们来逐行解析 (参考:https://blog.csdn.net/IT_forlearn/article/details/100022244):

m, n = x.size(0), y.size(0)

这一行比较简单,x的维度是[m,d],y的维度是[n,d],x.size(0) 就表示取x的第一个维度,即m。同理y.size(0)

xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)

这一行就比较难理解了,xx经过pow()方法对每单个数据进行二次方操作后,在axis=1 方向(横向,就是第一列向最后一列的方向)加和,此时xx的shape为(m, 1),经过expand()方法,扩展n-1次,此时xx的shape为(m, n).

yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()

与上一行相比,yy会在上述操作后,再进行转置的操作。

dist = xx + yy

这很简单,矩阵的加法

dist.addmm_(1, -2, x, y.t())

这里要特别注意,代码是dist.addmm_不是dist.addmm,具体区别参考:https://blog.csdn.net/qq_36556893/article/details/90638449。dist.addmm_(1, -2, x, y.t()) 实现的公式为:dist=1*dist-2*(xy^{T})

dist = dist.clamp(min=1e-12).sqrt()

clamp()函数可以限定dist内元素的最大最小范围,dist最后开方,得到样本之间的距离矩阵。

标签:distance,dist,similarity,##,torch,矩阵,源码,keepdim,expand
来源: https://blog.csdn.net/m0_57541899/article/details/122211836

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有