Como funciona o parâmetro class_weight no scikit-learn?


115

Estou tendo muitos problemas para entender como funciona o class_weightparâmetro na regressão logística do scikit-learn.

A situação

Quero usar a regressão logística para fazer a classificação binária em um conjunto de dados muito desequilibrado. As classes são rotuladas 0 (negativo) e 1 (positivo) e os dados observados estão em uma proporção de cerca de 19: 1 com a maioria das amostras tendo resultado negativo.

Primeira tentativa: preparando manualmente os dados de treinamento

Divido os dados que tinha em conjuntos separados para treinamento e teste (cerca de 80/20). Em seguida, fiz uma amostra aleatória dos dados de treinamento à mão para obter dados de treinamento em proporções diferentes de 19: 1; de 2: 1 -> 16: 1.

Em seguida, treinei a regressão logística nesses diferentes subconjuntos de dados de treinamento e a recordação plotada (= TP / (TP + FN)) como uma função das diferentes proporções de treinamento. Claro, o recall foi calculado nas amostras de TESTE disjuntas que tinham as proporções observadas de 19: 1. Observe que, embora eu tenha treinado os diferentes modelos em dados de treinamento diferentes, calculei a recuperação de todos eles nos mesmos dados de teste (separados).

Os resultados foram os esperados: o recall foi de cerca de 60% nas proporções de treinamento de 2: 1 e caiu bem rápido quando chegou a 16: 1. Havia várias proporções 2: 1 -> 6: 1 onde o recall estava decentemente acima de 5%.

Segunda tentativa: Pesquisa de grade

Em seguida, eu queria testar diferentes parâmetros de regularização e então usei GridSearchCV e fiz uma grade de vários valores do Cparâmetro, bem como do class_weightparâmetro. Para traduzir minhas proporções n: m de amostras de treinamento negativo: positivo para a linguagem do dicionário, class_weightpensei que apenas especificaria vários dicionários da seguinte forma:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

e eu também incluí Nonee auto.

Desta vez, os resultados foram totalmente malucos. Todas as minhas recuperações foram mínimas (<0,05) para cada valor de class_weightexceto auto. Portanto, só posso supor que meu entendimento de como definir o class_weightdicionário está errado. Curiosamente, o class_weightvalor de 'auto' na pesquisa de grade foi em torno de 59% para todos os valores de C, e imaginei que equilibra para 1: 1?

Minhas perguntas

  1. Como você usa adequadamente class_weightpara obter balanços diferentes nos dados de treinamento do que você realmente fornece? Especificamente, que dicionário devo class_weightusar para usar proporções n: m de amostras de treinamento negativas: positivas?

  2. Se você passar vários class_weightdicionários para o GridSearchCV, durante a validação cruzada ele irá reequilibrar os dados da dobra de treinamento de acordo com o dicionário, mas usar as verdadeiras proporções de amostra fornecidas para calcular minha função de pontuação na dobra de teste? Isso é crítico, pois qualquer métrica só é útil para mim se vier de dados nas proporções observadas.

  3. O que o autovalor de class_weightfaz em relação às proporções? Eu li a documentação e presumo que "equilibra os dados inversamente proporcionais à sua frequência" significa apenas 1: 1. Isso está correto? Se não, alguém pode esclarecer?


Quando se usa class_weight, a função de perda é modificada. Por exemplo, em vez de entropia cruzada, ela se torna entropia cruzada ponderada. paradatascience.com/…
prashanth

Respostas:


123

Em primeiro lugar, pode não ser bom ir apenas por recall. Você pode simplesmente atingir um recall de 100% classificando tudo como a classe positiva. Normalmente sugiro usar AUC para selecionar parâmetros e, em seguida, encontrar um limite para o ponto operacional (digamos, um determinado nível de precisão) no qual você está interessado.

