ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

GAN动漫人物头像生成

2022-02-24 16:01:16  阅读:306  来源: 互联网

标签:loss nn 动漫 torch 生成器 GAN 头像 fake size


GAN动漫人物头像生成

1.简介

搭建了一个简单的DCGAN网络生成动漫人物的头像,其中动漫人物头像数据集取自kaggle,网址如下
link

2.网络结构

  1. 数据集
  2. 生成器
  3. 判别器

2.1数据集

数据大小为64x64x3,样例如下
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

2.2生成器

由于生成器的原始输入是n维噪声,若想生成与数据集大小相同的图片,则需要进行上采样,这里我们用到的方法是转置卷积,通过pytorch中的ConvTransposed2d来实现。
生成器代码如下:

class Generator(nn.Module):

    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # out_shape = (1-1)*1-2*0+4 = 4*4
            nn.ConvTranspose2d(noise_dim, 256, kernel_size=4),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # out_shape = (4-1)*2-2*1+4 = 8*8
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # out_shape = (8-1)*2-2*1+4 = 16*16
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # out_shape = (16-1)*2-2*1+4 = 32*32
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            # out_shape = (32-1)*2-2*1+4 = 64*64
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, input):

        output = self.net(input)
        return output

在训练阶段,我们会生成batch_sizex100x1x1大小的随机噪声,然后经过生成器的上采样,实现与数据集图片大小相同的伪图片,然后送到判别器中进行真假图片的辨别。

2.3判别器

判别器的输入为从数据集中采样的真实图片与生成器生成的伪图片,输出为0-1之间的数值,因此网络尾端使用了Sigmoid激活函数。
判别器的目的是对真实图片判别为“1”(真),对伪图片判别为“0”(假),而生成器的目的是生成的伪图片足够好,足够逼近数据集的分布,以此骗过生成器,因此生成器希望自己生成的伪图片在判别器中得分越接近“1”(真)越好。这样判别器与生成器不断“对抗”,最后达到平衡或接近平衡。
判别器代码如下:

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # 32*32*32
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 16*16*64
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 8*8*128
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # 4*4*256
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(4*4*256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, input):

        output = self.net(input)
        return output.view(-1)

判别器的网络就是简单的前馈神经网络,经过卷积不断的下采样,提取图片的特征,最后输出为真或为假的0-1之间的得分。

3.训练阶段

训练阶段的大体流程跟深度学习训练流程相差无几,最重要的部分是label与损失函数的设计与计算。
先贴上训练阶段的代码:

import torch
import torch.nn as nn
from torchvision import transforms
from create_dataset import My_dataset, save_img
from torch.utils.data import DataLoader
from net import Generator, Discriminator

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


dataset = My_dataset('./data', transform=transform)
batch_size, epochs = 256, 200
my_dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

discriminator = Discriminator()
generator = Generator()
if torch.cuda.is_available():

    discriminator = discriminator.cuda()
    generator = generator.cuda()


d_optimizer = torch.optim.Adam(discriminator.parameters(), betas=(0.5, 0.99), lr=1e-4)
g_optimizer = torch.optim.Adam(generator.parameters(), betas=(0.5, 0.99), lr=1e-4)
criterion = nn.BCELoss()

for epoch in range(epochs):

    for i, img in enumerate(my_dataloader):

        noise = torch.randn(batch_size, 100, 1, 1).cuda()
        real_img = img.cuda()
        fake_img = generator(noise)

        real_label = torch.ones(batch_size).cuda()
        fake_label = torch.zeros(batch_size).cuda()
        real_out = discriminator(real_img)
        fake_out = discriminator(fake_img)
        real_loss = criterion(real_out, real_label)
        fake_loss = criterion(fake_out, fake_label)

        d_loss = real_loss + fake_loss
        d_optimizer.zero_grad()

        d_loss.backward()
        d_optimizer.step()

        noise = torch.randn(batch_size, 100, 1, 1).cuda()
        fake_img = generator(noise)
        output = discriminator(fake_img)

        g_loss = criterion(output, real_label)
        g_optimizer.zero_grad()

        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 5 == 0:
            print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f} '
                  'D_real: {:.6f},D_fake: {:.6f}'.format(
                epoch, epochs, d_loss.data.item(), g_loss.data.item(),
                real_out.data.mean(), fake_out.data.mean()  # 打印的是真实图片的损失均值
            ))
        if epoch == 0 and i == len(my_dataloader) - 1:
            save_img(img[:64, :, :, :], './sample/real_images.png')
        if (epoch+1) % 10 == 0 and i == len(my_dataloader)-1:
            save_img(fake_img[:64, :, :, :], './sample/fake_images_{}.png'.format(epoch + 1))

torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')

在训练之前,首先要人为的设置图片为真、假的label,这里我们设置为1,使用torch.ones函数实现,设置为0,使用torch.zeros函数实现。
然后就是对数据集中的图片以及生成器生成的伪图片进行判别损失的计算,如代码中d_loss。
接下来是对生成器损失的计算,因为生成器的目的是生成的图片越真越好,所以生成器损失的计算的label是1。如代码中g_loss。

4.反归一化及结果

4.1反归一化

因为对数据集进行了归一化及标准化处理,所以在显示生成器结果时需要进行反归一化,在这里我先是使用了torchvision中的save_image去保存生成器的结果,但是该官方函数的反归一化与我们的归一化过程不符,导致该函数保存的图片有些暗,如下所示(下图为数据集中的真实图片):
在这里插入图片描述
所以这里另进行了数据的反归一化过程并使用torchvision中的make_grid函数进行保存,结果如下(为数据集中的真实图片):
在这里插入图片描述

4.2结果

训练200轮,每10轮保存一次结果,其中10,50,100,150,200轮的结果如下图所示:
epoch10
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
可以看到,生成器生成的图片在逐步清晰,且越来越逼近数据集的分布

5.总结

最后效果还是不太好,GAN的训练过程也不是太稳定,尤其是如何让图片更加清晰,不模糊仍然是一个比较“棘手”的问题。
(新手小白第一次写博客,大神勿喷)

最后,全部代码可见我的github

标签:loss,nn,动漫,torch,生成器,GAN,头像,fake,size
来源: https://blog.csdn.net/weixin_43706434/article/details/123110332

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有