ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

2021-07-23 22:31:57  阅读:234  来源: 互联网

标签:loss 教程 plt nn torch PyTorch fake MNIST size


@Author:Runsen

GAN 是使用两个神经网络模型训练的生成模型。一种模型称为生成网络模型,它学习生成新的似是而非的样本。另一个模型被称为判别网络,它学习区分生成的例子和真实的例子。

生成性对抗网络

2014,蒙特利尔大学的Ian Goodfellow和他的朋友发明了生成性对抗网络(GAN)。自它出版以来,有许多它的变体和客观功能来解决它的问题

论文在这里找到.

论文提出了两种模型:生成模型和判别模型。两个模型竞争,以产生真实和假的样本。2016年,Yann LeCun将GANs描述为“过去二十年机器学习中最酷的想法”。

GAN 的大部分研究和应用都集中在计算机视觉领域。

其原因是卷积神经网络 (CNN) 等深度学习模型在过去 5 到 7 年中在计算机视觉领域取得了巨大成功,例如在具有挑战性的任务(如对象检测和人脸识别。

GAN 的典型例子是生成新的逼真的照片,最令人吃惊的是生成照片般逼真的人脸的例子。

在本教程中,我们将实现一个简单的GAN生成假的MNIST样本。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils

import numpy as np
import matplotlib.pyplot as plt
# CPU / GPU Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)  #cuda

使用MNIST数据集,具有最小大小的数据集。

它由60000个训练图像和10000个测试图像组成,每个图像有28*28的大小和一个彩色通道。

# Define a transform 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.5, ), std = (0.5, ))
])

# batch_size是一个前向和后向传播过程中的图像数。
batch_size = 100

mnist = datasets.MNIST('./data/MNIST', 
                       download = True, 
                       train = True, 
                       transform = transform)

mnist_loader = DataLoader(dataset = mnist, 
                          batch_size = batch_size, 
                          shuffle = True)
# CPU
def imshow(img, title):
    img = utils.make_grid(img.cpu().detach())
    img = (img+1)/2
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()
#GPU
def imshow(img, title):
    npimg = img.detach().numpy()
    fig = plt.figure(figsize = (10, 10))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()

images, labels = iter(mnist_loader).next()
imshow(images[0:16, :, :], "MNIST Images")

建立一个GANs模型。一个Generator和Discriminator

GANs由完全连接的层组成。它将从100维高斯分布采样的噪声转换为MNIST图像。鉴别器网络也由完全连接的层组成,用于区分输入数据是真是假。

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        latent_size = 100
        output = 28*28
        
        self.main = nn.Sequential(
            nn.Linear(latent_size, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, output),
            nn.Tanh()
        )
        
    def forward(self, x):
        out = self.main(x)
        out = out.view(-1, 1, 28, 28)
        return out


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        n_features = 28 * 28
        n_out = 1
        
        self.main = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(inplace=True),
            
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
            
            nn.Linear(64, n_out),
            nn.Sigmoid()        
        )
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        out = self.main(x)
        return out

G = Generator().to(device)
D = Discriminator().to(device)

生成性对抗网络训练过程的损失函数是二进制交叉熵损失,由torch.nn.BCELoss实现。

这两种模型都使用torch.optim.Adam作为优化工具,学习率设置为0.002。

# Objective Function
criterion = nn.BCELoss()

# Optimizer
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)

# Constants
noise_dim = 100
num_epochs = 50
total_batch = len(mnist_loader)

# Lists
G_losses = []
D_losses = []

# Noise
sample_size = 16
fixed_noise = torch.randn(sample_size, noise_dim).to(device)

# Train
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(mnist_loader):
        
        # Images #
        images = images.reshape(batch_size, -1).float().to(device)
        
        # Labels #
        ones = torch.ones(batch_size, 1).to(device)
        zeros = torch.zeros(batch_size, 1).to(device)
        
        # Noise #
        noise = torch.randn(batch_size, noise_dim).to(device)
        
        # Initialize Optimizers
        D_optimizer.zero_grad()
        G_optimizer.zero_grad()
        
        #######################
        # Train Discriminator #
        #######################
        
        # Forward Images #
        prob_real = D(images)
        D_real_loss = criterion(prob_real, ones)
        
        # Generate Samples #
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # Forward Fake Samples and Calculate Discriminator Loss #
        D_fake_loss = criterion(prob_fake, zeros)
        D_loss = (D_real_loss + D_fake_loss).mean()
        
        # Back Propagation and Update
        D_loss.backward()
        D_optimizer.step()
        
        ###################
        # Train Generator #
        ###################
        
        fake_images = G(noise)
        prob_fake = D(fake_images)
        
        # According to the section 3 in paper,
        # early in learning, when G is very poor, D can reject samples from G.
        # In this case, log(1-D(G(z))) saturates. 
        # thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))
        G_loss = criterion(prob_fake, ones)
        
        # Back Propagation and Update
        G_loss.backward()
        G_optimizer.step()
        
        # Save Losses for Plotting Later
        G_losses.append(G_loss.item())
        D_losses.append(D_loss.item())
        
        # Print Statistics #
        if (i + 1) % 100 == 0:
            print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"
                  %(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))
    
    # Generate Samples #
    if epoch % 1 == 0:
        fake_samples = G(fixed_noise)
        imshow(fake_samples, "Generated MNIST Images")
    
# Save Model Weights for Digit Generation
torch.save(G.state_dict(), './data/GAN.pkl')

plt.figure(figsize = (8, 6))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Losses")
plt.legend()
plt.show()

sample_size = 64
noise_dim = 100

noise = torch.randn(sample_size, noise_dim).to(device)

G.load_state_dict(torch.load('GAN.pkl'))
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")

GAN生成性对抗网络的运用

  • 将语义图像翻译成城市景观和建筑物的照片。
  • 将卫星照片翻译成地图。
  • 从白天到晚上的照片翻译。
  • 将黑白照片翻译成彩色。


- 论文在这里找到.

- 上述代码的论文.

- 上述代码.

标签:loss,教程,plt,nn,torch,PyTorch,fake,MNIST,size
来源: https://blog.csdn.net/weixin_44510615/article/details/119044979

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有