ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

深度学习:GAN案例练习-minst手写数字

2021-10-13 20:58:42  阅读:242  来源: 互联网

标签:loss 判别 nn self img minst GAN fake 手写


目录)

理论

参考:GAN原理详解

目标

最终期望两个网络达到一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近0.5(相当于随机猜测类别)。
目标:
对于生成器来说,传给辨别器的生成图片,生成器希望辨别器打上标签1。因为它要不断训练减小损失,以期望骗过判别器。
对于判别器来说,给定的真实图片,辨别器要为其打上标签1;给定的生成图片,辨别器要为其打上标签0;它要能够识别真假。

优化网络(定义损失)

GAN有两个网络,那么自然就有两个损失函数。
**生成网络的损失函数:**制造一个可以瞒过识别网络的输出
代表判断结果与1的距离。
识别网络的损失函数:
实数据就是真实数据,生成数据就是虚假数据(即真实数据与1的距离小,生成数据与0的距离小)

训练过程

GAN对抗网络的训练过程通常是两个网络单独且交替训练:先训练识别网络,再训练生成网络,再训练识别网络,如此反复,直到达到纳什均衡。

1.当生成器损失从很大的值迅速变为0,而判别器损失维持不变。
有可能时生成器生成能力较弱,因此一种可行的方法是增加生成器的层数来增加非线性。

2.某些文献采用生成器与判别器交叉训练的方法,即先训练判别器,再训练生成器,其目的是先训练判别器并更新其参数,先让其具有较好判别能力,而在训练生成器时因为判别器已具有一定判定能力,生成器的目的是尽可能骗过判别器,所以生成器会朝着生成更真实的图像前进;
也可以采用先训练生成器,再训练判别器,但是此种训练方法不推荐;同时也可以采用先更新生成器或判别器多次,再更新另一个一次的方法。

  1. 生成器损失、判别器损失,其中一个很大或者逐渐变大,另一个很小或者逐渐变小。
  2. 生成器和判别器的目的相反,也就是说两个生成器网络和判别器网络互为对抗,此消彼长。不可能Loss一直降到一个收敛的状态。

对于生成器,其Loss下降快,很有可能是判别器太弱,导致生成器很轻易的就"愚弄"了判别器。

对于判别器,其Loss下降快,意味着判别器很强,判别器很强则说明生成器生成的图像不够逼真,才使得判别器轻易判别,导致Loss下降很快。

也就是说,无论是判别器,还是生成器。loss的高低不能代表生成器的好坏。一个好的GAN网络,其GAN Loss往往是不断波动的。

技巧

训练GAN技巧
1.输入的图片经过处理,将0-255的值变为-1到1的值。
images = (images/255.0)*2 - 1

2 在generator输出层使用tanh激励函数,使得输出范围在 (-1,1)

3 保存生成的图片时,将矩阵值缩放到[0,1]之间
gen_image = (gen_image+1) / 2

4 使用leaky_relu激励函数,使得负值可以有一定的比重

5 使用BatchNormalization,使分布更均匀,最后一层不要使用。

6 在训练generator和discriminator的时候,一定要保证另外一个的参数是不变的,不能同时更新两个网络的参数。

7 如果训练很容易卡住,可以考虑使用WGAN
可以选择使用RMSprop optimizer

代码1(保存生成图片、loss可视化)

参考:GAN手写数字
代码位置:E:\项目例程\GNN\手写数字\3_可视化
评价:代码解释少,生产图片效果可以,
可学习保存生成图片代码
代码(加了可视化loss):

import torch
import torch.nn as nn
from torch import optim
from torch.autograd import variable
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt

G_in_dim = 100  # 模型的参数参考别人的网络设置
D_in_dim = 784
hidden1_dim = 256
hidden2_dim = 256
G_out_dim = 784
D_out_dim = 1

epoch = 50
batch_num = 60
lr_rate = 0.0003


def to_img(x):  # 这个函数参考自别人的网络,是将生成的假图像经过一系列操作能更清晰的显示出来,具体为什么这样设置没研究过
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


class G_Net(nn.Module):  # 生成网络,或者叫生成器,负责生成假数据
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(G_in_dim, hidden1_dim),
            nn.ReLU(),
            nn.Linear(hidden1_dim, hidden2_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden2_dim, G_out_dim),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.layer(x)
        return x


class D_Net(nn.Module):  # 判别网络,或者叫判别器,用来判别数据真假
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(D_in_dim, hidden1_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden1_dim, hidden2_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden2_dim, D_out_dim),
            nn.Sigmoid())

    def forward(self, x):
        x = self.layer(x)
        return x


data_tf = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize([0.5], [0.5])])
train_set = datasets.MNIST(root='data', train=True, transform=data_tf, download=True)
train_loader = DataLoader(train_set, batch_size=batch_num, shuffle=True)
g_net = G_Net()
d_net = D_Net()


G_losses = []
D_losses = []

criterion = nn.BCELoss()
G_optimizer = optim.Adam(g_net.parameters(), lr=lr_rate)
D_optimizer = optim.Adam(d_net.parameters(), lr=lr_rate)

