ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

TensorFlow学习笔记之DataSet中shuffle,batch和repeat的用法详解

2020-12-05 17:00:10  阅读:239  来源: 互联网

标签:repeat shuffle buffer batch print 数据 ds size


话不多说,看代码
代码git链家:https://github.com/lankuohsing/TensorFlowStudy/blob/master/dataset_usage/shuffle_batch_repeat.py

# -*- coding: utf-8 -*-
"""
Created on Fri Dec  4 21:08:13 2020

@author: lankuohsing
"""
import tensorflow as tf
import numpy as np

# In[]

ori_data = np.arange(20).reshape((4, 5))
ds = tf.data.Dataset.from_tensor_slices(ori_data)
print(ori_data)

'''
shuffle: 维持一个buffer_size大小的缓存,打乱后供后续打包成batch输出
具体来说,从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本
buffer中样本个数不足buffer_size,继续从data数据集中按顺序填充至buffer_size,
此时会再次打乱
batch: 打包成一个batch
repeat: 重复多次,构造成多个epoch
'''
# In[]
def f1(ds):
    # 最常用的顺序
    # 解释:相当于把所有数据先打乱,然后打包成batch输出,整体数据重复2个epoch
    # 特点:1.一个batch中的数据不会重复;2.每个epoch的最后一个batch的尺寸小于等于batch_size
    ds = ds.shuffle(buffer_size=100)
    ds = ds.batch(3)
    ds = ds.repeat(count=2)
    # 构造获取数据的迭代器
    iters = ds.make_one_shot_iterator()
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
    sess = tf.Session()
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    for i in range(0,4):
        print(i)
        print(sess.run(batch))
# In[]
def f2(ds):
    # 解释:相当于把所有数据先打乱,再把所有数据重复两个epoch,然后将重复两个epoch的数据放在一起,最后打包成batch_size输出
    # 特点:1.因为把数据复制两份,还进行打乱,因此某个batch数据可能会重复,而且出现重复数据的batch只会是两个batch交叉的位置;2.最后一个batch的尺寸小于等于batch_size
    ds = ds.shuffle(buffer_size=100)
    ds = ds.repeat(count=2)
    ds = ds.batch(3)
    # 构造获取数据的迭代器
    iters = ds.make_one_shot_iterator()
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
    sess = tf.Session()
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    for i in range(0,3):
        print(i)
        print(sess.run(batch))
# In[]
def f3(ds):
    # 解释:相当于把所有数据先打包成batch,然后把打包成batch的数据重复两遍,最后再将所有batch打乱进行输出
    # 1.打乱的是batch;2.某些batch的尺寸小于等于batch_size,因为是对batch进行打乱,所以这些batch不一定是最后一个
    ds = ds.batch(3)
    ds = ds.repeat(count=2)
    ds = ds.shuffle(buffer_size=100)
    # 构造获取数据的迭代器
    iters = ds.make_one_shot_iterator()
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
    sess = tf.Session()
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    for i in range(0,4):
        print(i)
        print(sess.run(batch))
# In[]
def f4(ds):
    # 解释:相当于把所有数据先打包成batch,然后再将所有batch打乱打,最后包成batch的数据重复两遍并输出
    # 1.打乱的是batch;2.某些batch的尺寸小于等于batch_size,因为是对batch进行打乱,所以这些batch不一定是最后一个
    ds = ds.batch(3)
    ds = ds.shuffle(buffer_size=100)
    ds = ds.repeat(count=2)
    # 构造获取数据的迭代器
    iters = ds.make_one_shot_iterator()
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
    sess = tf.Session()
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    for i in range(0,4):
        print(i)
        print(sess.run(batch))
# In[]
def f5(ds):
    # 解释:相当于把所有数据先重复两遍,然后打乱,最后打包成batch
    # 1.某些batch的数据可能重复;2最后一个batch的尺寸小于等于batch_size.
    ds = ds.repeat(count=2)
    ds = ds.shuffle(buffer_size=100)
    ds = ds.batch(3)

    # 构造获取数据的迭代器
    iters = ds.make_one_shot_iterator()
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
    sess = tf.Session()
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    for i in range(0,3):
        print(i)
        print(sess.run(batch))
# In[]
def f6(ds):
    # 解释:相当于把所有数据先重复两遍,然后打包成batch,最后打乱
    # 1.batch内部的数据不会重复;2.某一个batch的尺寸小于等于batch_size,但是打乱了所以不一定在最后一个.
    ds = ds.repeat(count=2)
    ds = ds.batch(3)
    ds = ds.shuffle(buffer_size=100)


    # 构造获取数据的迭代器
    iters = ds.make_one_shot_iterator()
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
    sess = tf.Session()
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    for i in range(0,3):
        print(i)
        print(sess.run(batch))

标签:repeat,shuffle,buffer,batch,print,数据,ds,size
来源: https://blog.csdn.net/THUChina/article/details/110699546

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有