将泡菜的Python库实现二进制协议的序列化和反序列化Python对象。
当您import torch
(或当您使用PyTorch)时,它将import pickle
为您而您不需要调用pickle.dump()
和pickle.load()
直接调用,这是保存和加载对象的方法。
事实上,torch.save()
和torch.load()
将包裹pickle.dump()
和pickle.load()
为您服务。
一个state_dict
对方的回答值得提及的只是几个音符。
什么state_dict
我们有内部PyTorch?实际上有两个state_dict
秒。
PyTorch模型torch.nn.Module
具有model.parameters()
调用以获取可学习的参数(w和b)。这些可学习的参数,一旦被随机设置,将随着我们的学习而随着时间而更新。可学习的参数是第一个state_dict
。
第二个state_dict
是优化器状态字典。您还记得优化器用于改善我们的可学习参数。但是优化器state_dict
是固定的。在那没什么可学的。
由于state_dict
对象是Python字典,因此可以轻松地保存,更新,更改和还原对象,从而为PyTorch模型和优化器增加了很多模块化。
让我们创建一个超级简单的模型来解释这一点:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
此代码将输出以下内容:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
请注意,这是最小模型。您可以尝试添加顺序堆栈
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm层)才在模型的中具有条目state_dict
。
不可学习的东西属于优化器对象state_dict
,该对象包含有关优化器状态以及所用超参数的信息。
故事的其余部分是相同的。在推理阶段(这是我们训练后使用模型的阶段)进行预测;我们会根据所学的参数进行预测。因此,为了进行推断,我们只需要保存参数model.state_dict()
。
torch.save(model.state_dict(), filepath)
并在以后使用model.load_state_dict(torch.load(filepath))model.eval()
注意:不要忘了最后一行,model.eval()
在加载模型之后,这是至关重要的。
也不要试图保存torch.save(model.parameters(), filepath)
。该model.parameters()
只是生成对象。
另一方面,torch.save(model, filepath)
保存模型对象本身,但请记住,模型没有优化程序state_dict
。检查@Jadiel de Armas的其他出色答案,以保存优化程序的状态字典。