Melhor maneira de salvar um modelo treinado no PyTorch?


192

Eu estava procurando maneiras alternativas de salvar um modelo treinado no PyTorch. Até agora, encontrei duas alternativas.

  1. torch.save () para salvar um modelo e torch.load () para carregar um modelo.
  2. model.state_dict () para salvar um modelo treinado e model.load_state_dict () para carregar o modelo salvo.

Eu me deparei com essa discussão em que a abordagem 2 é recomendada sobre a abordagem 1.

Minha pergunta é: por que a segunda abordagem é preferida? É apenas porque os módulos torch.nn têm essas duas funções e somos incentivados a usá-los?


2
Eu acho que é porque torch.save () salva todas as variáveis ​​intermediárias também, como saídas intermediárias para uso de propagação traseira. Mas você só precisa salvar os parâmetros do modelo, como peso / tendência, etc. Às vezes, o primeiro pode ser muito maior que o último.
Dawei Yang

2
Eu testei torch.save(model, f)e torch.save(model.state_dict(), f). Os arquivos salvos têm o mesmo tamanho. Agora eu estou confuso. Além disso, eu achei o uso de pickle para salvar model.state_dict () extremamente lento. Eu acho que a melhor maneira é usar, torch.save(model.state_dict(), f)já que você lida com a criação do modelo, e a tocha lida com o carregamento dos pesos do modelo, eliminando possíveis problemas. Referência: discuss.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

Parece que o PyTorch abordou isso de forma um pouco mais explícita em sua seção de tutoriais - há muitas informações boas que não estão listadas nas respostas aqui, incluindo salvar mais de um modelo de cada vez e modelos de inicialização quentes.
whlteXbread

o que há de errado em usar pickle?
Charlie Parker

1
@CharlieParker torch.save é baseado em picles. A seguir, é apresentado o tutorial vinculado acima: "[torch.save] salvará o módulo inteiro usando o módulo pickle do Python. A desvantagem dessa abordagem é que os dados serializados são vinculados às classes específicas e à estrutura de diretórios exata usada quando o modelo O motivo é que pickle não salva a própria classe do modelo. Em vez disso, salva um caminho para o arquivo que contém a classe, que é usada durante o tempo de carregamento. Por isso, seu código pode ser quebrado de várias maneiras quando usado em outros projetos ou após refatoradores ".
David Miller

Respostas:


214

Encontrei esta página no repositório do github, colarei o conteúdo aqui.


Abordagem recomendada para salvar um modelo

Existem duas abordagens principais para serializar e restaurar um modelo.

O primeiro (recomendado) salva e carrega apenas os parâmetros do modelo:

torch.save(the_model.state_dict(), PATH)

Depois, mais tarde:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

O segundo salva e carrega o modelo inteiro:

torch.save(the_model, PATH)

Depois, mais tarde:

the_model = torch.load(PATH)

No entanto, nesse caso, os dados serializados são vinculados às classes específicas e à estrutura de diretórios exata usada, para que possam ser quebrados de várias maneiras quando usados ​​em outros projetos ou após alguns refatores sérios.


8
De acordo com @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/…, o modelo é recarregado para treinar o modelo por padrão. portanto, é necessário chamar manualmente the_model.eval () após o carregamento, se você estiver carregando por inferência, não retomando o treinamento.
WillZ

o segundo método dá stackoverflow.com/questions/53798009/... erro no Windows 10. não foi capaz de resolvê-lo
Gulzar

Existe alguma opção para salvar sem a necessidade de acesso à classe de modelo?
Michael D

Com essa abordagem, como você monitora os * args e ** kwargs que precisa transmitir para o caso de carga?
Mariano Kamp

o que há de errado em usar pickle?
Charlie Parker

144

Depende do que você quer fazer.

Caso nº 1: salve o modelo para usá-lo por inferência : salve o modelo, restaure-o e altere o modelo para o modo de avaliação. Isso é feito porque você geralmente tem BatchNorme Dropoutcamadas que por padrão estão no modo de trem na construção:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso nº 2: salve o modelo para continuar o treinamento mais tarde : se você precisar continuar treinando o modelo que está prestes a salvar, precisará salvar mais do que apenas o modelo. Você também precisa salvar o estado do otimizador, épocas, pontuação etc. Você faria assim:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Para retomar o treinamento, você faria coisas como: state = torch.load(filepath)e, em seguida, para restaurar o estado de cada objeto individual, algo como isto:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Como você está retomando o treinamento, NÃO ligue model.eval()depois de restaurar os estados ao carregar.

