ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

用AVE和GAN生成图像

2021-11-25 12:58:31  阅读:198  来源: 互联网

标签:nn image torch batch GAN 图像 AVE self size


AVE相当于把图片压缩在解压的感觉。

自编码器       x ——编码器——z——解码器——\widetilde{x}

变分自编码器       对z有一个正态分布的约束

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

sample_dir = 'samples'
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 0.001
dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)

    # 编码  学习高斯分布均值与方差
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)

    # 将高斯分布均值与方差参数重表示,生成隐变量z  若x~N(mu, var*var)分布,则(x-mu)/var=z~N(0, 1)分布
    # 用mu,log_var生成一个潜在空间点z,mu,log_var为两个统计参数假设分布能生成图像
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return mu + eps * std

    # 解码隐变量z
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))

    # 计算重构值和隐变量z的分布参数
    def forward(self, x):
        mu, log_var = self.encode(x)  # 从原始样本x中学习隐变量z的分布,即学习服从高斯分布均值与方差
        z = self.reparameterize(mu, log_var)  # 将高斯分布均值与方差参数重表示,生成隐变量z
        x_reconst = self.decode(z)  # 解码隐变量z,生成重构x’
        return x_reconst, mu, log_var  # 返回重构值和隐变量的分布参数

model = VAE()
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# 开始训练
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        x = x.view(-1,image_size)  # 将batch_size*1*28*28 ---->batch_size*image_size  其中,image_size=1*28*28=784
        x_reconst, mu, log_var = model(x)  # 将batch_size*748的x输入模型进行前向传播计算,重构值和服从高斯分布的隐变量z的分布参数(均值和方差)
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)      # 计算重构损失和KL散度
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())          # KL散度
        loss = reconst_loss + kl_div                  # 计算误差(重构误差和KL散度值)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i + 1) % 10 == 0:
            print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                  .format(epoch + 1, num_epochs, i + 1, len(data_loader), reconst_loss.item(), kl_div.item()))

    with torch.no_grad():
        # 保存采样图像,即潜在向量z通过解码器生成的图像
        z = torch.randn(batch_size, z_dim)  # z的大小为batch_size * z_dim = 128*20
        out = model.decode(z).view(-1, 1, 28, 28)      # 对随机数 z 进行解码decode输出
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch + 1)))    # 保存结果值
        # 保存重构图像,即原图像通过解码器生成的图像
        out, _, _ = model(x)            # 将batch_size*748的x输入模型进行前向传播计算,获取重构值out
        # 将输入与输出拼接在一起输出保存  batch_size*1*28*(28+28)=batch_size*1*28*56
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch + 1)))

GAN是基于博弈论的,所以又称生成式对抗网络。他要解决的问题是如何从训练样本中学习出新样本,训练样本是图像就生成新图像,训练样本是文章就生成新文章。

GAN既不依赖标签来优化,也不根据对结果的奖惩来调整参数,他是依据生成器和判别器之间的博弈来不断优化。

import os
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

sample_dir = 'samples'
image_size = 784
hidden_size = 256
latent_size = 64
num_epochs = 5
batch_size = 100
learning_rate = 0.001
dataset = torchvision.datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
# 构建判断器
D = nn.Sequential(nn.Linear(image_size,hidden_size),
                  nn.LeakyReLU(0.2),
                  nn.Linear(hidden_size,hidden_size),
                  nn.LeakyReLU(0.2),
                  nn.Linear(hidden_size,1),
                  nn.Sigmoid())
# 构建生成器
G = nn.Sequential(nn.Linear(latent_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,hidden_size),
                  nn.ReLU(),
                  nn.Linear(hidden_size,image_size),
                  nn.Tanh())
criterion = nn.BCELoss()                     #用于二分类
optimizerG = torch.optim.Adam(G.parameters(), lr=0.0002)   #生成器的优化器
optimizerD = torch.optim.Adam(D.parameters(), lr=0.0002)   #判别器的优化器

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1) # 将out张量每个元素的范围限制到区间 [min,max]
def reset_grad():
    optimizerG.zero_grad()  # 清空鉴别器的梯度器上一步的残余更新参数值
    optimizerD.zero_grad()  # 清空生成器的梯度器上一步的残余更新参数值
for epoch in range(num_epochs):
    for i,(images,_) in enumerate(data_loader):
        images = images.reshape(batch_size,-1)
        # 定义图像真假的标签
        real_labels = torch.ones(batch_size,1)
        fake_labels = torch.zeros(batch_size,1)
        # =========================================#
        #              训练判别器                    #
        # =========================================#
        # 定义判断器对真图像的损失函数
        outputs = D(images)
        d_loss_real = criterion(outputs,real_labels)
        real_score = outputs
        # 定义判别器对加图像的损失函数
        z = torch.randn(batch_size,latent_size)
        fake_image = G(z)
        outputs = D(fake_image)
        d_loss_fake = criterion(outputs,fake_labels)
        fake_score = outputs
        d_loss = d_loss_real + d_loss_fake     # 判别器总的损失函数
        # 对生成器判别器梯度清零
        reset_grad()
        d_loss.backward()
        optimizerD.step()
        # =========================================#
        #              训练生成器                    #
        # =========================================#
        # 定义生成器对假图像的损失函数
        # 我们要求判别器生成的图像越来越像真图片,故损失函数中的标签改为真图像的标签,即希望生成的加图像越来越靠近真图像
        z = torch.randn(batch_size,latent_size)
        fake_image = G(z)
        outputs = D(fake_image)
        g_loss = criterion(outputs,real_labels)
        reset_grad()
        g_loss.backward()
        optimizerG.step()
        if (i + 1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i + 1, len(data_loader), d_loss.item(), g_loss.item(),
                          real_score.mean().item(), fake_score.mean().item()))
    # 保存真图像
    if(epoch+1) == 1:
        images = images.reshape(images.size(0),1,28,28)
        save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))
    # 保存假图像
    fake_images = images.reshape(images.size(0),1,28,28)
    save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))
# 保存模型
torch.save(G.state_dict(),'G.ckpt')
torch.save(D.state_dict('D.ckpt'))

 fake

real 

 

 

 

标签:nn,image,torch,batch,GAN,图像,AVE,self,size
来源: https://blog.csdn.net/z1139269312/article/details/121534975

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有