ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

变分自编码器(Variational Auto-Encoder,VAE)

2022-01-17 14:35:14  阅读:255  来源: 互联网

标签:编码器 nn Auto torch 变分 VAE hat mnist 28


VAE网络结构较AE只有部分改变

import torch
import numpy as np
from torch import nn




class VAE(nn.Module):

    def __init__(self):
        super(VAE, self).__init__()

        # [b,784] => [b,20]
        # u: [b,10]
        # sigma: [b,10]
        self.encoder = nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,20),
            nn.ReLU()
        )
        # [b,20] => [b,784]
        self.decoder = nn.Sequential(
            nn.Linear(10,64),
            nn.ReLU(),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()  # 压缩到0-1
        )


    def forward(self,x):
        """

        :param x:[b,1,28,28]
        :return:
        """
        batchsz = x.size(0)
        # flatten
        x=x.view(batchsz,784)
        # encoder
        # [b,20] 包含了mean 和 sigma
        h_ = self.encoder(x)

        # VAE 和 AE 的不同之处
        # 把mu和sigma拆分出来,用chunk(拆分的个数,位置)
        # [b,20]-——>[b,10] and [b,10]
        mu,sigma = h_.chunk(2,dim=1)
        # reparametrize trick ,epison~N(0,1)
        h = mu + sigma * torch.randn_like(sigma)  # 后边这个是sigma的正态分布
        # decoder
        x_hat = self.decoder(h)
        # reshape  因为是打平过的,还需要再变回照片
        x_hat = x_hat.view(batchsz,1,28,28)

        # 计算KL divergence,网上可以查它的公式,这里u2=0,sigma2=1
        kld = 0.5 * torch.sum(
            torch.pow(mu,2)+
            torch.pow(sigma,2)-
            torch.log(1e-8 + torch.pow(sigma,2))-1
        ) / (batchsz*28*28)





        return x_hat , kld

只是多个这个kld

主函数部分变化有

import torch
from torchvision import transforms,datasets  # datasets自带数据集MNIST
from torch.utils.data import DataLoader
from AE import AE
from VAE import VAE
from torch import nn,optim
import visdom





def main():
    # 把MNIST数据集加载进来
    mnist_train = datasets.MNIST('mnist',True,transform=transforms.Compose([
        transforms.ToTensor()
    ]),download=True)

    # 把数据集加载到DataLoader中
    mnist_train = DataLoader(mnist_train,batch_size=32,shuffle=True)


    # 把MNIST数据集加载进来
    mnist_test = datasets.MNIST('mnist',True,transform=transforms.Compose([
        transforms.ToTensor()
    ]),download=True)

    # 把数据集加载到DataLoader中
    mnist_test = DataLoader(mnist_test,batch_size=32,shuffle=True)

    # 构建一个迭代器
    x,_ = iter(mnist_train).next()  # 不返回label,因为这是无监督学习
    print('x:',x.shape)  # x:torch.Size([32, 1, 28, 28])

    device = torch.device('cuda')
    model = VAE().to(device)
    criteon = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(),lr=1e-3)
    print(model)

    viz = visdom.Visdom()


    for epoch in range(1000):
        for batchidx,(x,_) in enumerate(mnist_train):
            # [b,1,28,28]
            x=x.to(device)

            # x_hat表示重建过的x
            x_hat,kld = model(x)
            loss = criteon(x_hat,x)

            # VAE才有的
            if kld is not None:
                elbo = -loss - 1.0 * kld
                loss = -elbo

            # backprop
            optimizer.zero_grad() #第一步梯度清零
            loss.backward()  # 第二步backward
            optimizer.step()  # 第三步更新梯度


        print(epoch,'loss',loss.item(),'kld',kld.item())

        # 从test中取一些图片进行重构
        x, _ = iter(mnist_test).next()
        x=x.to(device)
        with torch.no_grad():
            x_hat,kld = model(x)
        # 可视化
        viz.images(x,nrow=8,win='x',opts=dict(title='x'))
        viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))








if __name__ == '__main__':
    main()

得到结果

具体原因可能是任务太简单了,体现不出VAE的好用。

标签:编码器,nn,Auto,torch,变分,VAE,hat,mnist,28
来源: https://blog.csdn.net/weixin_62637793/article/details/122539052

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有