ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

GFPGAN源码分析—第八篇

2021-12-26 23:35:13  阅读:203  来源: 互联网

标签:opt load network 第八篇 self 源码 GFPGAN path net


2021SC@SDUSC

源码:

models\init.py

models\gfpgan_model.py

本篇主要分析init.py与models\gfpgan_model.py下的

class GFPGANModel(BaseModel) 类init(self, opt) 方法

目录

init.py

gfpgan_model.py

class GFPGANModel(BaseModel)

init(self, opt)

init_training_settings(self)


init.py

自动扫描和导入注册表的模型模块

#在models文件夹下扫描所有以 '_model.py' 结尾的文件
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# 导入所有模型模块
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]

那么实际上就是导入models文件夹下gfpgan_model.py文件,接下来我们来看一下

gfpgan_model.py

本文件中只包含GFPGANModel(BaseModel)一个类

创建了一个MODEL_REGISTRY对象,并在类定义的时候用装饰器装饰它,以装饰器的形式调用MODEL_REGISTRY类的register函数

@MODEL_REGISTRY.register()
class GFPGANModel(BaseModel):
    """GFPGAN model for <Towards real-world blind faces restoratin with generative facial prior>"""

class GFPGANModel(BaseModel)

基于生成性人脸先验信息的真实盲脸修复 的 GFPGAN 模型

init(self, opt)

简单看一下代码

super(GFPGANModel, self).__init__(opt)
self.idx = 0

# 网络定义
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)

# 读取预训练的模型
load_path = self.opt['path'].get('pretrain_network_g', None)
#如果路径不为空
if load_path is not None:
    param_key = self.opt['path'].get('param_key_g', 'params')
    self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)

self.log_size = int(math.log(self.opt['network_g']['out_size'], 2))

if self.is_train:
    self.init_training_settings()

在读取预训练的模型时,实际上就是从train_gfpgan_v1.yml配置文件中读取到相应的参数的数值与路径。

init_training_settings(self)

初始化训练设置

1.读取opt['train']

train_opt = self.opt['train']

2.定义net_d

#构建网络
self.net_d = build_network(self.opt['network_d'])
#将模型放到gpu(cuda)上
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# 读取与训练好的模型
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
    self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))

3.定义net_g

# net_g_ema 仅用于在一个GPU上测试并保存
# 不需要使用DistributedDataParallel进行包装
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# 读取预训练模型
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
    self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
    self.model_ema(0)  # copy net_g weight

self.net_g.train()
self.net_d.train()
self.net_g_ema.eval()

根据配置文件:net_g读取预训练模型为arcface_resnet18.pth

4.面部组件网络

if ('network_d_left_eye' in self.opt and 'network_d_right_eye' in self.opt and 'network_d_mouth' in self.opt):
            self.use_facial_disc = True
        else:
            self.use_facial_disc = False

        if self.use_facial_disc:
            # left eye
            self.net_d_left_eye = build_network(self.opt['network_d_left_eye'])
            self.net_d_left_eye = self.model_to_device(self.net_d_left_eye)
            self.print_network(self.net_d_left_eye)
            load_path = self.opt['path'].get('pretrain_network_d_left_eye')
            if load_path is not None:
                self.load_network(self.net_d_left_eye, load_path, True, 'params')
            # right eye
            self.net_d_right_eye = build_network(self.opt['network_d_right_eye'])
            self.net_d_right_eye = self.model_to_device(self.net_d_right_eye)
            self.print_network(self.net_d_right_eye)
            load_path = self.opt['path'].get('pretrain_network_d_right_eye')
            if load_path is not None:
                self.load_network(self.net_d_right_eye, load_path, True, 'params')
            # mouth
            self.net_d_mouth = build_network(self.opt['network_d_mouth'])
            self.net_d_mouth = self.model_to_device(self.net_d_mouth)
            self.print_network(self.net_d_mouth)
            load_path = self.opt['path'].get('pretrain_network_d_mouth')
            if load_path is not None:
                self.load_network(self.net_d_mouth, load_path, True, 'params')

            self.net_d_left_eye.train()
            self.net_d_right_eye.train()
            self.net_d_mouth.train()

            # ----------- 定义面部组件的 gan loss ----------- #
            self.cri_component = build_loss(train_opt['gan_component_opt']).to(self.device)

5.定义损失

if train_opt.get('pixel_opt'):
    self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
    self.cri_pix = None

if train_opt.get('perceptual_opt'):
    self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
    self.cri_perceptual = None
    # pyramid loss, component style loss, identity loss 都使用L1损失
    self.cri_l1 = build_loss(train_opt['L1_opt']).to(self.device)

    # gan loss (wgan)
    self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)

6.identity loss的定义

if 'network_identity' in self.opt:
    self.use_identity = True
else:
    self.use_identity = False

if self.use_identity:
    # 定义 identity network
    self.network_identity = build_network(self.opt['network_identity'])
    self.network_identity = self.model_to_device(self.network_identity)
    self.print_network(self.network_identity)
    load_path = self.opt['path'].get('pretrain_network_identity')
    if load_path is not None:
        self.load_network(self.network_identity, load_path, True, None)
    self.network_identity.eval()
    for param in self.network_identity.parameters():
        param.requires_grad = False

# 正则化权重
self.r1_reg_weight = train_opt['r1_reg_weight']  # for discriminator
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_init_iters = train_opt.get('net_d_init_iters', 0)
self.net_d_reg_every = train_opt['net_d_reg_every']

# 设置优化器和调度程序
self.setup_optimizers()
self.setup_schedulers()

标签:opt,load,network,第八篇,self,源码,GFPGAN,path,net
来源: https://blog.csdn.net/Vaifer233/article/details/122162924

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有