ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

torch中乘法整理,*&torch.mul()&torch.mv()&torch.mm()&torch.dot()&@&torch.mutmal()

2022-07-24 21:01:18  阅读:193  来源: 互联网

标签:mat2 tensor mat1 mm mutmal torch arange vec


*位置乘

符号*在pytorch中是按位置相乘,存在广播机制。

例子:

vec1 = torch.arange(4)
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(4,3)
mat2 = torch.arange(12).reshape(3,4)

print(vec1 * vec2)
print(mat2 * vec1)
print(mat1 * mat1)

Output:
tensor([0, 3, 4, 3])
tensor([[ 0,  1,  4,  9],
        [ 0,  5, 12, 21],
        [ 0,  9, 20, 33]])
tensor([[  0,   1,   4],
        [  9,  16,  25],
        [ 36,  49,  64],
        [ 81, 100, 121]])

torch.mul():数乘

官方解释:

就是两个变量对应元素相乘,other可以为一个数,也可以为一个tensor变量

torch.mul()支持广播机制

例子1:
'''python
In[1]: vec = torch.randn(3)
In[2]: vec
Out[1]: tensor([0.3550, 0.0975, 1.3870])
In[3]: torch.mul(vec, 5)
Out[2]: tensor([1.7752, 0.4874, 6.9348])
'''

例子2:
'''python
In[1]: vec = torch.randn(3)
In[2]: vec
Out[1]: tensor([1.7752, 0.4874, 6.9348])
In[3]: mat = torch.randn(4).view(-1,1)
In[4]: mat
Out[2]: tensor([[-1.5181],
[ 0.4905],
[-0.3388],
[ 0.5626]])

In[5]:torch.mul(vec,mat)
Out[3]:tensor([[-0.5390, -0.1480, -2.1055],
[ 0.1741, 0.0478, 0.6803],
[-0.1203, -0.0330, -0.4699],
[ 0.1998, 0.0548, 0.7803]])
'''

torch.mv():矩阵向量乘法

官方文档写道:Performs a matrix-vector product of the matrix input and the vector vec.
说明torch.mv(input, vec, *, out=None)->tensor只支持矩阵向量乘法,如果input为\(n\times m\)的,vec向量的长度为m,那么输出为\(n\times 1\)的向量。
torch.mv()不支持广播机制
例子:

In[1]: vec = torch.arange(4)
In[2]: mat = torch.arange(12).reshape(3,4)
In[3]: torch.mv(mat, vec)
Out[1]: tensor([14, 38, 62])

torch.mm() 矩阵乘法

官方文档写道:Performs a matrix multiplication of the matrices input and mat2.
torch.mm(input , mat2, *, out=None) → Tensor
对矩阵input 和mat2进行相乘。 如果input 是一个n×m张量,mat2 是一个 m×p张量,将会输出一个 n×p张量out。
torch.mm()不支持广播机制
这个就是线性代数中的矩阵乘法。

例子:

In[1]: mat1 = torch.arange(12).reshape(3,4)
In[2]: mat2 = torch.arange(12).reshape(4,3)
In[3]: torch.mm(mat1, mat2)
Out[1]: tensor([[ 42,  48,  54],
        [114, 136, 158],
        [186, 224, 262]])

torch.dot() 点乘积

官方文档写道:Computes the dot product of two 1D tensors.
只能支持两个一维向量,与numpy中dot()方法不同。
torch.dot(input, other, *, out=None) → Tensor

例子:

In[1]: torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
Out[1]: tensor[7]

@操作

torch中的@操作是可以实现前面几个函数,是一种强大的操作。
mat1 @ mat2

  • 若mat1和mat2都是两个一维向量,那么对应操作就是torch.dot()
  • 若mat1是二维向量,mat2是一维向量,那么对应操作就是torch.mv()
  • 若mat1和mat2都是两个二维向量,那么对应操作就是torch.mm()
vec1 = torch.arange(4)
vec2 = torch.tensor([4,3,2,1])
mat1 = torch.arange(12).reshape(4,3)
mat2 = torch.arange(12).reshape(3,4)

print(vec1 @ vec2)
print(mat2 @ vec1)
print(mat1 @ mat2)

Output:
tensor(10)
tensor([14, 38, 62])
tensor([[ 20,  23,  26,  29],
        [ 56,  68,  80,  92],
        [ 92, 113, 134, 155],
        [128, 158, 188, 218]])

torch.matmul() 待整理

标签:mat2,tensor,mat1,mm,mutmal,torch,arange,vec
来源: https://www.cnblogs.com/CharlesLC/p/16515354.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有