A biblioteca pickle Python implementa protocolos binários para serializar e desserializar um objeto Python.
Quando você import torch
(ou quando você usa PyTorch) que vai import pickle
para você e você não precisa chamar pickle.dump()
e pickle.load()
diretamente, o que são os métodos para salvar e carregar o objeto.
Na verdade, torch.save()
e torch.load()
vai embrulhar pickle.dump()
e pickle.load()
para você.
A state_dict
outra resposta mencionada merece apenas mais algumas notas.
O state_dict
que temos dentro do PyTorch? Na verdade, existem dois state_dict
s.
O modelo PyTorch é torch.nn.Module
tem model.parameters()
chamada para obter parâmetros learnable (w eb). Esses parâmetros aprendíveis, uma vez definidos aleatoriamente, serão atualizados ao longo do tempo à medida que aprendemos. Parâmetros aprendíveis são os primeiros state_dict
.
O segundo state_dict
é o ditado de estado do otimizador. Você se lembra que o otimizador é usado para melhorar nossos parâmetros aprendíveis. Mas o otimizador state_dict
é fixo. Nada a aprender lá.
Como os state_dict
objetos são dicionários Python, eles podem ser facilmente salvos, atualizados, alterados e restaurados, adicionando uma grande modularidade aos modelos e otimizadores do PyTorch.
Vamos criar um modelo super simples para explicar isso:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Este código produzirá o seguinte:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Observe que este é um modelo mínimo. Você pode tentar adicionar uma pilha de
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Observe que apenas as camadas com parâmetros que podem ser aprendidos (camadas convolucionais, camadas lineares etc.) e buffers registrados (camadas batchnorm) têm entradas nos modelos state_dict
.
Coisas não aprendidas pertencem ao objeto otimizador state_dict
, que contém informações sobre o estado do otimizador, bem como os hiperparâmetros usados.
O resto da história é o mesmo; na fase de inferência (essa é uma fase em que usamos o modelo após o treinamento) para prever; nós previmos com base nos parâmetros que aprendemos. Portanto, para a inferência, só precisamos salvar os parâmetros model.state_dict()
.
torch.save(model.state_dict(), filepath)
E para usar mais tarde model.load_state_dict (torch.load (filepath)) model.eval ()
Nota: Não esqueça a última linha, model.eval()
isto é crucial após o carregamento do modelo.
Também não tente salvar torch.save(model.parameters(), filepath)
. O model.parameters()
é apenas o objeto gerador.
Por outro lado, torch.save(model, filepath)
salva o próprio objeto do modelo, mas lembre-se de que o modelo não possui o otimizador state_dict
. Verifique a outra excelente resposta de @Jadiel de Armas para salvar o ditado de estado do otimizador.