ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch 中 torch.flatten() 和 torch.nn.Flatten() 实例详解

2022-03-30 01:03:20  阅读:216  来源: 互联网

标签:tensor nn torch Pytorch flatten print Flatten


torch.flatten()

  torch.flatten(x) 等于 torch.flatten(x,0) 默认将张量拉成一维的向量,也就是说从第一维开始平坦化,torch.flatten(x,1) 代表从第二维开始平坦化。

Example:

import torch
x=torch.randn(2,4,2)
print(x)
  
z=torch.flatten(x)
print(z)
  
w=torch.flatten(x,1)
print(w)
  

输出结果:

输出为:
tensor([[[-0.9814,  0.8251],
         [ 0.8197, -1.0426],
         [-0.8185, -1.3367],
         [-0.6293,  0.6714]],
  
        [[-0.5973, -0.0944],
         [ 0.3720,  0.0672],
         [ 0.2681,  1.8025],
         [-0.0606,  0.4855]]])
  
tensor([-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714,
        -0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855])
  
  
tensor([[-0.9814,  0.8251,  0.8197, -1.0426, -0.8185, -1.3367, -0.6293,  0.6714],
        [-0.5973, -0.0944,  0.3720,  0.0672,  0.2681,  1.8025, -0.0606,  0.4855]])

  torch.flatten(x,0,1) 代表在第一维和第二维之间平坦化 

Example:

import torch
x=torch.randn(2,4,2)
print(x)
  
w=torch.flatten(x,0,1) #第一维长度2,第二维长度为4,平坦化后长度为2*4
print(w.shape)
  
print(w)
  
输出为:
tensor([[[-0.5523, -0.1132],
         [-2.2659, -0.0316],
         [ 0.1372, -0.8486],
         [-0.3593, -0.2622]],
  
        [[-0.9130,  1.0038],
         [-0.3996,  0.4934],
         [ 1.7269,  0.8215],
         [ 0.1207, -0.9590]]])
  
torch.Size([8, 2])
  
tensor([[-0.5523, -0.1132],
        [-2.2659, -0.0316],
        [ 0.1372, -0.8486],
        [-0.3593, -0.2622],
        [-0.9130,  1.0038],
        [-0.3996,  0.4934],
        [ 1.7269,  0.8215],
        [ 0.1207, -0.9590]])

torch.nn.Flatten()

  对于 torch.nn.Flatten(),因为其被用在神经网络中,输入为一批数据,第一维为batch,通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第二维开始平坦化。

Example:

import torch
#随机32个通道为1的5*5的图
x=torch.randn(32,1,5,5)
  
model=torch.nn.Sequential(
    #输入通道为1,输出通道为6,3*3的卷积核,步长为1,padding=1
    torch.nn.Conv2d(1,6,3,1,1),
    torch.nn.Flatten()
)
output=model(x)
print(output.shape)  # 6*(7-3+1)*(7-3+1)
  
输出为:
  
torch.Size([32, 150])

 

标签:tensor,nn,torch,Pytorch,flatten,print,Flatten
来源: https://www.cnblogs.com/BlairGrowing/p/16074632.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有