在keras中model.fit_generator()和model.fit()有什么区别
首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练。
fit
fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)
以给定数量的轮次(数据集上的迭代)训练模型。
参数
返回
一个 History 对象。其 History.history 属性是连续 epoch 训练损失和评估值,以及验证集损失和评估值的记录(如果适用)。
异常
-
RuntimeError: 如果模型从未编译。 ValueError: 在提供的输入数据与模型期望的不匹配的情况下。
fit_generator
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
使用 Python 生成器(或 Sequence 实例)逐批生成的数据,按批次训练模型。
生成器与模型并行运行,以提高效率。 例如,这可以让你在 CPU 上对图像进行实时数据增强,以在 GPU 上训练模型。
keras.utils.Sequence 的使用可以保证数据的顺序, 以及当 use_multiprocessing=True 时 ,保证每个输入在每个 epoch 只使用一次。
参数
返回
一个 History 对象。其 History.history 属性是连续 epoch 训练损失和评估值,以及验证集损失和评估值的记录(如果适用)。
异常
-
ValueError: 如果生成器生成的数据格式不正确。
例
def generate_arrays_from_file(path): while True: with open(path) as f: for line in f: # 从文件中的每一行生成输入数据和标签的 numpy 数组, x1, x2, y = process_line(line) yield ({input_1: x1, input_2: x2}, {output: y}) f.close() model.fit_generator(generate_arrays_from_file(/my_file.txt), steps_per_epoch=10000, epochs=10)
总结:
在使用fit函数的时候,需要有batch_size,但是在使用fit_generator时需要有steps_per_epoch