Para saber como class_weightfunciona: Penaliza erros em amostras de class[i]com em class_weight[i]vez de 1. Portanto, maior peso da classe significa que você deseja colocar mais ênfase em uma classe. Pelo que você disse, parece que a classe 0 é 19 vezes mais frequente do que a classe 1. Portanto, você deve aumentar o class_weightda classe 1 em relação à classe 0, digamos {0: .1, 1: .9}. Se class_weightnão somar 1, basicamente mudará o parâmetro de regularização.

Para saber como class_weight="auto"funciona, você pode dar uma olhada nesta discussão . Na versão dev você pode usar class_weight="balanced", o que é mais fácil de entender: basicamente significa replicar a classe menor até que você tenha tantas amostras quanto na classe maior, mas de forma implícita.


1
Obrigado! Pergunta rápida: mencionei a lembrança para maior clareza e, na verdade, estou tentando decidir qual AUC usar como minha medida. Meu entendimento é que eu deveria maximizar a área sob a curva ROC ou a área sob recall vs. curva de precisão para encontrar parâmetros. Depois de escolher os parâmetros dessa forma, acredito que escolhi o limite para classificação deslizando ao longo da curva. É isso que você queria dizer? Em caso afirmativo, qual das duas curvas faz mais sentido observar se meu objetivo é capturar o máximo possível de TPs? Além disso, obrigado por seu trabalho e contribuições para o scikit-learn !!!
kilgoretrout de

1
Acho que usar ROC seria o caminho mais padrão a seguir, mas não acho que haverá uma grande diferença. Você precisa de algum critério para escolher o ponto na curva, no entanto.
Andreas Mueller

3
@MiNdFrEaK Acho que o que Andrew quer dizer é que o estimador replica amostras na classe minoritária, de modo que as amostras de classes diferentes são balanceadas. É apenas uma sobreamostragem de uma forma implícita.
Shawn TIAN

8
@MiNdFrEaK e Shawn Tian: classificadores baseados em SV não produzem mais amostras das classes menores quando você usa 'balanceado'. Literalmente penaliza os erros cometidos nas classes menores. Dizer o contrário é um erro e é enganoso, especialmente em grandes conjuntos de dados, quando você não pode se dar ao luxo de criar mais amostras. Esta resposta deve ser editada.
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weight Os pesos das classes serão usados ​​de forma diferente dependendo do algoritmo: para modelos lineares (como SVM linear ou regressão logística), os pesos das classes irão alterar a função de perda por ponderando a perda de cada amostra pelo seu peso de classe. Para algoritmos baseados em árvore, os pesos das classes serão usados ​​para reponderar o critério de divisão. Observe, entretanto, que este rebalanceamento não leva em consideração o peso das amostras em cada classe.
prashanth de

2

A primeira resposta é boa para entender como funciona. Mas eu queria entender como deveria usá-lo na prática.

RESUMO

  • para dados moderadamente desequilibrados SEM ruído, não há muita diferença na aplicação de pesos de classe
  • para dados moderadamente desequilibrados COM ruído e fortemente desequilibrados, é melhor aplicar pesos de classe
  • param class_weight="balanced"funciona decentemente na ausência de você querer otimizar manualmente
  • com class_weight="balanced"você captura mais eventos verdadeiros (maior TRUE recall), mas também é mais provável que você receba alertas falsos (menor precisão TRUE)
    • como resultado, o% TRUE total pode ser maior do que o real por causa de todos os falsos positivos
    • O AUC pode te enganar aqui se os alarmes falsos forem um problema
  • não há necessidade de alterar o limite de decisão para a% de desequilíbrio, mesmo para desequilíbrio forte, ok para manter 0,5 (ou algo em torno disso, dependendo do que você precisa)

NB

O resultado pode ser diferente ao usar RF ou GBM. sklearn não tem class_weight="balanced" para GBM, mas lightgbm temLGBMClassifier(is_unbalance=False)

CÓDIGO

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.