tf.data.Dateset中batch, repeat, shuffle顺序问题
引言
本文旨在介绍tf.data.Dataset中batch, repeat, shuffle以及三者的顺序问题。首先介绍了这三个函数单独作用的结果,而后给出了相互作用下的影响。
一、单独作用
shuffle()
shuffle(buffsize) 用于将数据打乱,其中buffsize的大小越大,数据的混乱程度越高,因为shuffle的实现思路为:** 开辟一可容纳buffsize个数据的缓冲区,初始时将数据的前buffsize个读入缓冲区,而后随机在缓冲区里选择一个输出,同时将数据的第buffsize+1个读入缓冲区**。由此不难理解,如果buffsize很小,比如为1时就根本没有打乱。给出示例代码和结果如下:
data=tf.range(0,10) data=tf.data.Dataset.from_tensor_slices(data) data1=data.shuffle(5) for i in data1: print(i.numpy()) #结果 4 0 5 2 3 1 6 8 7 9
可以测试,如果多次执行代码中的输出程序,每次的打乱结果都会发生变化,但是第一个输出的值永远都是在0~4的范围之内,这是因为我们设置的buffsize=5。
repeat()
repeat(count) 用于将数据重复count次,相当于我们训练时的epoch,示例代码及结果如下:
data1=data.repeat(2) for i in data: print(i.numpy()) #结果 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9
batch()
batch(batch_size)用于将数据划分为多个batch,同时tensorflow中有着很好的调整功能,当最后一个batch不满足batchsize时就以当前长度输出。
data=tf.range(0,10.)[:,None] data=tf.data.Dataset.from_tensor_slices(data) data1=data.batch(4) for i in data1: print(i.numpy()) #结果 [[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.] [9.]]
可以看到最后一个batch只含有2个数据。
相互作用
1. 先repeat再shuffle
其结果就是对repeat后的数据进行打乱,这会使得不同epcoh间的数据被打乱,即前一个epcoh中数据未加载完,下一个epoch中数据可能插入,导致一个epoch中可能数据重复多次。
temp1=data.repeat(2).shuffle(5) for i in temp1: print(i.numpy()) # 结果 # [4.] [0.] [2.] [7.] [8.] [5.] [3.] [6.] [1.] #[1.] [0.] [2.] [5.] [4.] [9.] [9.] [3.] [6.] [7.] [8.]
可以看出,在第一个epoch还未结束(未出现9时)就已经出现了下个epoch的0
2.先shuffle再repeat
先shuffle再repeat,epoch内部打乱,一定先输出完一个epoch内所有值。
temp2=data.shuffle(5).repeat(2) for i in temp2: print(i.numpy()) #结果 # [2.] [4.] [5.] [6.] [0.] [1.] [3.] [9.] [8.] [7.] #[3.] [1.] [4.] [7.] [0.] [6.] [9.] [2.] [8.] [5.]
3. 先batch再repeat
先batch再repeat,为对于batch的复制。
temp3=data.batch(4).repeat(2) for i in temp3: print(i.numpy()) #结果 # [[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.] [9.]] #[[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.] [9.]
4. 先repeat再batch
先repeat再batch是对于重复后数组的分组。
temp4=data.repeat(2).batch(4) for i in temp4: print(i.numpy()) # [[0.] [1.] [2.] [3.]] [[4.] [5.] [6.] [7.]] [[8.] #[9.] [0.] [1.]] [[2.] [3.] [4.] [5.]] [[6.] [7.] #[8.] [9.]]
5.先batch再shuffle
先batch再shuffle是对不同组间的打乱
temp5=data.batch(4).shuffle(5) for i in temp5: print(i.numpy()) # [[4.] [5.] [6.] [7.]] #[[0.] [1.] [2.] [3.]] #[[8.] [9.]]
6.先shuffle再batch
先shuffle再batch是在对打乱后的数据分组
temp6=data.shuffle(5).batch(4) for i in temp6: print(i.numpy()) # [[3.] [2.] [1.] [7.]] #[[6.] [5.] [4.] [0.]] [[9.] [8.]]