ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

GAN系列2:利用简单的GAN生成手写体图像

2021-05-22 19:02:09  阅读:208  来源: 互联网

标签:loss generator nn fake GAN 图像 手写体 data size


目的:

基于pytorch利用GAN生成手写体图像;

系列内容:

一、学习GAN基本架构;

二、生成器和判别器的训练;

三、GAN中生成器和判别器的损失函数;

四、各种应用GAN的架构;

 

训练判别器:

1)得到真实数据和真实标签(真实标签标记为1);真实标签的长度应该等于batch size的长度;

2)前向传播,将真实的数据传给班别器,得到来自真实数据的真实输出;

3)计算判别器损失从真实的输出和标签中,并且反向传播它;

4)使用生成的数据,通过生成器进行前向传播,计算生成数据的输出和生成数据的损失;反向传播生成数据的损失;通过计算真实数据损失和生成数据损失,计算整体损失;

5)更新判别器的参数;

 

训练生成器:

1)通过前向传播得到生成器的生成数据;标记为1;

2)通过判别器做前向传播;

3)计算损失并且反向传播;

4)更新并优化生成器参数;

 

文件结构:

├───input
├───outputs
└───src
        vanilla_gan.py

 

代码实现:

我们将在vanilla_gan.py中实现我们所有的代码;

1) 导入包

 1 import torch
 2 import torch.nn as nn
 3 import torchvision.transforms as transforms
 4 import torch.optim as optim
 5 import torchvision.datasets as datasets
 6 import imageio
 7 import numpy as np
 8 import matplotlib
 9 from torchvision.utils import make_grid, save_image
10 from torch.utils.data import DataLoader
11 from matplotlib import pyplot as plt
12 from tqdm import tqdm
13 matplotlib.style.use('ggplot')

make_grid()和save_image有助于图像的存储;

 

2)学习参数的定义:

1 # learning parameters
2 batch_size = 512  
3 epochs = 200 
4 sample_size = 64 # fixed sample size
5 nz = 128 # latent vector size
6 k = 1 # number of steps to apply to the discriminator
7 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

 

3)数据集的准备

1 transform = transforms.Compose([
2                                 transforms.ToTensor(),
3                                 transforms.Normalize((0.5,),(0.5,)),
4 ])
5 to_pil_image = transforms.ToPILImage()

Line5将数据转换为PIL图像格式;这是必要的;当我们想存储GAN生成的图像;在存储之前,必须转换为PIL图像格式;

train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

定义GAN的判别器和分类器

生成器:

使用简单的线性层;

 

 1 class Generator(nn.Module):
 2     def __init__(self, nz):
 3         super(Generator, self).__init__()
 4         self.nz = nz
 5         self.main = nn.Sequential(
 6             nn.Linear(self.nz, 256),  #输入的特征为128,输出256
 7             nn.LeakyReLU(0.2),
8 nn.Linear(256, 512), 9 nn.LeakyReLU(0.2),
10 nn.Linear(512, 1024), 11 nn.LeakyReLU(0.2),
12 nn.Linear(1024, 784), 13 nn.Tanh(), 14 )
15 def forward(self, x): 16 return self.main(x).view(-1, 1, 28, 28)

 

判别器:

 1 class Discriminator(nn.Module):
 2     def __init__(self):
 3         super(Discriminator, self).__init__()
 4         self.n_input = 784
 5         self.main = nn.Sequential(
 6             nn.Linear(self.n_input, 1024),
 7             nn.LeakyReLU(0.2),
 8             nn.Dropout(0.3),
 9             nn.Linear(1024, 512),
10             nn.LeakyReLU(0.2),
11             nn.Dropout(0.3),
12             nn.Linear(512, 256),
13             nn.LeakyReLU(0.2),
14             nn.Dropout(0.3),
15             nn.Linear(256, 1),
16             nn.Sigmoid(),
17         )
18     def forward(self, x):
19         x = x.view(-1, 784)
20         return self.main(x)

初始化NN、定义优化器

1 generator = Generator(nz).to(device)
2 discriminator = Discriminator().to(device)
3 print('##### GENERATOR #####')
4 print(generator)
5 print('######################')
6 print('\n##### DISCRIMINATOR #####')
7 print(discriminator)
8 print('######################')

优化器:

1 # optimizers
2 optim_g = optim.Adam(generator.parameters(), lr=0.0002)
3 optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

损失函数:

1 # loss function
2 criterion = nn.BCELoss()

每次迭代后的损失存储:

1 losses_g = [] # to store generator loss after each epoch
2 losses_d = [] # to store discriminator loss after each epoch
3 images = [] # to store images generatd by the generator

定义一些其他函数:

在GAN训练过程中,我们需要真实图像和生成图像的标记,用于计算损失;

定义两个函数,用于生成1和0

1 # to create real labels (1s)
2 def label_real(size):
3     data = torch.ones(size, 1)
4     return data.to(device)
5 # to create fake labels (0s)
6 def label_fake(size):
7     data = torch.zeros(size, 1)
8     return data.to(device)

在生成器中,我们也需要一个噪音向量,这个向量应该等于nz(128)用于生成图像;

1 # function to create the noise vector
2 def create_noise(sample_size, nz):
3     return torch.randn(sample_size, nz).to(device)

这个函数接受两个参数:sample_size以及nz。

它将返回一个随机向量,后续用于输入生成器中生成假的图像;

 

最后保存生成的图像

1 # to save the images generated by the generator
2 def save_generator_image(image, path):
3     save_image(image, path)

 

训练判别器的函数:

 1 # function to train the discriminator network
 2 def train_discriminator(optimizer, data_real, data_fake):
 3     b_size = data_real.size(0)
 4     real_label = label_real(b_size)
 5     fake_label = label_fake(b_size)
 6     optimizer.zero_grad()
 7     output_real = discriminator(data_real)
 8     loss_real = criterion(output_real, real_label)
 9     output_fake = discriminator(data_fake)
10     loss_fake = criterion(output_fake, fake_label)
11     loss_real.backward()
12     loss_fake.backward()
13     optimizer.step()
14     return loss_real + loss_fake

训练GAN

1 # create the noise vector
2 noise = create_noise(sample_size, nz)

 

1 generator.train()
2 discriminator.train()

开始训练:

 1 for epoch in range(epochs):
 2     loss_g = 0.0
 3     loss_d = 0.0
 4     for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
 5         image, _ = data
 6         image = image.to(device)
 7         b_size = len(image)
 8         # run the discriminator for k number of steps
 9         for step in range(k):
10             data_fake = generator(create_noise(b_size, nz)).detach()
11             data_real = image
12             # train the discriminator network
13             loss_d += train_discriminator(optim_d, data_real, data_fake)
14         data_fake = generator(create_noise(b_size, nz))
15         # train the generator network
16         loss_g += train_generator(optim_g, data_fake)
17     # create the final fake image for the epoch
18     generated_img = generator(noise).cpu().detach()
19     # make the images as grid
20     generated_img = make_grid(generated_img)
21     # save the generated torch tensor models to disk
22     save_generator_image(generated_img, f"../outputs/gen_img{epoch}.png")
23     images.append(generated_img)
24     epoch_loss_g = loss_g / bi # total generator loss for the epoch
25     epoch_loss_d = loss_d / bi # total discriminator loss for the epoch
26     losses_g.append(epoch_loss_g)
27     losses_d.append(epoch_loss_d)
28     
29     print(f"Epoch {epoch} of {epochs}")
30     print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

 

标签:loss,generator,nn,fake,GAN,图像,手写体,data,size
来源: https://www.cnblogs.com/xmd-home/p/14799520.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有