Previsão Probabilística Aleatória da Floresta vs voto da maioria


10

O aprendizado do Scikit parece usar predição probabilística, em vez de voto majoritário, para a técnica de agregação de modelos, sem uma explicação do porquê (1.9.2.1. Florestas aleatórias).

Existe uma explicação clara para o porquê? Além disso, existe um bom artigo em papel ou de revisão para as várias técnicas de agregação de modelos que podem ser usadas para ensacamento de floresta aleatória?

Obrigado!

Respostas:


10

Sempre é melhor responder a essas perguntas olhando o código, se você é fluente em Python.

RandomForestClassifier.predict, pelo menos na versão atual 0.16.1, prevê a classe com maior estimativa de probabilidade, conforme fornecido por predict_proba. ( esta linha )

A documentação para predict_probadiz:

As probabilidades de classe previstas de uma amostra de entrada são calculadas como as probabilidades médias de classe previstas das árvores na floresta. A probabilidade de classe de uma única árvore é a fração de amostras da mesma classe em uma folha.

A diferença do método original é provavelmente apenas para predictfornecer previsões consistentes predict_proba. O resultado às vezes é chamado de "votação branda", em vez do voto majoritário "rígido" usado no documento original de Breiman. Na pesquisa rápida, não consegui encontrar uma comparação apropriada do desempenho dos dois métodos, mas ambos parecem razoáveis ​​nessa situação.

A predictdocumentação é, na melhor das hipóteses, bastante enganadora; Enviei uma solicitação pull para corrigi-la.

Se você deseja fazer a predição por maioria dos votos, aqui está uma função para fazê-lo. Chame como predict_majvote(clf, X)antes clf.predict(X). (Baseado em predict_proba; apenas levemente testado, mas acho que deve funcionar.)

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

No caso sintético idiota que eu tentei, as previsões concordavam com o predictmétodo todas as vezes.


Ótima resposta, Dougal! Obrigado por reservar um tempo para explicar isso com cuidado. Por favor, considere também examinar o estouro da pilha e responder a essa pergunta lá .
usar o seguinte comando

11
Há também um artigo aqui , que aborda a previsão probabilística.
usar o seguinte comando
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.