iter_count = 0
for e in range(epoch):
    for data in train_loader:
        img, l = data
        img = img.view(img.size(0), -1)
        img = variable(img)
        r_label = variable(torch.ones(batch_num))
        f_label = variable(torch.zeros(batch_num))
        g_input = variable(torch.randn(batch_num, G_in_dim))

        r_output = d_net(img)
        r_loss = criterion(r_output.squeeze(-1), r_label)
        f_output = g_net(g_input)
        d_f_output = d_net(f_output)
        f_loss = criterion(d_f_output.squeeze(-1), f_label)
        sum_loss = r_loss + f_loss
        D_optimizer.zero_grad()
        sum_loss.backward()
        D_optimizer.step()

        g_input1 = variable(torch.randn(batch_num, G_in_dim))
        g_output = g_net(g_input1)
        d_output = d_net(g_output)
        d_loss = criterion(d_output.squeeze(-1), r_label)
        G_optimizer.zero_grad()
        d_loss.backward()
        G_optimizer.step()

        if (iter_count % 250 == 0):
            print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_loss.item(), sum_loss.item()))
        iter_count += 1

        G_losses.append(sum_loss.item())
        D_losses.append(d_loss.item())


    # g_img = g_net(variable(torch.randn(batch_num, G_in_dim)))
    # images = to_img(g_img)
    #save_image(images, './img/fake_images-{}.png'.format(e))

x=[i for i in range(len(G_losses))]
figure = plt.figure(figsize=(20, 8), dpi=80)
plt.plot(x,G_losses,label='G_losses')
plt.plot(x,D_losses,label='D_losses')
plt.xlabel("iterations",fontsize=15)
plt.ylabel("loss",fontsize=15)
plt.legend()
plt.show()

结果:

生成图片保存结果:
第1、10、30、40、50个:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码2-BP全连接网络

参考:GANminst数字
代码位置:E:\项目例程\GNN\手写数字\1
评价:可视化,图片保存

代码说明:

生成器端,g_loss表示它希望让判别器对自己生成的图片尽可能输出为1,相当于它在于判别器进行对抗。
判别器端,real_loss对应着真实图片的loss,它尽可能让判别器的输出接近于1,real_loss与 fake_loss加起来就是整个判别器的损失。

判别网络

将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
定义判别网络,这一步其实就是构造一个数字识别网络,只不过略微有些区别,这里不是识别具体的数字,而是识别是不是真实的图片,输出只有两个(0或者1),1代表是真实的图片,0代表的是构造的虚假图片。输出其实是个概率值。

# 判别网络
class discriminator(torch.nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        # 调用父类的初始化函数,必须要的
        super(discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img = self.net(img)
        return img

生成网络

输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,然后通过ReLU激活函数,接着进行一个线性变换,再经过一个ReLU激活函数,然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间。
输出数据将会别送到判别网络中去做判别。

# 生成网络
class generator(torch.nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        # 调用父类的初始化函数,必须要的
        super(generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, img):
        img = self.net(img)
        return img

定义损失函数和优化器

优化器采用了Adam优化器,损失函数采用了二分类的交叉熵损失函数

# 二分类的交叉熵损失函数
bce_loss = nn.BCEWithLogitsLoss()

# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

定义两个计算loss函数

分别计算判别网络和生成网络的代价估算,对于判别网络来说,希望真实的图片预测都是输出1,期望标签是1,对于假的图片希望都是模型输出0,期望标签是0。
而对于生成网络来说,希望模型输出是1,因此期望标签是1。

def discriminator_loss(logits_real, logits_fake):  # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    size = logits_fake.shape[0]
    false_labels = Variable(torch.zeros(size, 1)).float()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss


def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    loss = bce_loss(logits_fake, true_labels)
    return loss

训练流程函数

def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
                noise_size=NOISE_DIM, num_epochs=25):
    G_losses = []
    D_losses = []
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            bs = x.shape[0]
            # 判别网络
            real_data = Variable(x).view(bs, -1)  # 真实数据
            logits_real = D_net(real_data)  # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分

            d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 优化判别网络

            # 生成网络
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化生成网络

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))

            iter_count += 1
            #print('iter_count: ', iter_count)
            G_losses.append(d_total_error.item())
            D_losses.append(g_error.item())

        g_img = G(variable(torch.randn(batch_size, NOISE_DIM)))
        images = to_img(g_img)
        save_image(images, './img_epoch50/fake_images-{}.png'.format(epoch))

    x = [i for i in range(len(G_losses))]
    figure = plt.figure(figsize=(20, 8), dpi=80)
    plt.plot(x, G_losses, label='G_losses')
    plt.plot(x, D_losses, label='D_losses')
    plt.xlabel("iterations", fontsize=15)
    plt.ylabel("loss", fontsize=15)
    plt.legend()
    plt.grid()
    plt.show()

开始训练

D = discriminator()
G = generator()

D_optim = get_optimizer(D)
G_optim = get_optimizer(G)

train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

结果

epoch=10

