如何训练深度网络的LSTM层


13

我正在使用lstm和前馈网络对文本进行分类。

我将文本转换为一键向量,然后将其输入到lstm中,这样我就可以将其总结为单个表示形式。然后,我将其馈送到另一个网络。

但是我如何训练LSTM?我只想按顺序对文本进行分类-是否应在未经培训的情况下进行输入?我只想将段落表示为单个项目,然后将其输入分类器的输入层。

我将不胜感激与此有关的任何建议!

更新:

所以我有一个lstm和一个分类器。我将lstm的所有输出并平均池化,然后将平均值输入分类器。

我的问题是我不知道如何训练lstm或分类器。我知道对于lstm的输入应该是什么,对于该输入的分类器的输出应该是什么。由于它们是两个单独的网络,它们只是按顺序激活的,因此我需要知道和不知道lstm的理想输出应该是什么,它也是分类器的输入。有没有办法做到这一点?

Answers:


10

从LSTM开始的最佳位置是A. Karpathy的博客文章http://karpathy.github.io/2015/05/21/rnn-efficiency/。如果您使用的是Torch7(强烈建议),则可从github https://github.com/karpathy/char-rnn获得源代码。

我也会尝试稍微更改您的模型。我将使用多对一方法,以便您通过查找表输入单词,并在每个序列的末尾添加一个特殊的词,这样,只有当您输入“序列的末尾”符号时,您才能阅读分类根据您的训练标准输出并计算误差。这样,您将直接在有监督的上下文中进行训练。

另一方面,一种更简单的方法是使用para2vec(https://radimrehurek.com/gensim/models/doc2vec.html)提取输入文本的特征,然后在特征之上运行分类器。段落矢量特征提取非常简单,在python中将是:

class LabeledLineSentence(object):
    def __init__(self, filename):
        self.filename = filename

    def __iter__(self):
        for uid, line in enumerate(open(self.filename)):
            yield LabeledSentence(words=line.split(), labels=['TXT_%s' % uid])

sentences = LabeledLineSentence('your_text.txt')

model = Doc2Vec(alpha=0.025, min_alpha=0.025, size=50, window=5, min_count=5, dm=1, workers=8, sample=1e-5)
model.build_vocab(sentences)

for epoch in range(epochs):
    try:
        model.train(sentences)
    except (KeyboardInterrupt, SystemExit):
        break

感谢您的回覆。我会考虑的。您对问题中的特定问题有任何建议吗-我已对其进行了更新。
wordSmith,2015年

我认为您描述的过程不会产生任何结果。您将针对LSTM进行哪些培训?我不确定我是否理解为什么在这种情况下将LSTM用于整个句子的无监督特征学习。您是否有任何有关我可以帮助您的方法的文献?arxiv.org/abs/1306.3584和您可能也会对此感兴趣。
Yannis Assael,2015年

我将基于过去的文字段落及其类的数据集来训练lstm。我不打算使用无监督学习。我想手动训练它,但不知道如何。这是我在没有机器学习库的情况下实现的lstm和分类器的实现,我知道它可以工作:pastebin.com/63Cqrnef lstm具有deepActivate的功能,该功能可以激活lstm,然后激活我在课题中提到的分类器。这类似于我要实现的内容:deeplearning.net/tutorial/lstm.html
wordSmith 2015年

但是,当我尝试将它们作为一个网络激活时,每个输出层都没有定义。有关更多信息,请访问:stats.stackexchange.com/q/159922/81435
wordSmith,2015年

1
非常感谢你!您提供的帮助超出了所需。感谢您超越自我。
wordSmith,2015年
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.