标签:主要参数 GRU torch nn Recurrent Gated input LSTM
GRU(Gated Recurrent Unit)也称门控循环单元结构
nn.GRU类初始化主要参数解释:
input_size: 输入张量x中特征维度的大小.
hidden_size: 隐层张量h中特征维度的大小.
num_layers: 隐含层的数量.
nonlinearity: 激活函数的选择, 默认是tanh.
bidirectional: 是否选择使用双向LSTM, 如果为True, 则使用; 默认不使用.
nn.GRU类实例化对象主要参数解释:
input: 输入张量x.
h0: 初始化的隐层张量h.
代码示例:
import torch
import torch.nn as nn
rnn = nn.GRU(5,6,2) #数据向量维数5, 隐藏元维度6, 2个LSTM层串联(如果是1,可以省略,默认为1)
input = torch.randn(1,3,5) # 序列长度seq_len=1, batch_size=3, 数据向量维数=5
h0 = torch.randn(2,3,6) # 2个LSTM层,batch_size=3,隐藏元维度6
output, hn = rnn(input,h0)
print(output)
print(output.type())
print(output.shape)
print(hn)
代码运行结果
tensor([[[-0.0307, 0.0718, -0.2517, 0.0565, 0.0613, 0.5001],
[-0.6239, 1.0618, 0.7506, 0.3475, -0.8536, -0.8410],
[-0.0949, 0.5698, 0.4491, -0.0122, 0.5413, -0.2383]]],
grad_fn=<StackBackward0>)
torch.FloatTensor
torch.Size([1, 3, 6])
tensor([[[-0.5540, 0.3067, -1.2936, -0.3727, -0.4141, 0.2967],
[-1.2364, 0.7779, 0.4355, -1.2783, -0.0382, 0.5875],
[ 1.4438, 1.2898, -0.3959, -0.5599, -1.1615, 0.3538]],
[[-0.0307, 0.0718, -0.2517, 0.0565, 0.0613, 0.5001],
[-0.6239, 1.0618, 0.7506, 0.3475, -0.8536, -0.8410],
[-0.0949, 0.5698, 0.4491, -0.0122, 0.5413, -0.2383]]],
grad_fn=<StackBackward0>)
GRU的优势:
GRU和LSTM作用相同, 在捕捉长序列语义关联时, 能有效抑制梯度消失或爆炸, 效果都优于传统RNN且计算复杂度相比LSTM要小.
GRU的缺点:
GRU仍然不能完全解决梯度消失问题, 同时其作用RNN的变体, 有着RNN结构本身的一大弊端, 即不可并行计算, 这在数据量和模型体量逐步增大的未来, 是RNN发展的关键瓶颈.
标签:主要参数,GRU,torch,nn,Recurrent,Gated,input,LSTM 来源: https://blog.csdn.net/weixin_41862755/article/details/123140739
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。