PyTorch中的model.train()有什么作用?


80

它会打电话吗 forward()nn.Module?我以为我们在调用模型时forward会使用方法。为什么我们需要指定train()?

Answers:


109

model.train()告诉模型您正在训练模型。因此,在培训和测试过程中表现不同的有效层(如辍学,batchnorm等)可以知道发生了什么,因此可以相应地表现。

更多详细信息:它设置训练模式(请参阅源代码)。您可以致电model.eval()model.train(mode=False)告诉您正在测试。期望train功能训练模型有些直观,但是并没有做到这一点。它只是设置模式。


46

这是以下代码module.train()

def train(self, mode=True):
        r"""Sets the module in training mode."""      
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

这是module.eval

def eval(self):
        r"""Sets the module in evaluation mode."""
        return self.train(False)

模式traineval是我们可以在其中设置模块的仅有的两种模式,它们是完全相反的。

那只是一个self.training标志,目前 辍学bachnorm关心该标志。

默认情况下,此标志设置为True


9

有两种方法可以让模型知道您的意图,即您要训练模型还是要使用模型进行评估。在model.train()模型知道的情况下,它必须学习层,并且当我们使用model.eval()它时,表明该模型没有新知识要学习,并且该模型用于测试。 model.eval()这也是必要的,因为在pytorch中,如果我们使用batchnorm,而在测试过程中,如果我们只想传递单个图像,则pytorch会抛出错误(如果model.eval()未指定)。

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.