Caso nº 3: Modelo a ser usado por outra pessoa sem acesso ao seu código : No Tensorflow, você pode criar um .pbarquivo que define a arquitetura e os pesos do modelo. Isso é muito útil, especialmente ao usar Tensorflow serve. A maneira equivalente de fazer isso no Pytorch seria:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Dessa forma, ainda não é à prova de balas e, como o pytorch ainda está passando por muitas alterações, eu não o recomendaria.


1
Existe um arquivo recomendado para os três casos? Ou é sempre .pth?
Verena Haunschmid 12/02/19

1
No Caso 3, torch.loadretorna apenas um OrderedDict. Como você obtém o modelo para fazer previsões?
Alber8295

Olá, posso saber como fazer o mencionado "Caso 2: salvar modelo para retomar o treinamento mais tarde"? Eu consegui carregar o ponto de verificação para modelo, então eu não conseguir executar ou retomar o modelo de trem como "modelo model.to (dispositivo) = train_model_epoch (modelo, critério, otimizador, sched, épocas)"
dnez

1
Olá, no caso de uma inferência, no documento oficial do pytorch, diga que deve salvar o otimizador state_dict para inferência ou para concluir o treinamento. "Ao salvar um ponto de verificação geral, para ser usado para inferir ou retomar o treinamento, você deve salvar mais do que apenas o estado_dict do modelo. É importante salvar também o estado_dict do otimizador, pois ele contém buffers e parâmetros que são atualizados à medida que o modelo treina . "
Mohammed Awney

1
No caso 3, a classe do modelo deve ser definida em algum lugar.
Michael D

12

A biblioteca pickle Python implementa protocolos binários para serializar e desserializar um objeto Python.

Quando você import torch(ou quando você usa PyTorch) que vai import picklepara você e você não precisa chamar pickle.dump()e pickle.load()diretamente, o que são os métodos para salvar e carregar o objeto.

Na verdade, torch.save()e torch.load()vai embrulhar pickle.dump()e pickle.load()para você.

A state_dictoutra resposta mencionada merece apenas mais algumas notas.

O state_dictque temos dentro do PyTorch? Na verdade, existem dois state_dicts.

O modelo PyTorch é torch.nn.Moduletem model.parameters()chamada para obter parâmetros learnable (w eb). Esses parâmetros aprendíveis, uma vez definidos aleatoriamente, serão atualizados ao longo do tempo à medida que aprendemos. Parâmetros aprendíveis são os primeiros state_dict.

O segundo state_dicté o ditado de estado do otimizador. Você se lembra que o otimizador é usado para melhorar nossos parâmetros aprendíveis. Mas o otimizador state_dicté fixo. Nada a aprender lá.

Como os state_dictobjetos são dicionários Python, eles podem ser facilmente salvos, atualizados, alterados e restaurados, adicionando uma grande modularidade aos modelos e otimizadores do PyTorch.

Vamos criar um modelo super simples para explicar isso:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Este código produzirá o seguinte:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Observe que este é um modelo mínimo. Você pode tentar adicionar uma pilha de

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Observe que apenas as camadas com parâmetros que podem ser aprendidos (camadas convolucionais, camadas lineares etc.) e buffers registrados (camadas batchnorm) têm entradas nos modelos state_dict.

Coisas não aprendidas pertencem ao objeto otimizador state_dict, que contém informações sobre o estado do otimizador, bem como os hiperparâmetros usados.

O resto da história é o mesmo; na fase de inferência (essa é uma fase em que usamos o modelo após o treinamento) para prever; nós previmos com base nos parâmetros que aprendemos. Portanto, para a inferência, só precisamos salvar os parâmetros model.state_dict().

torch.save(model.state_dict(), filepath)

E para usar mais tarde model.load_state_dict (torch.load (filepath)) model.eval ()

Nota: Não esqueça a última linha, model.eval()isto é crucial após o carregamento do modelo.

Também não tente salvar torch.save(model.parameters(), filepath). O model.parameters()é apenas o objeto gerador.

Por outro lado, torch.save(model, filepath)salva o próprio objeto do modelo, mas lembre-se de que o modelo não possui o otimizador state_dict. Verifique a outra excelente resposta de @Jadiel de Armas para salvar o ditado de estado do otimizador.


Embora não seja uma solução direta, a essência do problema é profundamente analisada! Voto a favor.
Jason Young

7

Uma convenção comum do PyTorch é salvar modelos usando uma extensão de arquivo .pt ou .pth.

Salvar / carregar modelo inteiro Salvar:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Carga:

A classe do modelo deve ser definida em algum lugar

model = torch.load(PATH)
model.eval()

4

Se você deseja salvar o modelo e deseja retomar o treinamento posteriormente:

GPU única: Salvar:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Carga:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

GPU múltipla: Salvar

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Carga:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.