在Keras中使用不同长度的示例训练RNN


59

我正在尝试开始学习RNN,并且正在使用Keras。我了解香草RNN和LSTM层的基本前提,但是我无法理解培训的某些技术要点。

keras文档中,它说到RNN层的输入必须具有形状(batch_size, timesteps, input_dim)。这表明所有训练示例都具有固定的序列长度,即timesteps

但这不是特别典型,是吗?我可能想让RNN对不同长度的句子进行运算。当我在某种语料库上对其进行训练时,我将为它提供成批的句子,这些句子的长度各不相同。

我想要做的显而易见的事情是找到训练集中任何序列的最大长度并将其零填充。但这是否意味着我无法在测试时进行输入长度大于该长度的预测?

我想这是一个关于Keras特定实现的问题,但是我也想问人们通常在遇到这种问题时通常会做什么。


@kbrose是正确的。但是,我有一个担忧。在示例中,您有一个非常特殊的无限收益产生器。更重要的是,它设计为可生产1000个批次的批次。实际上,这即使不是不可能的,也很难满足。您需要重新组织条目,以便将相同长度的条目排列在一起,并且需要仔细设置批次拆分位置。此外,您没有机会在所有批次中进行洗牌。所以我的观点是:除非您完全知道自己在做什么,否则请不要在Keras中使用可变长度的输入。使用填充并将Masking图层设置为可忽略
Bs He

Answers:


55

这表明所有训练示例都具有固定的序列长度,即timesteps

这不是很正确,因为该尺寸可以是None,即可变长度。在一个批处理中,您必须具有相同数量的时间步长(通常在这里看到0填充和掩码)。但是在批次之间没有这种限制。在推断期间,您可以有任何长度。

创建随机时间长度批次的训练数据的示例代码。

from keras.models import Sequential
from keras.layers import LSTM, Dense, TimeDistributed
from keras.utils import to_categorical
import numpy as np

model = Sequential()

model.add(LSTM(32, return_sequences=True, input_shape=(None, 5)))
model.add(LSTM(8, return_sequences=True))
model.add(TimeDistributed(Dense(2, activation='sigmoid')))

print(model.summary(90))

model.compile(loss='categorical_crossentropy',
              optimizer='adam')

def train_generator():
    while True:
        sequence_length = np.random.randint(10, 100)
        x_train = np.random.random((1000, sequence_length, 5))
        # y_train will depend on past 5 timesteps of x
        y_train = x_train[:, :, 0]
        for i in range(1, 5):
            y_train[:, i:] += x_train[:, :-i, i]
        y_train = to_categorical(y_train > 2.5)
        yield x_train, y_train

model.fit_generator(train_generator(), steps_per_epoch=30, epochs=10, verbose=1)

这就是它的打印内容。请注意,输出形状(None, None, x)指示可变的批次大小和可变的时间步长大小。

__________________________________________________________________________________________
Layer (type)                            Output Shape                        Param #
==========================================================================================
lstm_1 (LSTM)                           (None, None, 32)                    4864
__________________________________________________________________________________________
lstm_2 (LSTM)                           (None, None, 8)                     1312
__________________________________________________________________________________________
time_distributed_1 (TimeDistributed)    (None, None, 2)                     18
==========================================================================================
Total params: 6,194
Trainable params: 6,194
Non-trainable params: 0
__________________________________________________________________________________________
Epoch 1/10
30/30 [==============================] - 6s 201ms/step - loss: 0.6913
Epoch 2/10
30/30 [==============================] - 4s 137ms/step - loss: 0.6738
...
Epoch 9/10
30/30 [==============================] - 4s 136ms/step - loss: 0.1643
Epoch 10/10
30/30 [==============================] - 4s 142ms/step - loss: 0.1441

这次真是万分感谢。但是,如果我们对序列进行0填充,则会影响隐藏状态和存储单元,因为我们继续将x_t传递为0(如果实际上不传递任何内容)。通常fit(),我们可以传递sequence_lenth参数来指定要排除的序列长度。似乎生成器方法不允许忽略0个序列?
GRS

1
@GRS您的生成器可以返回3的元组(inputs, targets, sample_weights),并且您可以将sample_weights0-pad设置为0。但是,我不确定这对于双向RNN能否完美工作。
kbrose

这很有帮助,但我希望它也包括一个使用model.predict_generator测试集的示例。当我尝试使用生成器进行预测时,会出现有关级联的错误(测试集也具有可变长度的序列)。我的解决方案model.predict是以一种怪异的方式使用该标准。也许这将更适合一个新问题?
米奇

@mickey听起来像是另一个问题。这个问题是关于训练,而不是预测。
kbrose

7

@kbrose似乎有更好的解决方案

我想要做的显而易见的事情是找到训练集中任何序列的最大长度并将其零填充。

这通常是一个很好的解决方案。也许尝试最大序列长度+100。使用最适合您的应用程序的东西。

但这是否意味着我不能在测试时进行输入长度大于该长度的预测?

不必要。在喀拉拉邦使用固定长度的原因是因为它通过创建固定形状的张量极大地提高了性能。但这仅用于培训。经过培训,您将学到适合您任务的权重。

假设经过几个小时的训练,您意识到模型的最大长度不够大/小,现在需要更改时间步长,只需从旧模型中提取学习的权重,并使用新的时间步长构建新模型并将学习到的权重注入其中。

您可能可以使用类似的方法进行此操作:

new_model.set_weights(old_model.get_weights())

我自己还没有尝试过。请尝试一下,并在此处发布您的结果,以使所有人受益。这里有一些链接: 一个 2


1
您确实可以有可变长度的输入,而无需引入hack之类的东西max length + 100。请参阅我的答案以获取示例代码。
kbrose

1
将权重转移到具有更多时间步长的模型确实可以很好地工作!我增加了Bidirectional(LSTM)()和的时间步长RepeatVector(),并且预测是完全可行的。
komodovaran_

@kbrose这不是一个技巧,而是您通常的做法。使用batch_size为1太慢了,并且keras启用了遮罩层,因此遮罩不会影响丢失。
Ferus
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.