sklearn
os estimadores implementam métodos para facilitar o salvamento de propriedades treinadas relevantes de um estimador. Alguns estimadores implementam __getstate__
métodos eles mesmos, mas outros, como o GMM
apenas usam a implementação base, que simplesmente salva o dicionário interno dos objetos:
def __getstate__(self):
try:
state = super(BaseEstimator, self).__getstate__()
except AttributeError:
state = self.__dict__.copy()
if type(self).__module__.startswith('sklearn.'):
return dict(state.items(), _sklearn_version=__version__)
else:
return state
O método recomendado para salvar seu modelo em disco é usar o pickle
módulo:
from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
pickle.dump(model,f)
No entanto, você deve salvar dados adicionais para treinar novamente o seu modelo no futuro ou sofrer conseqüências terríveis (como ficar preso a uma versão antiga do sklearn) .
A partir da documentação :
Para reconstruir um modelo semelhante com versões futuras do scikit-learn, metadados adicionais devem ser salvos ao longo do modelo em vinagre:
Os dados de treinamento, por exemplo, uma referência a um instantâneo imutável
O código-fonte python usado para gerar o modelo
As versões do scikit-learn e suas dependências
A pontuação de validação cruzada obtida nos dados de treinamento
Isso é especialmente verdadeiro para os estimadores do Ensemble que dependem do tree.pyx
módulo escrito em Cython (como IsolationForest
), uma vez que ele cria um acoplamento à implementação, que não garante a estabilidade entre as versões do sklearn. Viu mudanças incompatíveis com versões anteriores no passado.
Se seus modelos se tornarem muito grandes e o carregamento se tornar um incômodo, você também poderá usar os mais eficientes joblib
. A partir da documentação:
No caso específico do scikit, pode ser mais interessante usar a substituição do joblib pelo pickle
( joblib.dump
& joblib.load
), que é mais eficiente em objetos que transportam matrizes numpy grandes internamente, como costuma ser o caso de estimadores de aprendizado do scikit adequados, mas pode apenas decapagem para o disco e não para uma string: