ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

《A Two-Step Disentanglement Method》keras 实践

2019-07-11 21:40:57  阅读:298  来源: 互联网

标签:DIM enc Disentanglement clf keras Two train plt test


Notes

文章是关于解耦特征表示的,网络主体基于 auto encoder,但将 encoder 拆成两个:EncsEnc_sEncs​ 和 EnczEnc_zEncz​,意图将 latent code 分成两部分:z’=(s, z),其中 s 编码同 label 相关的信息,z 编码其它信息。

z'EncsEnczClfsClfzszxl_sl_zx_hat

实现解耦的思路是靠两个分类器:

  • ClfsClf_sClfs​:对 s 分类,约束 s 捕捉 label 信息;
  • ClfzClf_zClfz​:同 EnczEnc_zEncz​ 做对抗学习,ClfzClf_zClfz​ 希望对 z 分类正确,而 EnczEnc_zEncz​ 希望编码出的 z 使 ClfzClf_zClfz​ 分类错误(单标签下输出的概率向量全是 1Nclass\frac{1}{N_{class}}Nclass​1​,多标签下则全是 0.5),以此约束 z 中不含 label 信息;

同时对 decoder 用重构损失,约束 z’ = (s, z) 能编码原始数据的全部信息,而没有信息丢失。

Practice

Model

小改了一下模型,加了一个分类器 ClfxClf_xClfx​,对 x 和 x^\hat xx^ 进行分类:

z'EncsEnczClfsClfzClfxClfxszxl_sl_zx_hatl_x_supl_x_uns

Objectives

