ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch中self.register_buffer()

2022-09-11 12:30:42  阅读:225  来源: 互联网

标签:parameters nn buffer self register pytorch 参数 net


PyTorch中定义模型时,有时候会遇到self.register_buffer(‘name’, Tensor)的操作,该方法的作用是定义一组参数,该组参数的特别之处在于:模型训练时不会更新(即调用 optimizer.step() 后该组参数不会变化,只可人为地改变它们的值),但是保存模型时,该组参数又作为模型参数不可或缺的一部分被保存。

为了更好地理解这句话,按照惯例,我们通过一个例子实验来解释:

首先,定义一个模型并实例化:

import torch 
import torch.nn as nn
from collections import OrderedDict

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        # (1)常见定义模型时的操作
        self.param_nn = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        # (2)使用register_buffer()定义一组参数
        self.register_buffer('param_buf', torch.randn(1, 2))

        # (3)使用形式类似的register_parameter()定义一组参数
        self.register_parameter('param_reg', nn.Parameter(torch.randn(1, 2)))

        # (4)按照类的属性形式定义一组变量
        self.param_attr = torch.randn(1, 2) 

    def forward(self, x):
        return x

net = Model()
 

上例中,我们通过继承nn.Module类定义了一个模型,在模型参数的定义中,我们分别以(1)常见的nn.Module类形式、(2)self.register_buffer()形式、(3)self.register_parameter()形式,以及(4)python类的属性形式定义了4组参数。

(1)哪些参数可以在模型训练时被更新?
这可以通过net.parameters()查看,因为定义优化器时是这样的:optimizer = SGD(net.parameters(), lr=0.1)。为了方便查看,我们使用 net.named_parameters():

In [8]: list(net.named_parameters())
Out[8]:
[('param_reg',
  Parameter containing:
  tensor([[-0.0617, -0.8984]], requires_grad=True)),
 ('param_nn.conv.weight',
  Parameter containing:
  tensor([[[[-0.3183, -0.0426, -0.2984],
            [-0.1451,  0.2686,  0.0556],
            [-0.3155,  0.0451,  0.0702]]]], requires_grad=True)),
 ('param_nn.fc.weight',
  Parameter containing:
  tensor([[-0.4647],
          [ 0.7753]], requires_grad=True))]
 

可以看到,我们定义的4组参数中,只有(1)和(3)定义的参数可以被更新,而self.register_buffer()和以python类的属性形式定义的参数都不能被更新。也就是说,modules和parameters可以被更新,而buffers和普通类属性不行。

那既然这两种形式定义的参数都不能被更新,二者可以互相替代吗?答案是不可以,原因看下一节:

(2)这其中哪些才算是模型的参数呢?
模型的所有参数都装在 state_dict 中,因为保存模型参数时直接保存 net.state_dict()。我们看一下其中究竟是哪些参数:

In [9]: net.state_dict()
Out[9]:
OrderedDict([('param_reg', tensor([[-0.0617, -0.8984]])),
             ('param_buf', tensor([[-1.0517,  0.7663]])),
             ('param_nn.conv.weight',
              tensor([[[[-0.3183, -0.0426, -0.2984],
                        [-0.1451,  0.2686,  0.0556],
                        [-0.3155,  0.0451,  0.0702]]]])),
             ('param_nn.fc.weight',
              tensor([[-0.4647],
                      [ 0.7753]]))])
 

可以看到,通过 nn.Module 类、self.register_buffer() 以及 self.register_parameter() 定义的参数都在 state-dict 中,只有用python类的属性形式定义的参数不包含其中。也就是说,保存模型时,buffers,modules和parameters都可以被保存,但普通属性不行。

(3)self.register_buffer() 的使用方法
在用self.register_buffer(‘name’, tensor) 定义模型参数时,其有两个形参需要传入。第一个是字符串,表示这组参数的名字;第二个就是tensor 形式的参数。

在模型定义中调用这个参数时(比如改变这组参数的值),可以使用self.name 获取。本文例中,就可用self.param_buf 引用。这和类属性的引用方法是一样的。

在实例化模型后,获取这组参数的值时,可以用 net.buffers() 方法获取,该方法返回一个生成器(可迭代变量):

In [10]: net.buffers()
Out[10]: <generator object Module.buffers at 0x00000289CA0032E0>

In [11]: list(net.buffers())
Out[11]: [tensor([[-1.0517,  0.7663]])]

# 也可以用named_buffers() 方法同时获取名字
In [12]: list(net.named_buffers())
Out[12]: [('param_buf', tensor([[-1.0517,  0.7663]]))]
 

(4)modules, parameters 和 buffers
实际上,PyTorch 定义的模型用OrderedDict() 的方式记录这三种类型,分别保存在self._modules, self._parameters 和 self._buffers 三个私有属性中。调试模式时就可以看到每个模型都有这几个私有属性:

在这里插入图片描述
调试模式 变量窗口
由于是私有属性,我们无法在实例化的变量上调用这些属性,可以在模型定义中调用它们:
在模型的实例化变量上调用时,三者有着相似的方法:

net.modules()
net.named_modules()

net.parameters()
net.named_parameters()

net.buffers()
net.named_buffers()
 

细心的读着可能会发现,self._parameters 和 net.parameters() 的返回值并不相同。这里self._parameters 只记录了使用 self.register_parameter() 定义的参数,而net.parameters() 返回所有可学习参数,包括self._modules 中的参数和self._parameters 参数的并集。

实际上,由nn.Module类定义的参数和self.register_parameter() 定义的参数性质是一样的,都是nn.Parameter 类型。

from:https://blog.csdn.net/dagouxiaohui/article/details/125649813

标签:parameters,nn,buffer,self,register,pytorch,参数,net
来源: https://www.cnblogs.com/chentiao/p/16683798.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有