ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

CGAN实现过程

2021-10-23 11:30:00  阅读:1022  来源: 互联网

标签:10 ... 实现 CGAN 28 0.0000 100 过程 输入


本文目录

本文用MNIST数据集进行训练,并用图解的方法展示了CGAN与GAN中输入的区别,帮助理解CGAN的运行过程

一、原理

如下图所示,我们在输入噪声z时,额外加上一个限制条件conditionz和c通过生成器G得到生成的图片

二、参数初始化

有了上面的原理解释,我们就可以来初始化我们的参数了,大致可以看出我们有如下几个参数:噪声z,条件c,真实图片x,生成器和判别器的初始化参数

  • G的输入:z_y_vec_
  • D的输入:xy_fill_
  • 模型参数的初始化
  • 测试时用的噪声sample_z_以及对应的标签sample_y_

这里输入的单个噪声维度为z_dim=62,当然这里还有很多其他的初始化,比如optimizer等,因为本文主要介绍模型的的具体执行过程,所以只对变量得初始化做介绍

1. G的输入

  • 输入噪声z:z_: (64, 62)
  • 输入条件c:y_vec_:(64, 10)

最终G的输入:横向拼接z+c (64, 72)

G:
torch.Size([64, 72])
tensor([[0.8920, 0.9742, 0.6876,  ..., 0.0000, 0.0000, 0.0000],
        [0.5271, 0.6423, 0.7480,  ..., 0.0000, 1.0000, 0.0000],
        [0.9545, 0.6324, 0.9603,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.1931, 0.7773, 0.8154,  ..., 0.0000, 0.0000, 0.0000],
        [0.0049, 0.7129, 0.3272,  ..., 0.0000, 0.0000, 0.0000],
        [0.2902, 0.1194, 0.0020,  ..., 0.0000, 1.0000, 0.0000]])

在这里插入图片描述

2. D的输入

  • 输入真实数据:x: (64, 1, 28, 28)
  • 输入生成数据:G(z):(64, 1, 28, 28)
  • 输入条件:c:y_fill_:(64, 10, 28, 28)

最终D的输入:横向拼接x+c (64, 11, 28, 28),也就是说取batch中的一个值,维度为(1,28, 28),将其作为(11, 28, 28)的第一维,剩下的十维如果标签为0则第二维为全1,剩下的为全0,如果标签为1则第三维为全1,剩下的为全0,以此类推

D:
torch.Size([64, 11, 28, 28])
tensor([[[[ 0.1099, -0.5590,  0.9668,  ...,  3.0843,  0.6788, -0.4171],
          [ 0.8949, -0.3523, -0.4086,  ..., -0.8257, -2.1445,  1.0512],
          [ 1.5333, -0.0918, -1.1146,  ..., -1.1746, -0.4689,  0.3702],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

在这里插入图片描述

3. 模型参数初始化

def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

4. 测试噪声

在测试时我们只需要设置G的输入就可以了,也就是说我们需要:

  • 输入噪声z:z_: (100, 62)
  • 输入条件c:y_vec_:(100, 10)

最终G的输入:横向拼接z+c (100, 72)

下面给出代码和输出

# fixed noise
sample_z_ = torch.randn((100, 62))
for i in range(10):
    sample_z_[i*10] = torch.rand(1, 62)
    for j in range(1, 10):
        sample_z_[i*10 + j] = sample_z_[i*10]
print(sample_z_)
"""
sample_z_:(100, 62)
          0-9:    same value
          10-19:  same value
          ...
          90-99:  same value
"""
temp = torch.zeros((10, 1))     # (10,1)---> 0,0,0,0,0,0,0,0,0,0
for i in range(10):
    temp[i, 0] = i                     # (10, 1) ---> 0,1,2,3,4,5,6,7,8,9
# print("temp:      ", temp)

temp_y = torch.zeros((100, 1))  #(100,1)---> 0,0,0,0,...,0,0,0,0
for i in range(10):             #(100,1)---> 0,1,2,3,...,6,7,8,9
    temp_y[i*10: (i+1)*10] = temp
# print("temp_y:    ", temp_y)           
sample_y_ = torch.zeros((100, 10)).scatter_(1, temp_y.type(torch.LongTensor), 1)
print(sample_y_)                       #(100,10)
'''
tensor([[0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],
        [0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],
        [0.3944, 0.9880, 0.4956,  ..., 0.0602, 0.9869, 0.5094],
        ...,
        [0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332],
        [0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332],
        [0.2845, 0.7694, 0.9878,  ..., 0.3211, 0.0242, 0.0332]])
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
'''

下面给出详细的解释,我们知道G的输入有噪声以及条件,这里我们有100组噪声,每10组噪声的组内取值是完全相同的,但是组内的10个噪声每个噪声的条件是不同的,分别代表了数字0-9

也就是说我们希望用相同的噪声生成0-9一共十个数字,生成十组

三、执行过程

图中的红线代表一个执行流程,绿线代表一个执行流程,红色的方框为这一步反向传播的网络。因为判别器与生成器是分开训练的,用两个图来表示,左边是第一步训练判别器,右边是第二步训练生成器

  • step1:首先将样本进行输入,用BCE_loss来评估得到D_real_loss,然后将G生成的数据进行输入,同理评估得到D_fake_loss,将二者相加进行反向传播优化D。注意这一步不要优化G
  • step2:直接将G生成的数据进行输入,评估得到G_loss,反向传播优化G。注意这一步虽然是G生成的数据,但是通过D以后要与real进行求损失

在这里插入图片描述

四、测试

训练完后直接进行测试即可,最后测试生成的图片如下:

标签:10,...,实现,CGAN,28,0.0000,100,过程,输入
来源: https://blog.csdn.net/weixin_41978699/article/details/120883330

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有