ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Tensorflow【实战Google深度学习框架】使用 HDF5 处理大型数据集with TFLearn

2021-06-10 16:55:59  阅读:281  来源: 互联网

标签:h5f Google network HDF5 cifar10 TFLearn test import data


文章目录

1.HDF5文件

在这里插入图片描述到目前为止,我们使用的数据集都能够全部加载到内存中。对于小数据集,我们可以加载全部图像数据到内存中,进行预处理,并进行前向传播处理。然而,对于大规模数据集(比如ImageNet),我们需要创建数据生成器,每次只访问一小部分数据集(比如mini-batch),然后对batch数据进行预处理和前向传播。

Keras模块很方便进行数据加载,可以使用磁盘上的原始文件路径作为训练过程的输入。你不需要将整个数据集存储在内存中——只需为Keras数据生成器提供图像路径,生成器会自动从路径中加载数据并进行前向传播。

然而,这种方法非常低效。读取磁盘上的每张图像都需要一个I/O操作,这样会造成一定的延迟。训练深度学习网络本身已经够慢了,所以我们应该尽可能避免I/O瓶颈。

一个比较合理的解决方案是将原始图像生成HDF5数据集,,只是这一次我们存储的是原始图像,而不是提取的特征。HDF5不仅可以存储大量的数据集,而且还可以用于I/O操作,特别是用于从文件中提取batch(称为“片”)。我们将在磁盘上的原始图像保存到HDF5文件中,这可以让模型快速的遍历数据集并在其上训练深度学习网络。

2.github代码:
# -*- coding: utf-8 -*-
"""
Example on how to use HDF5 dataset with TFLearn. HDF5 is a data model,
library, and file format for storing and managing data. It can handle large
dataset that could not fit totally in ram memory. Note that this example
just give a quick compatibility demonstration. In practice, there is no so
real need to use HDF5 for small dataset such as CIFAR-10.
"""

from __future__ import division, print_function, absolute_import

import tflearn
from tflearn.layers.core import *
from tflearn.layers.conv import *
from tflearn.data_utils import *
from tflearn.layers.normalization import *
from tflearn.layers.estimator import regression

# CIFAR-10 Dataset
from tflearn.datasets import cifar10
(X, Y), (X_test, Y_test) = cifar10.load_data()
Y = to_categorical(Y)
Y_test = to_categorical(Y_test)

# Create a hdf5 dataset from CIFAR-10 numpy array
import h5py
h5f = h5py.File('data.h5', 'w')
h5f.create_dataset('cifar10_X', data=X)
h5f.create_dataset('cifar10_Y', data=Y)
h5f.create_dataset('cifar10_X_test', data=X_test)
h5f.create_dataset('cifar10_Y_test', data=Y_test)
h5f.close()

# Load hdf5 dataset
h5f = h5py.File('data.h5', 'r')
X = h5f['cifar10_X']
Y = h5f['cifar10_Y']
X_test = h5f['cifar10_X_test']
Y_test = h5f['cifar10_Y_test']

# Build network
network = input_data(shape=[None, 32, 32, 3], dtype=tf.float32)
network = conv_2d(network, 32, 3, activation='relu')
network = max_pool_2d(network, 2)
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2)
network = fully_connected(network, 512, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 10, activation='softmax')
network = regression(network, optimizer='adam',
                     loss='categorical_crossentropy',
                     learning_rate=0.001)

# Training
model = tflearn.DNN(network, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=50, shuffle=True, validation_set=(X_test, Y_test),
          show_metric=True, batch_size=96, run_id='cifar10_cnn')

h5f.close()
3.源码地址:

https://github.com/tflearn/tflearn/blob/master/examples/basics/use_hdf5.py

4.tf入门合集

https://blog.csdn.net/bigquant/article/details/85339665?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

标签:h5f,Google,network,HDF5,cifar10,TFLearn,test,import,data
来源: https://blog.51cto.com/u_14013325/2890574

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有