Como funciona o python numpy.where ()?


94

Estou brincando numpye vasculhando documentação e descobri alguma mágica. A saber, estou falando sobre numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

Como eles conseguem internamente que você seja capaz de passar algo semelhante x > 5a um método? Acho que tem a ver com isso, __gt__mas estou procurando uma explicação detalhada.

Respostas:


75

Como eles conseguem internamente que você seja capaz de passar algo como x> 5 em um método?

A resposta curta é que não.

Qualquer tipo de operação lógica em uma matriz numpy retorna uma matriz booleana. (ou seja __gt__, __lt__etc, todos retornam matrizes booleanas onde a condição fornecida é verdadeira).

Por exemplo

x = np.arange(9).reshape(3,3)
print x > 5

rendimentos:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

Este é o mesmo motivo pelo qual algo como if x > 5:gera um ValueError se xfor um array numpy. É uma matriz de valores True / False, não um único valor.

Além disso, matrizes numpy podem ser indexadas por matrizes booleanas. Por exemplo , x[x>5]rendimentos [6 7 8], neste caso.

Honestamente, é bastante raro que você realmente precise, numpy.wheremas ele apenas retorna os indicies onde está um array booleano True. Normalmente, você pode fazer o que precisa com a indexação booleana simples.


10
Apenas para apontar que numpy.wheretem 2 'modos operacionais', primeiro um retorna o indices, onde condition is Truee se os parâmetros opcionais xe yestão presentes (o mesmo formato condition, ou transmitido para tal formato!), Ele retornará valores de xquando de condition is Trueoutra forma y. Portanto, isso torna wheremais versátil e permite que seja usado com mais frequência. Obrigado
coma em

1
Também pode haver sobrecarga em alguns casos, usando a __getitem__sintaxe []over numpy.whereou numpy.take. Como __getitem__também precisa oferecer suporte ao fatiamento, há alguma sobrecarga. Tenho visto diferenças de velocidade perceptíveis ao trabalhar com as estruturas de dados Python Pandas e indexar colunas muito grandes de forma lógica. Nesses casos, se você não precisa fatiar, então takee whereé realmente melhor.
ely

24

Resposta Antiga é meio confusa. Dá a você os LOCAIS (todos eles) de onde sua declaração é verdadeira.

tão:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

Eu o uso como uma alternativa para list.index (), mas ele também tem muitos outros usos. Nunca o usei com matrizes 2D.

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

Nova resposta Parece que a pessoa estava perguntando algo mais fundamental.

A questão era como VOCÊ poderia implementar algo que permita a uma função (como onde) saber o que foi solicitado.

Primeiro, observe que chamar qualquer um dos operadores de comparação é algo interessante.

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

Isso é feito sobrecarregando o método "__gt__". Por exemplo:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

Como você pode ver, "a> 4" era um código válido.

Você pode obter uma lista completa e documentação de todas as funções sobrecarregadas aqui: http://docs.python.org/reference/datamodel.html

Algo incrível é como é simples fazer isso. TODAS as operações em python são feitas dessa forma. Dizer a> b é equivalente a a. gt (b)!


3
Essa sobrecarga de operador de comparação não parece funcionar bem com expressões lógicas mais complexas - por exemplo, eu não posso fazer np.where(a > 30 and a < 50)ou np.where(30 < a < 50)porque acaba tentando avaliar o AND lógico de dois arrays de booleanos, o que não faz sentido. Existe uma maneira de escrever tal condição com np.where?
davidA

@meowsqueaknp.where((a > 30) & (a < 50))
tibalt

Por que np.where () está retornando uma lista em seu exemplo?
Andreas Yankopolus

0

np.whereretorna uma tupla de comprimento igual à dimensão do ndarray numpy no qual é chamado (em outras palavras ndim) e cada item da tupla é um ndarray numpy de índices de todos aqueles valores no ndarray inicial para o qual a condição é True. (Por favor, não confunda dimensão com forma)

Por exemplo:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y é uma tupla de comprimento 2 porque x.ndimé 2. O primeiro item na tupla contém números de linha de todos os elementos maiores que 4 e o segundo item contém números de coluna de todos os itens maiores que 4. Como você pode ver, [1,2,2 , 2] corresponde a números de linha de 5,6,7,8 e [2,0,1,2] corresponde a números de coluna de 5,6,7,8 Observe que o ndarray é percorrido ao longo da primeira dimensão (em linha )

Similarmente,

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


retornará uma tupla de comprimento 3 porque x tem 3 dimensões.

Mas espere, há mais para np.where!

quando dois argumentos adicionais são adicionados a np.where; ele fará uma operação de substituição para todas aquelas combinações linha-coluna emparelhadas que são obtidas pela tupla acima.

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.