关于mnist.train.next_batch
在训练mnist数据集的过程中,我们会采用在线的学习方法,利用next_batch功能来不断地获取新的数据集进行训练。关于next_batch的功能以及其返回的数据格式学习一下
功能
我们通过next_batch来获取下一个数据集来对我们的参数进行调整,用法为
#其中的n代表返回多少个训练数据集和对应的标签 batch_n=mnist_data.train.next_batch(n)
输入数据
import numpy as np import tensorflow as tf import input_data #导入mnist数据(以one_hot的格式) mnist_data=input_data.read_data_sets("MNIST_data/",one_hot=True) mnist=mnist_data #遍历两次 for i in range(2): #每次返回两个训练集的数据 batch=mnist_data.train.next_batch(2) #输出它的内容 print (batch) #输出它的类型 print (type(batch))
返回结果
(array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])) <type tuple> (array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])) <type tuple>
可以看到数据返回的是一个元组,元组的第一个元素为一个阵列,2行,784列,第二个元素为预测的标签,为两行十列,是one_hot的数据格式
上一篇:
JS实现多线程数据分片下载