ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Generative Adversarial Network (GAN) - Pytorch版

2019-08-31 11:00:07  阅读:306  来源: 互联网

标签:loss images 生成器 鉴别器 GAN Pytorch fake Generative size


import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image


# 配置GPU或CPU设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数设置
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Pytorch:transforms的二十二个方法:https://blog.csdn.net/weixin_38533896/article/details/86028509#10transformsNormalize_120
# 对Image数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc
transform = transforms.Compose([
                transforms.ToTensor(),# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])

# 下载数据,并指定转换形式transform
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',
                                   train=True,
                                   transform=transform,
                                   download=True)
# 数据加载,按照batch_size大小加载,并随机打乱
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)
# 鉴别器
# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())
# 生成器
# Generator
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

# GPU或CPU设置
# Device setting
D = D.to(device)
print(D)
# Sequential((0): Linear(in_features=784, out_features=256, bias=True)
#            (1): LeakyReLU(negative_slope=0.2)
#            (2): Linear(in_features=256, out_features=256, bias=True)
#            (3): LeakyReLU(negative_slope=0.2)
#            (4): Linear(in_features=256, out_features=1, bias=True)
#            (5): Sigmoid())
G = G.to(device)
print(G)
# Sequential( (0): Linear(in_features=64, out_features=256, bias=True)
#             (1): ReLU()
#             (2): Linear(in_features=256, out_features=256, bias=True)
#             (3): ReLU()
#             (4): Linear(in_features=256, out_features=784, bias=True)
#             (5): Tanh())

# 二值交叉熵损失函数和优化器设置
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
# 优化器设置 ,并传入鉴别器与生成器模型参数和相应的学习率
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

# 规范化处理
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1) # 将out张量每个元素的范围限制到区间 [min,max]

# 清空上一步的残余更新参数值
def reset_grad():
    d_optimizer.zero_grad() # 清空鉴别器的梯度器上一步的残余更新参数值
    g_optimizer.zero_grad() # 清空生成器的梯度器上一步的残余更新参数值


# 开始训练
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)

        # 创建label
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # ================================================================== #
        #                            训练鉴别器                              #
        # ================================================================== #
        # 使用真实图像计算二值交叉熵损失
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)# 真图像输入给鉴别器,并产生鉴别器输出
        d_loss_real = criterion(outputs, real_labels) # 计算由真图像输入给鉴别器产生的输出与真实的label间的二值交叉熵损失
        real_score = outputs# 鉴别器输出真实图像score值

        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)# 随机生成假图像
        fake_images = G(z)# 假图像输入给生成器,并产生生成器输出假值图
        outputs = D(fake_images)# 生成器输出假值图给鉴别器鉴别,输出鉴别结果
        d_loss_fake = criterion(outputs, fake_labels)# 由随机产生的假图像输入给生成器产生的假图,计算生成器生成的假图输入给鉴别器鉴别输出与假的标签间的二值交叉熵损失
        fake_score = outputs# 鉴别器输出假图像score值

        # 反向传播与优化
        d_loss = d_loss_real + d_loss_fake#真图像输入给鉴别器产生的输出与真实的label间的二值交叉熵损失和假图输入给鉴别器鉴别输出与假的标签间的二值交叉熵损失
        # 重置梯度求解器
        reset_grad()
        # 反向传播
        d_loss.backward()
        # 将参数更新值施加到鉴别器 model的parameters上
        d_optimizer.step()

        # ================================================================== #
        #                             训练生成器                             #
        # ================================================================== #
        # 计算假图像的损失
        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)# 随机生成假图像
        fake_images = G(z)# 假图像输入给生成器,并产生生成器输出假值图
        outputs = D(fake_images)# 生成器输出假值图给鉴别器鉴别,输出鉴别结果

        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)# 由随机产生的假图像输入给生成器产生的假图,计算生成器生成的假图输入给鉴别器鉴别输出与真的标签间的二值交叉熵损失

        # 反向传播与优化
        # 重置梯度求解器
        reset_grad()
        # 反向传播
        g_loss.backward()
        # 将参数更新值施加到生成器 model的parameters上
        g_optimizer.step()
        # 每迭代一定步骤,打印结果值
        if (i + 1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))
    # 保存真图像
    # Save real images
    if (epoch + 1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

    # 保存假或采样图像
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))

# 保存以训练好的生成器与鉴别器模型
# Save the model checkpoints
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

  

标签:loss,images,生成器,鉴别器,GAN,Pytorch,fake,Generative,size
来源: https://www.cnblogs.com/jeshy/p/11438245.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有