标签:Linear nn 补充 self init weight pytorch 李沐 net
1、注册带有参数的层时候就要使用nn.Parameter()
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.v=nn.Parameter()
self.v=nn.Linear()
要将左侧的v注册成为参数,右侧就需要进行nn.Parameter这个操作,第二个没有是因为Linear本身就是封装好了的。
这也就意味着,所有有学习参数的层,必须在__init__内部封装好。
2、介绍一下**nn._modules()**层的容器,可以用来复写sequential
class MySequential(nn.Module):
def __init__(self,*args):
super().__init__()
for block in args:
#._modules是包装层的容器,按顺序输入的插入层
self._modules[block]=block
def forward(self,x):
for block in self._modules.values():
x=block(x)
return x
net = MySequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))
net(x)
3、在前项传播中可以写代码,不仅仅包括一些层的传导,还可以书写代码。
class FixedHiddenMLP(nn.Module):
def __init__(self):
super().__init__()
self.rand_weight = torch.rand((20,20),requires_grad=False)
self.linear = nn.Linear(20,20)
def forward(self,x):
x=self.linear(x)
x=F.relu(torch.mm(x,self.rand_weight)+1)
x=self.linear(x)
while x.abs().sum()>1:
x/=2
return x.sum()
net=FixedHiddenMLP()
net(x)
4、参数的访问
函数如图所示:
import torch
from torch import nn
net = nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
x=torch.rand(size=(2,4))
net(x)
print(net[2].state_dict())
得到nn.Linear(8,1)的所有参数
print(net[2].bias)
访问nn.Linear(8,1)特定的参数bias
print(net[2].bias.data)
得到参数的数值
5、net.named_parameters()可以得到网络的名字和参数,还可以使用net[i]取得特定块的名字和参数
net.named_parameters()中param是len为2的tuple
param[0]是name,fc1.weight、fc1.bias等
param[1]是fc1.weight、fc1.bias等对应的值
23就是另一组
56又是另一组
6、对指定层进行参数初始化
def init_normal(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight,mean=0,std=0.1)
nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0],net[0].bias.data[0]
def init_constant(m):
if type(m) == nn.Linear:
nn.init.constant_(m.weight,1)
nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0],net[0].bias.data[0]
7、直接对指定的参数进行操作
net[0].weight.data[:] +=1
标签:Linear,nn,补充,self,init,weight,pytorch,李沐,net 来源: https://blog.csdn.net/M_arshal_/article/details/120453539
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。