ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

详解L1和L2正则化

2020-11-24 20:28:44  阅读:514  来源: 互联网

标签:weight decay 正则 L2 L1 范数


大纲:

  • L1和L2的区别以及范数相关知识
  • 对参数进行L1和L2正则化的作用与区别
  • pytorch实现L1与L2正则化
  • 对特征进行L2正则化的作用

L1和L2的区别以及范数

  使用机器学习方法解决实际问题时,我们通常要用L1或L2范数做正则化(regularization),从而限制权值大小,减少过拟合风险,故其又称为权重衰减。特别是在使用梯度下降来做目标函数优化时。

L1和L2的区别
在机器学习中,

  • L1范数(L2 normalization)是指向量中各个元素绝对值之和,通常表述为 ∥ w i ∥ 1 \|\boldsymbol{w_i}\|_1 ∥wi​∥1​,线性回归中使用L1正则的模型也叫Lasso regularization
    比如 向量A=[1,-1,3], 那么A的L1范数为 |1|+|-1|+|3|.

  • L2范数指权值向量w中各个元素的平方和然后再求平方根(可以看到Ridge回归的L2正则化项有平方符号),通常表示为 ∥ w i ∥ 2 \|\boldsymbol{w_i}\|_2 ∥wi​∥2​, 线性回归中使用L2正则的模型又叫岭回归(Ringe regularization)。

简单总结一下就是:

  • L1范数: 为x向量各个元素绝对值之和。
  • L2范数: 为x向量各个元素平方和的1/2次方,L2范数又称Euclidean范数或者Frobenius范数
  • Lp范数: 为x向量各个元素绝对值p次方和的1/p次方.

下图为p从无穷到0变化时,三维空间中到原点的距离(范数)为1的点构成的图形的变化情况。以常见的L-2范数(p=2)为例,此时的范数也即欧氏距离,空间中到原点的欧氏距离为1的点构成了一个球面
在这里插入图片描述

参数正则化作用

  • L1: 为模型加入先验, 简化模型, 使权值稀疏,由于权值的稀疏,从而过滤掉一些无用特征,防止过拟合
  • L2: 根据L2的特性,它会使得权值减小,即使平滑权值,一定程度上也能和L1一样起到简化模型,加速训练的作用,同时可防止模型过拟合

关于为什么L1会使得权重稀疏,而L2会使得权值平滑,可以参考知乎上一位答主的台大林轩田老师人工智能基石课笔记,从凸优化,梯度更新,概率分布三个角度诠释L1和L2正则化的原理和区别。我把笔记搬运到这:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

pytorch实现L1与L2正则化

网上很多关于L2和L1正则化的对象都是针对参数的,或者说权重,即权重衰减,可以用pytorch很简单的实现L2惩罚:

class torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

如上,weight_decay参数即为L2惩罚项前的系数
举个栗子,对模型中的某些参数进行惩罚时

#定义一层感知机
net = nn.Linear(num_inputs, 1)
#自定义参数初始化
nn.init.normal_(net.weight, mean=0, std=1)
nn.init.normal_(net.bias, mean=0, std=1)
optimizer_w = torch.optim.SGD(params=[net.weight], lr=lr, weight_decay=wd) # 对权重参数衰减,惩罚项前的系数为wd
optimizer_b = torch.optim.SGD(params=[net.bias], lr=lr)  # 不对偏差参数衰减

而对于L1正则化或者其他的就比较麻烦了,因为pytorch优化器只封装了L2惩罚功能,参考pytorch实现L2和L1正则化regularization的方法

