ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch 剪枝

2022-08-16 19:33:47  阅读:178  来源: 互联网

标签:剪枝 修剪 prune weight module PyTorch print


pytorch 实现剪枝的思路是 生成一个掩码,然后同时保存 原参数、mask、新参数,如下图

 

pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;

局部剪枝 是对 模型内 的部分模块 的 部分参数 进行剪枝,全局剪枝是对  整个模型进行剪枝;

 

本文旨在记录 pytorch 剪枝模块的用法,首先让我们构建一个模型

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

下面对 这个模型进行剪枝

 

局部剪枝

以修剪 第一层卷积  模块 为例

module = model.conv1
print(list(module.named_parameters()))
print(list(module.buffers()))

# 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数
# random_unstructured 是一种裁剪技术,随机非结构化裁剪
prune.random_unstructured(module, name="weight", amount=0.3)      # weight    bias
print(list(module.named_parameters()))

# 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区
print(list(module.named_buffers()))

# 新的参数保存为模块 的weight属性
print(module.weight)
# print(module.bias)

print(module._forward_pre_hooks)
# OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>)])

named_parameters() 内 存储的对象 除非手动删除,否则在剪枝过程中对其无影响

 

迭代剪枝

迭代剪枝 是 对 同一模块 进行 多种剪枝,执行逻辑是 顺序执行各剪枝操作

在之前  随机非结构化剪枝 的基础上进行  L1 L2 非结构化剪枝

## 增加一个修剪,看看变化
# l1范数修剪bias中3个最小条目
prune.l1_unstructured(module, name="bias", amount=3)
print(module.bias)
print(module._forward_pre_hooks)
# OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x000002695EBCEC18>),
#              (1, <torch.nn.utils.prune.L1Unstructured object at 0x000002695DE5CEB8>)])

print(list(module.named_parameters()))
print(list(module.named_buffers()))


### 迭代修剪
# 一个模块中的同一参数可以被多次修剪,多次修剪会顺序执行
# 如在之前的基础上,对 weight 参数继续修剪
# l2 结构化裁剪,n=2代表l2,dim=0代表在weight的第0轴进行结构化裁剪
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# 查看 weight 参数的 剪枝 操作
for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))
# [<torch.nn.utils.prune.RandomUnstructured object at 0x0000020AE2A6EC18>,
# <torch.nn.utils.prune.LnStructured object at 0x0000020AA872DE80>]

print(module.state_dict().keys())
# odict_keys(['weight_orig', 'bias_orig', 'weight_mask', 'bias_mask'])

 

修剪模型中的多个参数

### 修剪模型中的多个参数
new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

 

全局剪枝

以上研究通常被称为“局部”修剪方法,即通过比较每个条目的统计信息(权重,激活度,梯度等)来逐一修剪模型中的张量的做法。

但是,一种常见且可能更强大的技术是通过删除整个模型中最低的 20%的连接,

而不是删除每一层中最低的 20%的连接来修剪模型。

这很可能导致每个层的修剪百分比不同。

让我们看看如何使用torch.nn.utils.prune中的global_unstructured进行操作

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)
# 检查每个修剪参数的稀疏性,该稀疏性不等于每层中的 20%。 但是,全局稀疏度将(大约)为 20%

 

自定义剪枝

见  参考资料3

 

训练中剪枝实例

见参考资料1

 

 

 

 

参考资料:

https://blog.csdn.net/qq_40268672/article/details/108631518  pytorch剪枝实战     训练时剪枝,类似 dropout 

https://blog.csdn.net/ssunshining/article/details/125121066  PyTorch--模型剪枝案例

https://www.w3cschool.cn/pytorch/pytorch-rnmi3bti.html  PyTorch 修剪教程

标签:剪枝,修剪,prune,weight,module,PyTorch,print
来源: https://www.cnblogs.com/yanshw/p/16592678.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有