ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

tf中WGAN-GP实战

2021-10-03 13:02:34  阅读:217  来源: 互联网

标签:map GP image batch dataset tf fn WGAN


tf中WGAN-GP实战

文章目录

1. 任务

利用DCGAN对Anmie数据集生成

2. WGAN模型

# 定义WGAN-GP过程
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc = layers.Dense(3*3*512)
        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()
        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()
        # 最后一层输出的为输入D中的数据,需要与保持原数据(图片)的输入维度一致,所以设为三个卷积核
        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

    def call(self, inputs, training=None):
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        # 为避免梯度弥散,用leaky_relu代替relu
        x = tf.nn.leaky_relu(x)
        # BatchNormalization层在训练和测试时行为不一致,需要标注是训练还是测试模式
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        x = tf.tanh(x)
        return x


class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()
        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()
        # flatten用于自动打平,可以放到Sequential容器中,reshape不可放到Sequential容器中
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

    def call(self, inputs, training=None):
        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
        x = self.flatten(x)
        output = self.fc(x)
        return output


def main():
    d = Discriminator()
    g = Generator()
    x = tf.random.normal([1, 64, 64, 3])
    z = tf.random.normal([1, 100])
    prob = d(x)
    x_hat = g(z)


if __name__ == '__main__':
    main()

3. WGAN训练

# WGAN-GP与DCGAN的区别只在损失函数部分
# WGAN-GP在损失函数部分添加了梯度惩罚项
import numpy as np
import tensorflow as tf
from tensorflow import keras
from PIL import Image
import glob
from wgan import Generator, Discriminator
from dataset import make_anime_dataset
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# dataset.py以把图片处理为64x64


def save_result(val_out, val_block_size, image_path, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        # img = img.astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b+1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image).save(image_path)


def celoss_zeros(output):
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=tf.zeros_like(output))
    loss = tf.reduce_mean(loss)
    return loss


def celoss_ones(output):
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=output, labels=tf.ones_like(output))
    loss = tf.reduce_mean(loss)
    return loss


def gradient_penalty(discriminator, batch_x, fake_image):
    # batch_x即真实图片
    batch_size0 = batch_x.shape[0]
    # t需改为与输入真实图片相同格式
    # t为一个image全局共用,共有需要batch_size0(总数)个image
    t = tf.random.uniform([batch_size0, 1, 1, 1])
    t = tf.broadcast_to(t, batch_x.shape)
    inter_plate = t * batch_x + (1-t) * fake_image
    with tf.GradientTape() as tape:
        # 对于tensor类型数据,更新梯度时必须加tape.watch
        tape.watch([inter_plate])
        d_inter_plate_output = discriminator(inter_plate, training=True)
    grads = tape.gradient(d_inter_plate_output, inter_plate)
    # 更改grads的维度(打平)    [b,h,w,c]=>[b,-1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1)  # 求2范数
    gp = tf.reduce_mean((gp-1) ** 2)
    return gp


def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_output = discriminator(fake_image, is_training)
    d_real_output = discriminator(batch_x, is_training)
    d_loss_fake = celoss_zeros(d_fake_output)
    d_loss_real = celoss_ones(d_real_output)
    gp = gradient_penalty(discriminator, batch_x, fake_image)
    loss = d_loss_real + d_loss_fake + 10. * gp
    return loss, gp


def g_loss_fn(generator, discriminator, batch_z, is_training):
    fake_image = generator(batch_z, is_training)
    d_fake_output = discriminator(fake_image, is_training)
    loss = celoss_ones(d_fake_output)
    return loss


def main():
    z_dim = 100
    epochs = 1000000
    batch_size = 256
    lr = 1e-3
    is_training = True
    img_path = glob.glob(r'E:\PyCharm Community Edition 2020.3.5\workspace\wgan\data\anime\*.jpg')
    assert len(img_path) > 0
    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    sample = next(iter(dataset))
    dataset = dataset.repeat()  # repeat即一直sample
    db_iter = iter(dataset)
    generator = Generator()
    generator.build(input_shape=(None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5)  # GAN需设置Adam优化器的beta_1参数
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.5)
    for epoch in range(epochs):
        batch_z = tf.random.normal([batch_size, z_dim])
        batch_x = next(db_iter)
        with tf.GradientTape() as tape:
            d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
        if epoch % 100 == 0:
            print(epoch, 'd_loss:', float(d_loss), 'g_loss:', float(g_loss), 'gp:', float(gp))
            z = tf.random.normal([100, z_dim])
            fake_image = generator(z, training=False)
            image_path = os.path.join('images', 'gan-%d.png'%epoch)
            save_result(fake_image.numpy(), 10, image_path, color_mode='P')


if __name__ == '__main__':
    main()

4. 数据集处理

import multiprocessing

import tensorflow as tf


def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1
        return img

    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size

    return dataset, img_shape, len_dataset


def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

        if filter_fn:
            dataset = dataset.filter(filter_fn)

    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)

    return dataset


def memory_data_batch_dataset(memory_data,
                              batch_size,
                              drop_remainder=True,
                              n_prefetch_batch=1,
                              filter_fn=None,
                              map_fn=None,
                              n_map_threads=None,
                              filter_after_map=False,
                              shuffle=True,
                              shuffle_buffer_size=None,
                              repeat=None):
    """Batch dataset of memory data.

    Parameters
    ----------
    memory_data : nested structure of tensors/ndarrays/lists

    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data)
    dataset = batch_dataset(dataset,
                            batch_size,
                            drop_remainder=drop_remainder,
                            n_prefetch_batch=n_prefetch_batch,
                            filter_fn=filter_fn,
                            map_fn=map_fn,
                            n_map_threads=n_map_threads,
                            filter_after_map=filter_after_map,
                            shuffle=shuffle,
                            shuffle_buffer_size=shuffle_buffer_size,
                            repeat=repeat)
    return dataset


def disk_image_batch_dataset(img_paths,
                             batch_size,
                             labels=None,
                             drop_remainder=True,
                             n_prefetch_batch=1,
                             filter_fn=None,
                             map_fn=None,
                             n_map_threads=None,
                             filter_after_map=False,
                             shuffle=True,
                             shuffle_buffer_size=None,
                             repeat=None):
    """Batch dataset of disk image for PNG and JPEG.

    Parameters
    ----------
        img_paths : 1d-tensor/ndarray/list of str
        labels : nested structure of tensors/ndarrays/lists

    """
    if labels is None:
        memory_data = img_paths
    else:
        memory_data = (img_paths, labels)

    def parse_fn(path, *label):
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, 3)  # fix channels to 3
        return (img,) + label

    if map_fn:  # fuse `map_fn` and `parse_fn`
        def map_fn_(*args):
            return map_fn(*parse_fn(*args))
    else:
        map_fn_ = parse_fn

    dataset = memory_data_batch_dataset(memory_data,
                                        batch_size,
                                        drop_remainder=drop_remainder,
                                        n_prefetch_batch=n_prefetch_batch,
                                        filter_fn=filter_fn,
                                        map_fn=map_fn_,
                                        n_map_threads=n_map_threads,
                                        filter_after_map=filter_after_map,
                                        shuffle=shuffle,
                                        shuffle_buffer_size=shuffle_buffer_size,
                                        repeat=repeat)

    return dataset

标签:map,GP,image,batch,dataset,tf,fn,WGAN
来源: https://blog.csdn.net/qq_46588746/article/details/120593601

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有