整幅计算图分三个子模块迭代训练:

  1. Encs,Clfs,ClfxEnc_s, Clf_s, Clf_xEncs​,Clfs​,Clfx​
    ls=Clfs(Encs(x))l_s=Clf_s(Enc_s(x))ls​=Clfs​(Encs​(x)) 和 lxsup=Clfx(x)l_x^{sup}=Clf_x(x)lxsup​=Clfx​(x) 用分类损失;
  2. ClfzClf_zClfz​
    定住 EnczEnc_zEncz​,对 lz=Clfz(Encz(x))l_z=Clf_z(Enc_z(x))lz​=Clfz​(Encz​(x)) 用分类损失;
  3. Encz,DecEnc_z, DecEncz​,Dec
    定住 EncsEnc_sEncs​,对 x^=Dec(Encs(x),Encz(x))\hat x=Dec(Enc_s(x), Enc_z(x))x^=Dec(Encs​(x),Encz​(x)) 用重构损失;
    另外抽样图像和标签 x,lx', l'x′,l′,对 lxuns=Clfx(Dec(Encs(x),Encz(x)))l_x^{uns}=Clf_x(Dec(Enc_s(x'), Enc_z(x)))lxuns​=Clfx​(Dec(Encs​(x′),Encz​(x))) 用分类损失,target label 是新抽样的 ll'l′;
    定住 ClfzClf_zClfz​,对 lz=Clfz(Encz(x))l_z=Clf_z(Enc_z(x))lz​=Clfz​(Encz​(x)) 用分类损失,但此时 target label 是 l~=(1Nclass,,1Nclass)\tilde l=(\frac{1}{N_{class}},\dots,\frac{1}{N_{class}})l~=(Nclass​1​,…,Nclass​1​);

详见代码。

Code

  • 用预设参数运行
from time import time
import argparse
import numpy as np
from sklearn import manifold
import matplotlib.pyplot as plt

import keras
import keras.backend as K
from keras.optimizers import adam, sgd
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.models import Model
from keras.layers import Dense, Dropout, Activation, Input, Concatenate, LeakyReLU

np.random.seed(int(time()))

parser = argparse.ArgumentParser()
parser.add_argument('--EPOCH', type=int, default=30)
parser.add_argument('--BATCH', type=int, default=128)
parser.add_argument('--DIM_Z', type=int, default=16)
parser.add_argument('--DIM_H', type=int, default=256)
parser.add_argument('--DIM_FEA', type=int, default=16)
opt = parser.parse_args()
print(opt)

(I_train, L_train), (I_test, L_test) = mnist.load_data()
N_PIX = I_train.shape[1]
I_train = I_train.reshape(I_train.shape[0], -1) / 255.
I_test = I_test.reshape(I_test.shape[0], -1) / 255.
L_train = to_categorical(L_train, 10)
L_test = to_categorical(L_test, 10)
print(I_train.shape, L_test.shape)

N_CLASS = L_train.shape[-1]
DIM_IMG = I_train.shape[-1]
DIM_FEA = opt.DIM_FEA
DIM_Z = opt.DIM_Z
DIM_H = opt.DIM_H
EPOCH = opt.EPOCH
BATCH = opt.BATCH


def Encoder(dim_in=DIM_IMG, dim_z=DIM_Z, name='encoder'):
    inputs = Input([dim_in])
    x = inputs
    x = Dense(DIM_H, activation='relu')(x)
    x = Dropout(0.2)(x)
    x = Dense(DIM_H, activation='relu')(x)
    x = Dropout(0.2)(x)
    z = Dense(dim_z)(x)
    return Model(inputs, z, name=name)


def Decoder(dim_z=DIM_Z, dim_a=DIM_FEA, dim_out=DIM_IMG, name='decoder'):
    z = Input([dim_z])
    a = Input([dim_a])
    inputs = [z, a]
    x = Concatenate()([z, a])
    for _ in range(2):
        x = Dense(DIM_H, activation='relu')(x)
        # x = LeakyReLU(alpha=0.2)(x)
        x = Dropout(0.3)(x)
    x = Dense(dim_out)(x)
    x = Activation("sigmoid")(x)
    output = x
    return Model(inputs, output, name=name)


def Classifier(dim_in=DIM_Z, n_class=N_CLASS, name='classifier'):
    inputs = Input([dim_in])
    x = inputs
    # x = Dense(DIM_H, activation='relu')(x)
    # x = Dropout(0.2)(x)
    x = Dense(n_class, activation='softmax')(x)
    output = x
    return Model(inputs, output, name=name)


def _set_train(m, is_train=True):
    m.trainable = is_train
    for ly in m.layers:
        ly.trainable = is_train


# network
in_lab = Input([N_CLASS])
in_img = Input([DIM_IMG])
other_i = Input([DIM_IMG])

enc_z = Encoder(DIM_IMG, DIM_Z, 'enc_z')
enc_s = Encoder(DIM_IMG, DIM_FEA, 'enc_s')
dec = Decoder(DIM_Z, DIM_FEA, DIM_IMG, 'dec')
clf_z = Classifier(DIM_Z, N_CLASS, 'clf_z')
clf_s = Classifier(DIM_FEA, N_CLASS, 'clf_s')
clf_x = Classifier(DIM_IMG, N_CLASS, 'clf_x')

z = enc_z(in_img)
s = enc_s(in_img)
x_hat = dec([z, s])
l_z = clf_z(z)
l_s = clf_s(s)
l_x_sup = clf_x(in_img)
other_s = enc_s(other_i)
other_x_hat = dec([z, other_s])
l_x_uns = clf_x(other_x_hat)


# enc_s & clf_s & clf_x
m_sup = Model([in_img, in_lab], [l_s, l_x_sup],
              name='train_EncF_ClfF_ClfI')
m_sup.compile('adam',
              loss=['categorical_crossentropy',
                    'categorical_crossentropy'],
              loss_weights=[1, 1],
              metrics=['categorical_accuracy'])


# adv: clf_z
m_adv = Model(in_img, l_z, name='train_EncZ')
_set_train(enc_z, False)
m_adv.compile(sgd(0.001),
              loss='categorical_crossentropy',
              metrics=['categorical_accuracy'])


# AE: enc_z & dec
m_ae = Model([in_img, other_i], [x_hat, l_z, l_x_uns], name='train_ae')
_set_train(enc_z, True)
_set_train(dec, True)
_set_train(enc_s, False)
_set_train(clf_z, False)
_set_train(clf_x, False)
# _set_train(clf_s, False)
# _set_train(model_lab, False)
m_ae.compile('adam',
             loss=['binary_crossentropy', 'categorical_crossentropy',
                   'categorical_crossentropy'],
             loss_weights=[10, 10, 1],
             metrics=['categorical_accuracy'])


def TSNE(X, label, title="", save_f=None):
    n_points = len(X)
    n_components = 2
    color = np.argmax(label, axis=-1)
    fig = plt.figure(figsize=(15, 8))
    if title == "":
        plt.suptitle("%s Manifold Learning with %i points"
                     % (title, n_points), fontsize=14)
    else:
        plt.suptitle(title)

    if X[0].size == 3:
        ax = fig.add_subplot(251, projection='3d')
        ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=color,
                   cmap=plt.get_cmap("rainbow"))
        ax.view_init(4, -72)

    t0 = time()
    tsne = manifold.TSNE(n_components=n_components, init='pca', random_state=0)
    Y = tsne.fit_transform(X)
    t1 = time()
    print("t-SNE: %.2g sec" % (t1 - t0))
    plt.scatter(Y[:, 0], Y[:, 1], c=color, cmap=plt.get_cmap("rainbow"))
    plt.colorbar()
    plt.title("t-SNE (%.2g sec)" % (t1 - t0))
    plt.axis('tight')
    if save_f is not None:
        assert isinstance(save_f, str)
        fig.savefig(f'./picture/{save_f}.png')
    plt.show()


def test():
    idx = np.random.choice(L_test.shape[0], 10)
    other_idx = np.random.choice(L_test.shape[0], 10)

    print('original')
    x = I_test[idx].reshape(-1, N_PIX, N_PIX)
    x = np.hstack(x)
    plt.imshow(x, cmap='Greys')
    plt.show()

    print('reconstruct')
    x_gen = dec.predict([enc_z.predict(I_test[idx]),
                         enc_s.predict(I_test[idx])])
    x = x_gen.reshape(-1, N_PIX, N_PIX)
    x = np.hstack(x)
    plt.imshow(x, cmap='Greys')
    plt.show()

    print('change s:', np.argmax(L_test[other_idx], axis=-1))
    x_gen = dec.predict([enc_z.predict(I_test[idx]),
                         enc_s.predict(I_test[other_idx])])  # changed
    x = x_gen.reshape(-1, N_PIX, N_PIX)
    x = np.hstack(x)
    plt.imshow(x, cmap='Greys')
    plt.show()

    print('change z:', np.argmax(L_test[idx], axis=-1))
    x_gen = dec.predict([enc_z.predict(I_test[other_idx]),  # changed
                         enc_s.predict(I_test[idx])])
    x = x_gen.reshape(-1, N_PIX, N_PIX)
    x = np.hstack(x)
    plt.imshow(x, cmap='Greys')
    plt.show()

    print('real label:', np.argmax(L_test[idx[0]], axis=-1))
    print('clf_z:', clf_z.predict(enc_z.predict(I_test[idx[0:1]]))[0])
    print('clf_s:', clf_s.predict(enc_s.predict(I_test[idx[0:1]]))[0])


def gen_data(dataset, batch_size):
	"""数据生成器"""
    if dataset == "train":
        I, L = I_train, L_train
    elif dataset == "test":
        I, L = I_test, L_test
    size = I.shape[0]
    while True:
        idx = np.random.choice(size, batch_size)
        yield I[idx], L[idx]


xjb_label = np.ones((BATCH, N_CLASS)) / N_CLASS  # 假 label
gen_train = gen_data('train', BATCH)
for epoch in range(EPOCH):
    print(f'--- {epoch} ---')
    for b in range(I_train.shape[0] // BATCH):
        for _ in range(1):
            i, l = next(gen_train)
            loss_sup = m_sup.train_on_batch([i, l], [l, l])
        for _ in range(3):
            i, l = next(gen_train)
            loss_adv = m_adv.train_on_batch(i, l)
        for _ in range(1):
            i, l = next(gen_train)
            i2, l2 = next(gen_train)
            loss_ae = m_ae.train_on_batch([i, i2], [i, xjb_label, l2])

    print(loss_sup)
    print(loss_adv)
    print(loss_ae)
    if epoch % 10 == 0:
        test()


print('\n--- after ---')
test()
TSNE(enc_z.predict(I_test), L_test, 'z distribution')
TSNE(enc_s.predict(I_test), L_test, 's distribution')

Renderings

原图、重构图、换 s 不换 z、换 z 不换 s
MNIST
s 的分布
s_distribution
z 的分布
z_distribution

References

  1. paper:A Two-Step Disentanglement Method
  2. code:naamahadad/A-Two-Step-Disentanglement-Method

标签:DIM,enc,Disentanglement,clf,keras,Two,train,plt,test
来源: https://blog.csdn.net/HackerTom/article/details/95509729

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有