class Regularization(torch.nn.Module):
    def __init__(self,model,weight_decay,p=2):
        '''
        :param model 模型
        :param weight_decay:正则化参数
        :param p: 范数计算中的幂指数值,默认求2范数,
                  当p=0为L2正则化,p=1为L1正则化
        '''
        super(Regularization, self).__init__()
        if weight_decay <= 0:
            print("param weight_decay can not <=0")
            exit(0)
        self.model=model
        self.weight_decay=weight_decay
        self.p=p
        self.weight_list=self.get_weight(model)
        self.weight_info(self.weight_list)
 
    def to(self,device):
        '''
        指定运行模式
        :param device: cude or cpu
        :return:
        '''
        self.device=device
        super().to(device)
        return self
 
    def forward(self, model):
        self.weight_list=self.get_weight(model)#获得最新的权重
        reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
        return reg_loss
 
    def get_weight(self,model):
        '''
        获得模型的权重列表
        :param model:
        :return:
        '''
        weight_list = []
        for name, param in model.named_parameters():
            if 'weight' in name:
                weight = (name, param)
                weight_list.append(weight)
        return weight_list
 
    def regularization_loss(self,weight_list, weight_decay, p=2):
        '''
        计算张量范数
        :param weight_list:
        :param p: 范数计算中的幂指数值,默认求2范数
        :param weight_decay:
        :return:
        '''
        # weight_decay=Variable(torch.FloatTensor([weight_decay]).to(self.device),requires_grad=True)
        # reg_loss=Variable(torch.FloatTensor([0.]).to(self.device),requires_grad=True)
        # weight_decay=torch.FloatTensor([weight_decay]).to(self.device)
        # reg_loss=torch.FloatTensor([0.]).to(self.device)
        reg_loss=0
        for name, w in weight_list:
            l2_reg = torch.norm(w, p=p)
            reg_loss = reg_loss + l2_reg
 
        reg_loss=weight_decay*reg_loss
        return reg_loss
 
    def weight_info(self,weight_list):
        '''
        打印权重列表信息
        :param weight_list:
        :return:
        '''
        print("---------------regularization weight---------------")
        for name ,w in weight_list:
            print(name)
        print("---------------------------------------------------")

class Regularization的使用


# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
print("-----device:{}".format(device))
print("-----Pytorch version:{}".format(torch.__version__))
 
weight_decay=100.0 # 正则化参数
 
model = my_net().to(device)
# 初始化正则化
if weight_decay>0:
   reg_loss=Regularization(model, weight_decay, p=2).to(device)
else:
   print("no regularization")
 
 
criterion= nn.CrossEntropyLoss().to(device) # CrossEntropyLoss=softmax+cross entropy
optimizer = optim.Adam(model.parameters(),lr=learning_rate)#不需要指定参数weight_decay
 
# train
batch_train_data=...
batch_train_label=...
 
out = model(batch_train_data)
 
# loss and regularization
loss = criterion(input=out, target=batch_train_label)
if weight_decay > 0:
   loss = loss + reg_loss(model)
total_loss = loss.item()
 
# backprop
optimizer.zero_grad()#清除当前所有的累积梯度
total_loss.backward()
optimizer.step()

特征正则化作用

上面介绍了对于权重进行正则化的作用以及具体实现,其实在很多模型中,也会对特征采用L2归一化,有的时候在训练模型时,经过几个batch后,loss会变成nan,此时,如果你在特征后面加上L2归一化,可能可以很好的解决这个问题,而且有时会影响训练的效果,深有体会。
L2正则的原理比较简单,如下公式:
y = x i ∑ i = 0 D x i 2 \boldsymbol{y} = \frac{\boldsymbol{x_i}}{\sum_{i=0}^D\boldsymbol{{x_i}}^2 } y=∑i=0D​xi​2xi​​
其中D为向量的长度,经过l2正则后 x i \boldsymbol{x_i} xi​向量的元素平方和等于1

python实现

def l2norm(X, dim=-1, eps=1e-12):
    """L2-normalize columns of X
    """
    norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps
    X = torch.div(X, norm)
    return X

在SSD目标检测的conv4_3层便使用了L2Norm

对特征进行L2正则的具体作用如下:

  • 防止梯度消失或者梯度爆炸
  • 统一量纲,加快模型收敛

参考:

机器学习中L1和L2的直观理解
几种范数的介绍

标签:weight,decay,正则,L2,L1,范数
来源: https://blog.csdn.net/baidu_32885165/article/details/110085688

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有