ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

batch_first 选择True/False的区别

2022-04-25 09:03:44  阅读:293  来源: 互联网

标签:False batch shape output input Size True size


rnn = nn.RNN(input_size=4,hidden_size=3,num_layers=2,batch_first=True)
input = torch.randn(1,5,4)
output , h_n = rnn(input)
print(output.shape) 
print(h_n.shape)

输出结果:

#output

tensor([[[-0.4026, -0.2417, -0.1307],
             [-0.0122, 0.4269, -0.7256],
             [ 0.2228, 0.7731, -0.9092],
             [-0.3735, 0.4446, -0.6930],
             [-0.1539, 0.5937, -0.8616]]], grad_fn=<TransposeBackward1>)

#h_n
tensor([[[-0.5664, 0.0416, -0.9316]],
             [[-0.1539, 0.5937, -0.8616]]], grad_fn=<StackBackward0>)


input.Size([1, 5, 4]) output.Size([1, 5, 3]) h_n.Size([2, 1, 3]) seq_len = 1 num_layers = 2 hidden_size = 3 input_size = 4 batch_size = 5

当batch_first = False

rnn = nn.RNN(input_size=4,hidden_size=3,num_layers=2,batch_first=False)
input = torch.randn(1,5,4)
output , h_n = rnn(input)
print(output.shape)
print(h_n.shape)

输出结果:
input.Size([1, 5, 3])
output.Size([1, 5, 3])
h_n.Size([2, 5, 3])

seq_len = 1  num_layers = 2 hidden_size = 3  input_size = 4 batch_size = 5

#output
tensor([[[-0.6884,  0.0477, -0.3248],
         [-0.5575, -0.0757, -0.4916],
         [-0.6645,  0.2197, -0.4582],
         [-0.6820,  0.1047, -0.4033],
         [-0.6624,  0.0487, -0.3798]]], grad_fn=<StackBackward0>)
#h_n
tensor([[[ 0.4404, -0.4511,  0.2594],
         [ 0.8487,  0.3987,  0.2429],
         [ 0.0287, -0.7793, -0.4574],
         [-0.4603, -0.7794,  0.4563],
         [ 0.4304, -0.3424,  0.1715]],

        [[-0.6884,  0.0477, -0.3248],
         [-0.5575, -0.0757, -0.4916],
         [-0.6645,  0.2197, -0.4582],
         [-0.6820,  0.1047, -0.4033],
         [-0.6624,  0.0487, -0.3798]]], grad_fn=<StackBackward0>)

  

  

标签:False,batch,shape,output,input,Size,True,size
来源: https://www.cnblogs.com/conpi/p/16188684.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有