ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python – 在pytorch(或Numpy)中实现此等式的更有效方法

2019-07-10 13:07:29  阅读:257  来源: 互联网

标签:python machine-learning numpy pytorch


我正在实现这个功能的分析形式

enter image description here

其中k(x,y)是RBF核k(x,y)= exp( – || x-y || ^ 2 /(2h))

我的函数原型是

def A(X, Y, grad_log_px,Kxy):
   pass

X,Y是NxD矩阵,其中N是批量大小,D是维度.因此,X是上面等式中具有大小N的一批x,grad_log_px是使用autograd计算的一些NxD矩阵.

Kxy是NxN矩阵,其中每个条目(i,j)是RBF内核K(X [i],Y [j])

这里的挑战是,在上面的等式中,y只是一个维数为D的向量.我想要传递给一批y. (所以传递NxD大小的矩阵Y)

使用循环批量大小的方程式很好,但我无法以更整洁的方式实现

这是我尝试的循环解决方案:

def A(X, Y, grad_log_px,Kxy):
   res = []
   for i in range(Y.shape[0]):
       temp = 0
       for j in range(X.shape[0]):
           # first term of equation
           temp += grad_log_px[j].reshape(D,1)@(Kxy[j,i] * (X[i] - Y[j]) / h).reshape(1,D)
           temp += Kxy[j,i] * np.identity(D) - ((X[i] - Y[j]) / h).reshape(D,1)@(Kxy[j,i] * (X[i] - Y[j]) / h).reshape(1,D) # second term of equation
       temp /= X.shape[0]

        res.append(temp)
    return np.asarray(res) # return NxDxD array 

在等式中:grad_ {x}和grad_ {y}都是维D

解决方法:

鉴于我正确推断了各种术语的所有维度,这是一种解决方法.但首先是维度的摘要(截图,因为它更容易用数学类型设置解释;请验证它们是否正确):

Explanation

另请注意第二项的双重导数,它给出:

Derivative

下标表示样本,上标表示功能.

因此,我们可以使用np.einsum(类似于torch.einsum)和array broadcasting创建两个术语:

grad_y_K = (X[:, None, :] - Y) / h * K[:, :, None]  # Shape: N_x, N_y, D
term_1 = np.einsum('ij,ikl->ikjl', grad_log_px, grad_y_K)  # Shape: N_x, N_y, D_x, D_y
term_2_h = np.einsum('ij,kl->ijkl', K, np.eye(D)) / h  # Shape: N_x, N_y, D_x, D_y
term_2_h2_xy = np.einsum('ijk,ijl->ijkl', grad_y_K, grad_y_K)  # Shape: N_x, N_y, D_x, D_y
term_2_h2 = K[:, :, None, None] * term_2_h2_xy / h**2  # Shape: N_x, N_y, D_x, D_y
term_2 = term_2_h - term_2_h2  # Shape: N_x, N_y, D_x, D_y

然后结果由下式给出:

(term_1 + term_2).sum(axis=0) / N  # Shape: N_y, D_x, D_y

标签:python,machine-learning,numpy,pytorch
来源: https://codeday.me/bug/20190710/1424382.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有