ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python – keras model.fit()用tf.Dataset对象的初始化迭代器

2019-07-10 14:07:34  阅读:1005  来源: 互联网

标签:python tensorflow keras


我正在使用tf.keras API来构建我的CNN模型,使用tf.Dataset API为我的模型创建输入管道.来自tf.keras.datasets的mnist数据集用于测试,并通过执行代码在内存中准备:

(train_images,train_labels),(test_images,test_labels) = tf.keras.datasets.mnist.load_data()

以及一些与我的keras模型兼容的预处理:

Train_images = np.expand_dims(train_images,3).astype('float')/255.0
Test_images = np.expand_dims(test_images,3).astype('float')/255.0

Train_labels = tf.keras.utils.to_categorical(train_labels)
Test_labels = tf.keras.utils.to_categorical(test_labels)

这些数据作为数组存储在内存中,有两个选项可用于创建数据集对象.第一个是使用tf.data.Dataset.from_tensor_slices:

image = tf.data.Dataset.from_tensor_slices((Train_images,Train_labels))

并将此结果对象输入到model.fit():

model.fit(x=image,steps_per_epoch=1000)

或者通过以下方式输入此数据集的迭代器:

iterator = image.make_one_shot_iterator()

model.fit(x=iterator,steps_per_epoch=1000)

这两个选项都可以正常工作,因为这里名为image的数据集是使用内存中的数据创建的.但是,根据这里的Importing Data,我们可能希望避免这样做,因为它会多次复制数据并占用内存.所以另一种选择是基于tf.placeholder和初始化迭代器创建这样的数据集对象:

X = tf.placeholder(tf.float32,shape = [60000,28,28,1])
Y = tf.placeholder(tf.float32,shape = [60000,10])
image2 = tf.data.Dataset.from_tensor_slices((X,Y))
iterator2 = image.make_initializable_iterator()

with tf.Session() as sess:
  sess.run(iterator2.initializer,feed_dict={X:Train_images,Y:Train_labels}
  sess.run(iterator2.get_next())

当使用tf.Session()进行内存中的数据并避免数据的多个副本时,这种迭代器工作正常.但是我找不到让它与keras.model.fit()一起使用的方法,因为你无法真正调用iterator.initializer或者在那里提供任何数据.有没有办法使用这种迭代器?

解决方法:

我不认为keras正式支持传递可初始化迭代器的情况,正如您所指出的,没有地方可以提供占位符和值映射.

但是,使用keras callbacks可以解决方法:

import tensorflow as tf
import numpy as np
import pandas as pd

# Make sure only tensorflow.keras is imported, don't mix with keras
from tensorflow.keras import layers
import tensorflow.keras.backend as K

# example data
x_values = np.random.randn(200, 100).astype(np.float32)
y_labels = np.random.randint(low=0, high=9, size=200)

graph = tf.Graph()
with graph.as_default():
    # make datasets from placeholders as in https://www.tensorflow.org/guide/datasets#reading_input_data
    # X:
    features_placeholder = tf.placeholder(tf.float32, x_values.shape, name='features')
    dataset_x = tf.data.Dataset.from_tensor_slices({'x': features_placeholder})
    # Y:
    labels_placeholder = tf.placeholder(tf.float32, [None], name='labels')
    dataset_y = tf.data.Dataset.from_tensor_slices({'y': labels_placeholder})

    # compose datasets to make X-Y pairs for training
    dataset0 = tf.data.Dataset.zip((dataset_x, dataset_y))
    dataset0 = dataset0.batch(16).repeat()

    # build model with keras
    inputs = tf.keras.Input(name='x', shape=(x_values.shape[1],))
    mlp1 = layers.Dense(16, name='mlp-1', activation='relu')
    mlp1_out = mlp1(inputs)
    output = layers.Dense(1, name='y', activation='linear')
    output_out = output(mlp1_out)
    model = tf.keras.Model(inputs=inputs, outputs=output_out)
    # The compile step specifies the training configuration.
    model.compile(optimizer=tf.train.RMSPropOptimizer(0.001), loss='mse', metrics=['mse'])

    iterator = dataset0.make_initializable_iterator()
    feed_dict = { labels_placeholder: y_labels, features_placeholder: x_values }

    class InitIteratorCallback(tf.keras.callbacks.Callback):
        """
        Ensures that placeholders in dataset are initialized before each epoch begins
        """

        def on_epoch_begin(self, epoch, logs=None):
            sess = K.get_session()
            sess.run(iterator.initializer, feed_dict=feed_dict)


    model.fit(iterator, callbacks=[InitIteratorCallback()],
              epochs=10, steps_per_epoch=300)

标签:python,tensorflow,keras
来源: https://codeday.me/bug/20190710/1424879.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有