ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

代码笔记12 PyTorch加载部分模型参数到另一个模型

2022-05-25 23:33:30  阅读:178  来源: 互联网

标签:12 模型 backbone channels PyTorch train dict net self


1

 首先,加载是有条件的,就是两个模型想要加载的参数的名字相同,才能加载进来。

2

 加载的方法有两种,load_state_dict(strict)和update
代码里的方法是对backbone单独做一个Module类,这样不容易出错。
代码中展示了如何把train_net中的backbone参数加载到test_net中的两种办法

import torch
import torch.nn as nn
import torch.nn.functional as F


class backbone(nn.Module):
    def __init__(self):
        super(backbone, self).__init__()
        self.backbone_conv1 = nn.Conv2d(in_channels=1,out_channels=3,kernel_size=2,stride=1,padding=1)
        self.normal1 = nn.GroupNorm(num_groups=1,num_channels=3)
        self.backbone_conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=2, stride=1, padding=1)
        self.normal2 = nn.GroupNorm(num_groups=1, num_channels=3)
    def forward(self,input):
        #Stage 1
        conv1 = self.normal1(self.backbone_conv1(input))
        pool1,id1 = F.max_pool2d(F.relu(conv1),kernel_size=2, stride=2, return_indices=True)

        #Stage2
        conv2 = self.normal2(self.backbone_conv1(pool1))
        pool2,id2 = F.max_pool2d(F.relu(conv2),kernel_size=2, stride=2, return_indices=True)

        return pool2,id2

class train(nn.Module):
    def __init__(self):
        super(train, self).__init__()
        self.backbone_RGB = backbone()
        self.train_conv1 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=2,stride=1,padding=0)
        self.train_conv2 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=2,stride=1,padding=0)
        self.train_conv3 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=2,stride=1,padding=0)

    def forward(self,input):
        #Stage1
        train_input,id = self.backbone_RGB(input)
        x1 = self.train_conv1(train_input)
        x2 = self.train_conv2(x1)
        x3 = self.train_conv3(x2)

        return x3

class test(nn.Module):
    def __init__(self):
        super(test,self).__init__()
        self.backbone_RGB = backbone()
        self.test_conv1 = nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=0)
        self.test_conv2 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=3,stride=1,padding=0)
        self.test_conv3 = nn.Conv2d(in_channels=3, out_channels=3,kernel_size=3,stride=1,padding=0)

    def forward(self,input):
        #Stage1
        train_input,id = self.backbone_RGB(input)
        x1 = self.train_conv1(train_input)
        x2 = self.train_conv2(x1)
        x3 = self.train_conv3(x2)

        return x3

backbone_net = backbone()
for name,parameter in backbone_net.state_dict().items():
    print(name)


print("------------------------------------------------------------")

train_net = train()
for name,parameter in train_net.state_dict().items():
    print(name)
    print(parameter)


print("------------------------------------------------------------")


test_net = test()
for name,parameter in test_net.state_dict().items():
    print(name)
    print(parameter)

#method1 load_state_dict(strict)

test_net.load_state_dict(train_net.state_dict(), strict=False) #set strict to False,for loading the same name parameters

for name,parameter in test_net.state_dict().items():
    print(name)
    print(parameter)

#method2 update

model_dict=test_net.state_dict()   #load test_net names and paras
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in train_net.state_dict().items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
test_net.load_state_dict(model_dict)


具体的也可以看看这几个博客,我也是从这里面学的[1],包括如何冻结参数[2]

Refrences

[1]https://blog.csdn.net/qq_41314786/article/details/112569854
[2]https://blog.csdn.net/weixin_44815943/article/details/113180588?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0-113180588-blog-112569854.pc_relevant_default&spm=1001.2101.3001.4242.1&utm_relevant_index=3

标签:12,模型,backbone,channels,PyTorch,train,dict,net,self
来源: https://www.cnblogs.com/HumbleHater/p/16311425.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有