Salvar o classificador em disco no scikit-learn


191

Como salvar um classificador treinado Naive Bayes em disco e usá-lo para prever dados?

Eu tenho o seguinte programa de amostra no site scikit-learn:

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

Respostas:


201

Classificadores são apenas objetos que podem ser decapados e descartados como qualquer outro. Para continuar seu exemplo:

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)

1
Funciona como um encanto! Eu estava tentando usar o np.savez e carregá-lo de volta o tempo todo e isso nunca ajudou. Muito obrigado.
Kartos

7
em python3, use o módulo pickle, que funciona exatamente assim.
MCSH 25/11

212

Você também pode usar joblib.dump e joblib.load, que são muito mais eficientes no tratamento de matrizes numéricas do que o seletor python padrão.

O joblib está incluído no scikit-learn:

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

Edit: no Python 3.8+, agora é possível usar pickle para pickling eficiente de objetos com grandes matrizes numéricas como atributos se você usar o protocolo 5 de pickle (que não é o padrão).


1
Mas, pelo meu entendimento, o pipelining funciona se fizer parte de um único fluxo de trabalho. Se eu quiser construir o modelo, armazene-o no disco e interrompa a execução por lá. Então eu voltar uma semana mais tarde e tentar carregar o modelo do disco isto me lança um erro:
venuktan

2
Não há como parar e retomar a execução do fitmétodo se é isso que você está procurando. Dito isto, joblib.loadnão deve gerar uma exceção após um êxito, joblib.dumpse você a chamar de um Python com a mesma versão da biblioteca scikit-learn.
ogrisel

10
Se você estiver usando IPython, não use o --pylabsinalizador de linha de comando ou a %pylabmágica, pois a sobrecarga implícita de espaço para nome é conhecida por interromper o processo de decapagem. Use importações explícitas e a %matplotlib inlinemágica.
ogrisel

2
consulte a documentação do scikit
user1448319

1
É possível treinar novamente o modelo salvo anteriormente? Modelos SVC especificamente?
Uday Sawant

108

O que você está procurando é chamado de persistência do modelo no sklearn words e está documentado nas seções introdução e persistência do modelo .

Então você inicializou seu classificador e o treinou por um longo tempo com

clf = some.classifier()
clf.fit(X, y)

Depois disso, você tem duas opções:

1) Usando Pickle

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2) Usando o Joblib

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

Mais uma vez, é útil ler os links mencionados acima


30

Em muitos casos, principalmente na classificação de texto, não basta armazenar o classificador, mas você também precisará armazenar o vetorizador para poder vetorizar sua entrada no futuro.

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

caso de uso futuro:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

Antes de descarregar o vetorizador, é possível excluir a propriedade stop_words_ do vetorizador:

vectorizer.stop_words_ = None

para tornar o dumping mais eficiente. Além disso, se os parâmetros do seu classificador forem escassos (como na maioria dos exemplos de classificação de texto), você poderá converter os parâmetros de denso para esparso, o que fará uma enorme diferença em termos de consumo de memória, carregamento e descarte. Sparsify o modelo por:

clf.sparsify()

O que funcionará automaticamente para SGDClassifier, mas, se você souber que seu modelo é escasso (muitos zeros em clf.coef_), poderá converter manualmente o clf.coef_ em uma matriz esparsa csr scipy :

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

e então você pode armazená-lo com mais eficiência.


Resposta perspicaz! Só queria adicionar no caso de SVC, ele retorna um parâmetro de modelo esparso.
Shayan Amani

4

sklearnos estimadores implementam métodos para facilitar o salvamento de propriedades treinadas relevantes de um estimador. Alguns estimadores implementam __getstate__métodos eles mesmos, mas outros, como o GMMapenas usam a implementação base, que simplesmente salva o dicionário interno dos objetos:

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

O método recomendado para salvar seu modelo em disco é usar o picklemódulo:

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

No entanto, você deve salvar dados adicionais para treinar novamente o seu modelo no futuro ou sofrer conseqüências terríveis (como ficar preso a uma versão antiga do sklearn) .

A partir da documentação :

Para reconstruir um modelo semelhante com versões futuras do scikit-learn, metadados adicionais devem ser salvos ao longo do modelo em vinagre:

Os dados de treinamento, por exemplo, uma referência a um instantâneo imutável

O código-fonte python usado para gerar o modelo

As versões do scikit-learn e suas dependências

A pontuação de validação cruzada obtida nos dados de treinamento

Isso é especialmente verdadeiro para os estimadores do Ensemble que dependem do tree.pyxmódulo escrito em Cython (como IsolationForest), uma vez que ele cria um acoplamento à implementação, que não garante a estabilidade entre as versões do sklearn. Viu mudanças incompatíveis com versões anteriores no passado.

Se seus modelos se tornarem muito grandes e o carregamento se tornar um incômodo, você também poderá usar os mais eficientes joblib. A partir da documentação:

No caso específico do scikit, pode ser mais interessante usar a substituição do joblib pelo pickle( joblib.dump& joblib.load), que é mais eficiente em objetos que transportam matrizes numpy grandes internamente, como costuma ser o caso de estimadores de aprendizado do scikit adequados, mas pode apenas decapagem para o disco e não para uma string:


1
but can only pickle to the disk and not to a stringMas você pode incluir isso no StringIO no joblib. É isso que faço o tempo todo.
Matthew Matthew

1

sklearn.externals.joblibfoi descontinuado desde então 0.21e será removido em v0.23:

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/ init .py: 15: FutureWarning: sklearn.externals.joblib foi descontinuado em 0,21 e será removido em 0,23. Importe essa funcionalidade diretamente do joblib, que pode ser instalado com: pip install joblib. Se esse aviso for acionado ao carregar modelos em pickled, talvez seja necessário serializar novamente esses modelos com o scikit-learn 0.21+.
warnings.warn (msg, categoria = FutureWarning)


Portanto, você precisa instalar joblib:

pip install joblib

e finalmente escreva o modelo no disco:

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

Agora, para ler o arquivo despejado, tudo o que você precisa executar é:

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.