标签:pairs positive sq pdist 评价 tf 量化 pow 对比
在超球面上通过对齐和一致实现理解对比表示学习 —— 论文阅读笔记
两个对比损失最关键的要素:
- 正例对特征的对齐(就是找最接近的正例对)。
- 超球面特征分布的均匀分布(可以保存最多的信息
torch 版本代码:
# bsz : batch size (number of positive pairs)
# d : latent dim
# x : Tensor, shape=[bsz, d]
# latents for one side of positive pairs
# y : Tensor, shape=[bsz, d]
# latents for the other side of positive pairs
# lam : hyperparameter balancing the two losses
def lalign(x, y, alpha=2):
return (x - y).norm(dim=1).pow(alpha).mean()
def lunif(x, t=2):
sq_pdist = torch.pdist(x, p=2).pow(2)
return sq_pdist.mul(-t).exp().mean().log()
loss = lalign(x, y) + lam * (lunif(x) + lunif(y)) / 2
tensorflow 版本:
def lalign(x, y, alpha=2):
"""
x: [bs, d] latents for one side of positive pairs
y: [bs,d] latents for the other side of positive pairs
"""
# 第二范数
return tf.reduce_mean(tf.pow(tf.norm(x - y, axis=1), alpha))
def lunif(x, t=2):
"""
x: [bs, d]
"""
batch_size = tf.shape(x)[0]
# 实现torch.pdist
x=tf.cast(x, tf.float32)
pdist_matrix = tf.norm(x[:, None]-x, axis=2)
bool_mask = tf.cast(1-tf.linalg.band_part(tf.ones((batch_size,batch_size)),-1,0), bool) # 右上对角线
pdist = pdist_matrix[bool_mask]
sq_pdist = tf.pow(pdist, 2)
return tf.math.log(tf.reduce_mean(tf.exp(-t*sq_pdist)))
标签:pairs,positive,sq,pdist,评价,tf,量化,pow,对比 来源: https://www.cnblogs.com/carolsun/p/16419791.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。