Parâmetro “estratificar” do método “train_test_split” (scikit Learn)


96

Estou tentando usar train_test_splitdo pacote scikit Learn, mas estou tendo problemas com o parâmetro stratify. A seguir está o código:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

No entanto, continuo tendo o seguinte problema:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

Alguém tem ideia do que está acontecendo? Abaixo está a documentação da função.

[...]

stratify : array-like ou None (o padrão é None)

Se não for nenhum, os dados são divididos de forma estratificada, usando isso como a matriz de rótulos.

Novo na versão 0.17: divisão estratificada

[...]


Não, tudo resolvido.
Daneel Olivaw

Respostas:


62

O Scikit-Learn está apenas dizendo a você que não reconhece o argumento "estratificar", não que você o esteja usando incorretamente. Isso ocorre porque o parâmetro foi adicionado na versão 0.17 conforme indicado na documentação que você citou.

Então você só precisa atualizar o Scikit-Learn.


Estou recebendo o mesmo erro, embora tenha a versão 0.21.2 do scikit-learn. scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forge
Kareem Jeiroudi

338

Este stratifyparâmetro faz uma divisão de forma que a proporção dos valores na amostra produzida seja a mesma que a proporção dos valores fornecidos ao parâmetro stratify.

Por exemplo, se a variável yé uma variável categórica binário com valores 0e 1e há 25% de zeros e 75% dos queridos, stratify=yvai se certificar de que a sua divisão aleatória tem 25% de 0's e 75% de 1' s.


124
Isso realmente não responde à pergunta, mas é muito útil apenas para entender como funciona. Muito obrigado.
Reed Jessen

6
Ainda tenho dificuldade em entender por que essa estratificação é necessária: se houver classe em equilíbrio nos dados, ela não seria preservada em média ao fazer uma divisão aleatória dos dados?
Holger Brandl

15
@HolgerBrandl será preservado em média; com estratificar, ele será preservado com certeza.
Yonatan

7
@HolgerBrandl com conjuntos de dados muito pequenos ou muito desequilibrados, é bem possível que a divisão aleatória possa eliminar completamente uma classe de uma das divisões.
cddt

1
@HolgerBrandl Boa pergunta! Talvez pudéssemos adicionar isso primeiro, você tem que dividir em treinamento e conjunto de teste usando stratify. Em segundo lugar, para corrigir o desequilíbrio, você eventualmente precisa executar a sobreamostragem ou subamostragem no conjunto de treinamento. Muitos classificadores Sklearn têm um parâmetro chamado peso da classe que você pode definir como balanceado. Finalmente, você também pode usar uma métrica mais apropriada do que precisão para conjunto de dados desequilibrado. Tente F1 ou área sob ROC.
Claude COULOMBE

64

Para meu futuro eu, que vem aqui via Google:

train_test_splitestá agora dentro model_selection, portanto:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

é a maneira de usá-lo. A configuração de random_stateé desejável para reprodutibilidade.


Esta deve ser a resposta :) Obrigado
SwimBikeRun

16

Nesse contexto, a estratificação significa que o método train_test_split retorna subconjuntos de treinamento e teste que têm as mesmas proporções de rótulos de classe que o conjunto de dados de entrada.


3

Tente executar este código, ele "simplesmente funciona":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])

@ user5767535 Como você pode ver, está funcionando na minha máquina Ubuntu, com sklearna versão '0,17', distribuição Anaconda para Python 3,5. Só posso sugerir que verifique mais uma vez se você inserir o código corretamente e atualize seu software.
Sergey Bushmanov

2
@ user5767535 BTW, "Novo na versão 0.17: divisão estratificada" me dá quase certeza de que você precisa atualizar seu sklearn...
Sergey Bushmanov
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.