在PyTorch中保存经过训练的模型的最佳方法?


191

我一直在寻找其他方法来在PyTorch中保存经过训练的模型。到目前为止,我发现了两种选择。

  1. 使用torch.save()保存模型,使用torch.load()加载模型。
  2. model.state_dict()保存训练的模型,model.load_state_dict()加载保存的模型。

我碰到过这种讨论,其中建议方法2优于方法1。

我的问题是,为什么选择第二种方法呢?仅仅是因为torch.nn模块具有这两个功能,我们被鼓励使用它们吗?


2
我认为这是因为torch.save()也保存所有中间变量,例如用于反向传播的中间输出。但是您只需要保存模型参数,例如权重/偏差等。有时前者可能比后者大得多。
杨大为

2
我测试torch.save(model, f)torch.save(model.state_dict(), f)。保存的文件大小相同。现在我很困惑。另外,我发现使用泡菜保存model.state_dict()非常慢。我认为最好的方法是使用,torch.save(model.state_dict(), f)因为您可以处理模型的创建,而手电筒可以处理模型权重的加载,从而消除可能的问题。参考:exploring.pytorch.org/t/saving-torch-models/838/4
杨大为

好像PyTorch在其教程部分中更明确地解决了此问题-此处的答案中未列出很多好的信息,其中包括一次保存多个模型和热启动模型。
whlteXbread

Answers:


211

我在他们的github仓库中找到了此页面,我将内容粘贴在这里。


推荐的模型保存方法

序列化和还原模型有两种主要方法。

第一个(推荐)仅保存和加载模型参数:

torch.save(the_model.state_dict(), PATH)

然后再:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

然后再:

the_model = torch.load(PATH)

但是,在这种情况下,序列化的数据将绑定到所使用的特定类和确切的目录结构,因此在其他项目中使用时或经过一些严重的重构后,它可能以各种方式中断。


6
据@smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/...模型重新加载默认情况下训练模式。因此,在加载后需要手动调用the_model.eval(),如果要加载以进行推理,则不要恢复训练。
WillZ '18年

第二种方法给出了stackoverflow.com/questions/53798009/…Windows 10上的错误。无法解决
Gulzar

是否可以保存任何选项,而无需访问模型类?
Michael D

使用这种方法,如何跟踪需要传递给负载情况的* args和** kwargs?
Mariano Kamp

143

这取决于您想做什么。

案例1:保存模型以供您自己进行推断:保存模型,还原模型,然后将模型更改为评估模式。这样做是因为您通常在构造上具有BatchNormDropout图层,这些图层默认情况下处于训练模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例2:保存模型以便以后继续训练:如果您需要继续训练将要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器的状态,时期,得分等。您可以这样操作:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢复训练,您将执行以下操作:state = torch.load(filepath),然后恢复每个对象的状态,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,因此在加载时恢复状态后,请勿model.eval()再致电。

案例3:无法访问您的代码的其他人可以使用的模型:在Tensorflow中,您可以创建一个.pb文件,该文件定义了体系结构和模型权重。这非常方便,尤其是在使用时Tensorflow serve。在Pytorch中执行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式仍然不能保证安全,并且由于pytorch仍在进行大量更改,因此我不建议这样做。


1
这3种情况下是否有推荐的档案结尾?还是总是.pth?
Verena Haunschmid

1
在情况3中,torch.load仅返回OrderedDict。您如何获得模型以进行预测?
Alber8295

嗨,我可以知道如何执行上面提到的“案例2:保存模型以稍后恢复训练”吗?我设法将检查点加载到模型,然后无法运行或恢复训练模型,例如“ model.to(设备)模型= train_model_epoch(模型,条件,优化器,
预定时间

1
嗨,对于一种推理的情况,在pytorch官方文档中说,必须保存优化器state_dict以便推理或完成训练。“保存通用检查点以用于推理或继续训练时,您必须保存的不仅仅是模型的state_dict。保存优化器的state_dict也是很重要的,因为它包含随着模型训练而更新的缓冲区和参数“
Mohammed Awney

在情况#3中,应在某处定义模型类。
Michael D

11

泡菜的Python库实现二进制协议的序列化和反序列化Python对象。

当您import torch(或当您使用PyTorch)时,它将import pickle为您而您不需要调用pickle.dump()pickle.load()直接调用,这是保存和加载对象的方法。

事实上,torch.save()torch.load()将包裹pickle.dump()pickle.load()为您服务。

一个state_dict对方的回答值得提及的只是几个音符。

什么state_dict我们有内部PyTorch?实际上有两个state_dict秒。

PyTorch模型torch.nn.Module具有model.parameters()调用以获取可学习的参数(w和b)。这些可学习的参数,一旦被随机设置,将随着我们的学习而随着时间而更新。可学习的参数是第一个state_dict

第二个state_dict是优化器状态字典。您还记得优化器用于改善我们的可学习参数。但是优化器state_dict是固定的。在那没什么可学的。

由于state_dict对象是Python字典,因此可以轻松地保存,更新,更改和还原对象,从而为PyTorch模型和优化器增加了很多模块化。

让我们创建一个超级简单的模型来解释这一点:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

此代码将输出以下内容:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

请注意,这是最小模型。您可以尝试添加顺序堆栈

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm层)才在模型的中具有条目state_dict

不可学习的东西属于优化器对象state_dict,该对象包含有关优化器状态以及所用超参数的信息。

故事的其余部分是相同的。在推理阶段(这是我们训练后使用模型的阶段)进行预测;我们会根据所学的参数进行预测。因此,为了进行推断,我们只需要保存参数model.state_dict()

torch.save(model.state_dict(), filepath)

并在以后使用model.load_state_dict(torch.load(filepath))model.eval()

注意:不要忘了最后一行,model.eval()在加载模型之后,这是至关重要的。

也不要试图保存torch.save(model.parameters(), filepath)。该model.parameters()只是生成对象。

另一方面,torch.save(model, filepath)保存模型对象本身,但请记住,模型没有优化程序state_dict。检查@Jadiel de Armas的其他出色答案,以保存优化程序的状态字典。


尽管这不是一个简单的解决方案,但对问题的实质进行了深入分析!赞成。
Jason Young

7

常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型。

保存/加载整个模型 保存:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

加载:

模型类必须在某处定义

model = torch.load(PATH)
model.eval()

3

如果您要保存模型并希望以后继续训练,请执行以下操作:

单个GPU: 保存:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

多GPU: 保存

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

加载:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.