在这里插入图片描述
在这里插入图片描述
生成图片:
第一次:
在这里插入图片描述
第二次
在这里插入图片描述
第三次在这里插入图片描述

第10次
在这里插入图片描述
总体趋势是随着迭代次数的增加,图像会变得稍微清晰一点点,数字的轮廓也明显一些。

epoch=30

在这里插入图片描述
在这里插入图片描述
第20次
在这里插入图片描述
第30次
在这里插入图片描述

代码

import torch
from torch import nn
from torch.autograd import Variable
from torch.autograd import variable
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
from torchvision import datasets, transforms

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from torchvision.utils import save_image

NUM_TRAIN = 60000

NOISE_DIM = 100
batch_size = 128

data_tf = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize([0.5], [0.5])])
train_set = MNIST('./data', train=True, transform=data_tf)
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)

def to_img(x):  # 这个函数参考自别人的网络,是将生成的假图像经过一系列操作能更清晰的显示出来,具体为什么这样设置没研究过
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

#show_images(imgs)

# 判别网络
class discriminator(torch.nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        # 调用父类的初始化函数,必须要的
        super(discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img = self.net(img)
        return img

# 生成网络
class generator(torch.nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        # 调用父类的初始化函数,必须要的
        super(generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 256),
            nn.ReLU(True),
            nn.Linear(256, 784),
            nn.Tanh()
        )

    def forward(self, img):
        img = self.net(img)
        return img

# 二分类的交叉熵损失函数
bce_loss = nn.BCEWithLogitsLoss()

# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
    optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
    return optimizer

def discriminator_loss(logits_real, logits_fake):  # 判别器的 loss
    size = logits_real.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    size = logits_fake.shape[0]
    false_labels = Variable(torch.zeros(size, 1)).float()
    loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
    return loss


def generator_loss(logits_fake):  # 生成器的 loss
    size = logits_fake.shape[0]
    true_labels = Variable(torch.ones(size, 1)).float()
    loss = bce_loss(logits_fake, true_labels)
    return loss

def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
                noise_size=NOISE_DIM, num_epochs=10):
    G_losses = []
    D_losses = []
    iter_count = 0
    for epoch in range(num_epochs):
        for x, _ in train_data:
            bs = x.shape[0]
            # 判别网络
            real_data = Variable(x).view(bs, -1)  # 真实数据
            logits_real = D_net(real_data)  # 判别网络得分

            sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5  # -1 ~ 1 的均匀分布
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据
            logits_fake = D_net(fake_images)  # 判别网络得分

            d_total_error = discriminator_loss(logits_real, logits_fake)  # 判别器的 loss
            D_optimizer.zero_grad()
            d_total_error.backward()
            D_optimizer.step()  # 优化判别网络

            # 生成网络
            g_fake_seed = Variable(sample_noise)
            fake_images = G_net(g_fake_seed)  # 生成的假的数据

            gen_logits_fake = D_net(fake_images)
            g_error = generator_loss(gen_logits_fake)  # 生成网络的 loss
            G_optimizer.zero_grad()
            g_error.backward()
            G_optimizer.step()  # 优化生成网络

            if (iter_count % show_every == 0):
                print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))

            iter_count += 1
            #print('iter_count: ', iter_count)
            G_losses.append(d_total_error.item())
            D_losses.append(g_error.item())

        g_img = G(variable(torch.randn(batch_size, NOISE_DIM)))
        images = to_img(g_img)
        save_image(images, './img2/fake_images-{}.png'.format(epoch))

    x = [i for i in range(len(G_losses))]
    figure = plt.figure(figsize=(20, 8), dpi=80)
    plt.plot(x, G_losses, label='G_losses')
    plt.plot(x, D_losses, label='D_losses')
    plt.xlabel("iterations", fontsize=15)
    plt.ylabel("loss", fontsize=15)
    plt.legend()
    plt.grid()
    plt.show()



D = discriminator()
G = generator()

D_optim = get_optimizer(D)
G_optim = get_optimizer(G)

train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

代码3:判别-生成网络模型 cnn

代码位置:E:\项目例程\GNN\手写数字\1_cnn_效果好
改进:
产生的噪声更少了,训练也更加稳定,主要是里面引入了Batchnormalization,另外gan的训练过程是特别困难的,两个对偶网络相互学习,这个时候有一些训练技巧可以使得训练生成更加稳定。

代码

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2)
        )
        self.fc = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x


class generator(nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        super(generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 7 * 7 * 128),
            nn.ReLU(True),
            nn.BatchNorm1d(7 * 7 * 128)
        )

        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.shape[0], 128, 7, 7)  # reshape 通道是 128,大小是 7x7
        x = self.conv(x)
        return x

流程训练函数train_a_gan的第8行 real_data = Variable(x).view(bs, -1) # 真实数据 需要修改为,real_data = Variable(x) # 真实数据

结果

epoch=25
第1,5 ,10,15,20 ,25 次
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

标签:loss,判别,nn,self,img,minst,GAN,fake,手写
来源: https://blog.csdn.net/zhe470719/article/details/120741901

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有