ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

BatchNorm原理以及PyTorch实现

2021-11-22 17:02:11  阅读:331  来源: 互联网

标签:beta None stats self PyTorch running num 原理 BatchNorm


BatchNorm算法

在这里插入图片描述
简单来说BatchNorm对输入的特征按照通道计算期望方差,并标准化(均值为0,方差为1)。但这会降低网络的表达能力,因此,BN在标准化后还要进行缩放平移,也就是可学习的参数 γ \gamma γ和 β \beta β,也对应每个通道。

BatchNorm的原理并不清楚,可能是降低了Internal Covariate Shift,也可能是使得optimization landscape变得平滑

优点

  • 提高训练稳定性,可使用更大的learning rate、降低初始化参数的要求并可以构建更深更宽的网络;
  • 加速网络收敛。

缺点

  • 增加计算量和内存开销,降低推理速度;
  • 增加训练和推理时的差异;
  • 打破了minibatch之间的独立性;
  • 小batch效果差。

PyTorch Code

nn.BatchNorm2d为例。其继承关系为:Module → \to →_NormBase → \to →_BatchNorm → \to →BatchNorm2dModule 是所有PyTorch构建网络模块的父类。

_NormBase

_NormBase主要是注册和初始化参数

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    def __init__(
        self,
        num_features: int, # 特征通道数
        eps: float = 1e-5,	# 防止分母为0
        momentum: float = 0.1, # 
        affine: bool = True, # 标准化后是否进行缩放,是否使用\gamma 和 \beta
        track_running_stats: bool = True, # 使用均值方差进行标准化
        device=None,
        dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\gamma,后续初始化为1
            self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) # 注册\beta,后续初始化为0
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs)) # 注册期望,后续初始化为0
            self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs)) # 注册方差,后续初始化为1
            self.running_mean: Optional[Tensor]
            self.running_var: Optional[Tensor]
            self.register_buffer('num_batches_tracked',
                                 torch.tensor(0, dtype=torch.long,
                                              **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
            self.num_batches_tracked: Optional[Tensor]
        else:
            self.register_buffer("running_mean", None)
            self.register_buffer("running_var", None)
            self.register_buffer("num_batches_tracked", None)
        self.reset_parameters()

    def reset_running_stats(self) -> None:
        if self.track_running_stats:
            # running_mean/running_var/num_batches... are registered at runtime depending
            # if self.track_running_stats is on
            self.running_mean.zero_()  # type: ignore[union-attr]
            self.running_var.fill_(1)  # type: ignore[union-attr]
            self.num_batches_tracked.zero_()  # type: ignore[union-attr,operator]

	# 参数初始化,\gamma 为 1,\beta 为 0.
    def reset_parameters(self) -> None:
        self.reset_running_stats()
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        raise NotImplementedError

_BatchNorm

调用nn.functional.batch_norm 对每个通道进行计算:

class _BatchNorm(_NormBase):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,	# 见下一章节
        affine=True,
        track_running_stats=True,
        device=None,
        dtype=None
    ):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
        )

    def forward(self, input: Tensor) -> Tensor:
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:  # type: ignore[has-type]
                self.num_batches_tracked = self.num_batches_tracked + 1  # type: ignore[has-type]
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        r"""
        Decide whether the mini-batch stats should be used for normalization rather than the buffers.
        Mini-batch stats are used in training mode, and in eval mode when buffers are None.
        """
        if self.training:
            bn_training = True
        else:
            bn_training = (self.running_mean is None) and (self.running_var is None)

        r"""
        Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
        passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
        used for normalization (i.e. in eval mode when buffers are not None).
        """
        return F.batch_norm(
            input,
            # If buffers are not to be tracked, ensure that they won't be updated
            self.running_mean
            if not self.training or self.track_running_stats
            else None,
            self.running_var if not self.training or self.track_running_stats else None,
            self.weight,
            self.bias,
            bn_training,
            exponential_average_factor,
            self.eps,
        )

BatchNorm2d

特化了输入检查

class BatchNorm2d(_BatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError("expected 4D input (got {}D input)".format(input.dim()))

关于momentum参数

按照Pytorch注释,momentum参与running_meanrunning_var的计算。置为None时,简单计算平均(累积移动平均)。默认值为0.1。

_BatchNorm中,赋值给了

exponential_average_factor = self.momentum

当其不为None时,也就是指数平均(Exponential Moving Average, EMA)。其计算公式为:
x ˉ t = β μ t + ( 1 − β ) x ˉ t − 1 \bar{x}_t = \beta \mu_t + (1-\beta)\bar{x}_{t-1} xˉt​=βμt​+(1−β)xˉt−1​
其中, μ t \mu_t μt​是当前Batch的均值或方差, β \beta β为exponential_average_factor。展开
x ˉ t = β μ t + ( 1 − β ) ( β μ t − 1 + ( 1 − β ) ( β μ t − 2 + ( 1 − β ) x ˉ t − 3 ) ) = β μ t + ( 1 − β ) β μ t − 1 + ( 1 − β ) 2 β μ t − 2 + . . . + ( 1 − β ) t β μ 0 \begin{aligned} \bar{x}_t &= \beta \mu_t + (1-\beta)(\beta \mu_{t-1} + (1-\beta)(\beta \mu_{t-2} + (1-\beta)\bar{x}_{t-3}))\\\\ &= \beta \mu_t + (1-\beta)\beta \mu_{t-1} + (1-\beta)^2\beta \mu_{t-2} + ... + (1-\beta)^t\beta \mu_0 \end{aligned} xˉt​​=βμt​+(1−β)(βμt−1​+(1−β)(βμt−2​+(1−β)xˉt−3​))=βμt​+(1−β)βμt−1​+(1−β)2βμt−2​+...+(1−β)tβμ0​​
从公式可以看出,越靠近当前的数据占的比重越大,比重按指数衰减。其值约等于最近
1 β \frac{1}{\beta} β1​
次的均值。

标签:beta,None,stats,self,PyTorch,running,num,原理,BatchNorm
来源: https://blog.csdn.net/ice__snow/article/details/121472283

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有