机器学习-LSTM中的几个参数理解
背景介绍
时间序列的处理使用RNN更为有效。但RNN中的一些参数理解起来与CNN差别很大,这篇文章主要梳理一下RNN中LSTM架构的几个关键参数以及如何理解这些参数。
以pytorch为例,我们首先看一下LSTM网络的构建过程
class RNN(nn.Module): def __init__(self): super(RNN, self).__init__() self.rnn = nn.LSTM( input_size=1, hidden_size=64, num_layers=1, batch_first=True, ) self.out = nn.Linear(64, 2) def forward(self, x): r_out, (h_n, h_c) = self.rnn(x, None) out = self.out(h_n[0]) return out
我们使用一个比较简单理解的例子来解释一下这几个主要参数的含义,比如我们用30天的买东西的数据来预测第31天的,每天采集一组数据,这组数据可以表示为
day1 : {面包:5个,泡面3个,火腿肠2个,卤蛋2个,可乐2个} day2 : {面包:3个,泡面1个,火腿肠2个,卤蛋1个,可乐1个} 以此类推
这里我们可以看到,我们一共有30天的数据,每天的数据包含5个种类 。
input_size
特征的长度,在我们的例子中,就是每一天的数据中包含几个维度,这里就是5。如果是做自然语言处理,那embedding之后的size就是这个input_size。通常数据维度越多可能会使预测更准,但是会带来维度灾难的问题,维度达到某个程度以后,不但性能不一定会提升,还会打来巨大的计算消耗,这里需要使用者自己去权衡。
hidden_size
这个是隐藏层的参数,通常决定了网络的复杂程度,这是一个魔调的参数,比如你用多层感知机,里边有多少个神经元,也是并没有一个最优的值,通常用64,但是不绝对,需要在实际任务的执行中再用。
num_layers
使用几层LSTM
batch_first
这个参数与最终输入网络的格式相关,正常情况下,我们构建好LSTM网络后,输入的X的格式为
X = [SEQ_LEN, BS, input_size]
这里的SEQ_LEN指的是数据的长度,在这里就是30。BS就是batch_size,当batch_first参数为true以后,输入X的形式就变成了
X = [BS, SEQ_LEN, input_size]
这个东西在训练过程中不要配置错误。
output和h_n的关系
output是每个时间步的输出,h_n 是最后一个时间步的输出,即是 h_n = output[:, -1, :],通常使用这两种都可以,感兴趣的读者可以自己写段代码比较一下。
上一篇:
通过多线程提高代码的执行效率例子