Entendendo o einsum de NumPy


190

Estou lutando para entender exatamente como einsum funciona. Eu olhei para a documentação e alguns exemplos, mas parece que não fica.

Aqui está um exemplo que abordamos na aula:

C = np.einsum("ij,jk->ki", A, B)

para duas matrizes AeB

Eu acho que isso levaria A^T * B, mas não tenho certeza (está aceitando a transposição de um deles, certo?). Alguém pode me explicar exatamente o que está acontecendo aqui (e em geral ao usar einsum)?


7
Na verdade será (A * B)^T, ou equivalente B^T * A^T.
Tigran Saluev

20
Eu escrevi um pequeno post sobre o básico einsum aqui . (Estou feliz em transplantar os bits mais relevantes para uma resposta no Stack Overflow, se útil).
Alex Riley

1
@ajcr - link bonito. Obrigado. A numpydocumentação é lamentavelmente inadequada ao explicar os detalhes.
rayryeng

Obrigado pelo voto de confiança! Tardiamente, contribuí com uma resposta abaixo .
Alex Riley

Note que no Python *não é multiplicação de matrizes, mas multiplicação por elementos. Cuidado!
ComputerScientist

Respostas:


368

(Nota: esta resposta é baseada em uma pequena postagem no blog sobreeinsum que escrevi há um tempo.)

O que einsum faz?

Imagine que temos duas matrizes multidimensionais Ae B. Agora vamos supor que queremos ...

  • multiplique A comB de um modo particular para criar uma nova gama de produtos; e então talvez
  • soma essa nova matriz ao longo de eixos específicos; e então talvez
  • transponha os eixos da nova matriz em uma ordem específica.

Há uma boa chance de que einsumisso nos ajude a fazer isso mais rapidamente e com mais eficiência de memória que as combinações do NumPy funcionam multiply, sume transposepermitirão.

Como einsumfunciona?

Aqui está um exemplo simples (mas não completamente trivial). Pegue as duas matrizes a seguir:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

Vamos multiplicar Ae Belemento-wise e depois somar ao longo das linhas da nova matriz. Em NumPy "normal", escreveríamos:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

Então aqui, a operação de indexação em A alinha os primeiros eixos das duas matrizes para que a multiplicação possa ser transmitida. As linhas da matriz de produtos são somadas para retornar a resposta.

Agora, se quiséssemos usar einsum, poderíamos escrever:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

A string de assinatura'i,ij->i' é a chave aqui e precisa de um pouco de explicação. Você pode pensar nisso em duas partes. No lado esquerdo (à esquerda da ->), rotulamos as duas matrizes de entrada. À direita ->, rotulamos a matriz com a qual queremos terminar.

Aqui está o que acontece a seguir:

  • Atem um eixo; nós o rotulamos i. E Btem dois eixos; rotulamos o eixo 0 como ie o eixo 1 como j.

  • Ao repetir o rótulo iem ambas as matrizes de entrada, estamos dizendo einsumque estes dois eixos devem ser multiplicados juntos. Em outras palavras, estamos multiplicando a matriz Acom cada coluna da matriz B, assim como A[:, np.newaxis] * Bfaz.

  • Observe que jnão aparece como um rótulo em nossa saída desejada; acabamos de usar i(queremos terminar com uma matriz 1D). Ao omitir o rótulo, estamos dizendo einsumpara somar ao longo deste eixo. Em outras palavras, estamos somando as linhas dos produtos, assim como .sum(axis=1)faz.

Isso é basicamente tudo que você precisa saber para usar einsum. Ajuda a brincar um pouco; se deixarmos os dois rótulos na saída, 'i,ij->ij'obteremos uma matriz de produtos 2D (igual a A[:, np.newaxis] * B). Se dissermos que não há etiquetas de saída, 'i,ij->retornamos um único número (o mesmo que fazer (A[:, np.newaxis] * B).sum()).

O melhor de tudo einsum, porém, é que isso não cria primeiro uma matriz temporária de produtos; apenas soma os produtos como vai. Isso pode levar a grandes economias no uso da memória.

Um exemplo um pouco maior

Para explicar o produto escalar, aqui estão duas novas matrizes:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

Vamos calcular o produto escalar usando np.einsum('ij,jk->ik', A, B). Aqui está uma foto mostrando a rotulagem do Ae Be a matriz de saída que começa a partir da função:

insira a descrição da imagem aqui

Você pode ver que o rótulo jé repetido. Isso significa que multiplicamos as linhas Acom as colunas de B. Além disso, o rótulo jnão está incluído na saída - estamos somando esses produtos. Etiquetas iek são mantidos para a saída, por isso recuperamos uma matriz 2D.

Pode ser ainda mais claro comparar esse resultado com a matriz em que o rótulo nãoj é somado. Abaixo, à esquerda, você pode ver a matriz 3D resultante da escrita (ou seja, mantemos o rótulo ):np.einsum('ij,jk->ijk', A, B)j

insira a descrição da imagem aqui

O eixo somador jfornece o produto de ponto esperado, mostrado à direita.

Alguns exercícios

Para entender melhor einsum, pode ser útil implementar operações familiares de array NumPy usando a notação subscrita. Qualquer coisa que envolva combinações de eixos multiplicadores e somatórios pode ser escrita usando einsum .

Sejam A e B duas matrizes 1D com o mesmo comprimento. Por exemplo, A = np.arange(10)e B = np.arange(5, 15).

  • A soma de Apode ser escrita:

    np.einsum('i->', A)
  • A multiplicação por elementos A * B, pode ser escrita:

    np.einsum('i,i->i', A, B)
  • O produto interno ou produto de ponto, np.inner(A, B)ou np.dot(A, B), pode ser escrito:

    np.einsum('i,i->', A, B) # or just use 'i,i'
  • O produto externo,, np.outer(A, B)pode ser escrito:

    np.einsum('i,j->ij', A, B)

Para matrizes 2D Ce D, desde que os eixos tenham comprimentos compatíveis (ambos com o mesmo comprimento ou um deles com comprimento 1), aqui estão alguns exemplos:

  • O traço de C(soma da diagonal principal) np.trace(C), pode ser escrito:

    np.einsum('ii', C)
  • Multiplicação elemento-wise de Ce a transposta de D, C * D.T, pode ser escrito:

    np.einsum('ij,ji->ij', C, D)
  • Multiplicando cada elemento Cpela matriz D(para criar uma matriz 4D) C[:, :, None, None] * D, pode ser escrito:

    np.einsum('ij,kl->ijkl', C, D)  

1
Muito boa explicação, obrigado. "Observe que eu não aparece como um rótulo na saída desejada" - não é?
21816 Ian Hincks

Obrigado @IanHincks! Isso parece um erro de digitação; Eu o corrigi agora.
Alex Riley #

1
Resposta muito boa. Também vale a pena notar que ij,jkpoderia funcionar por si só (sem as setas) para formar a multiplicação da matriz. Mas, para maior clareza, é melhor colocar as setas e as dimensões da saída. Está no post do blog.
ComputerScientist

1
@ Peaceful: esta é uma daquelas ocasiões em que é difícil escolher a palavra certa! Eu sinto que "coluna" se encaixa um pouco melhor aqui, pois Atem comprimento 3, o mesmo que o comprimento das colunas B(enquanto as linhas Btêm comprimento 4 e não podem ser multiplicadas por elementos A).
Alex Riley

1
Observe que omitir ->afeta a semântica: "No modo implícito, os subscritos escolhidos são importantes, pois os eixos da saída são reordenados em ordem alfabética. Isso significa que np.einsum('ij', a)isso não afeta uma matriz 2D, enquanto np.einsum('ji', a)faz sua transposição".
precisa

40

Compreender a idéia de numpy.einsum()é muito fácil se você a entender intuitivamente. Como exemplo, vamos começar com uma descrição simples envolvendo multiplicação de matrizes .


Para usar numpy.einsum(), tudo o que você precisa fazer é passar a chamada sequência de subscritos como argumento, seguida pelas matrizes de entrada .

Digamos que você tenha duas matrizes 2D Ae B, e deseja fazer a multiplicação de matrizes. Então você faz:

np.einsum("ij, jk -> ik", A, B)

Aqui, a sequência de subscritos ij corresponde à matriz, Aenquanto a sequência de subscritos jk corresponde à matriz B. Além disso, o mais importante a ser observado aqui é que o número de caracteres em cada sequência de caracteres subscrito deve corresponder às dimensões da matriz. (ou seja, dois caracteres para matrizes 2D, três caracteres para matrizes 3D e assim por diante.) E se você repetir os caracteres entre as seqüências de caracteres subscritas ( jno nosso caso), isso significa que você deseja que a einsoma aconteça nessas dimensões. Assim, eles serão reduzidos pela soma. (ou seja, essa dimensão desaparecerá )

A cadeia de caracteres subscrita após isso ->será nossa matriz resultante. Se você deixar em branco, tudo será somado e um valor escalar será retornado como resultado. Caso contrário, a matriz resultante terá dimensões de acordo com a cadeia de caracteres subscrita . No nosso exemplo, será ik. Isso é intuitivo porque sabemos que, para multiplicação de matrizes, o número de colunas na matriz Adeve corresponder ao número de linhas na matriz, o Bque está acontecendo aqui (ou seja, codificamos esse conhecimento repetindo o caractere jna sequência de caracteres subscrito )


Aqui estão mais alguns exemplos que ilustram o uso / potência np.einsum()na implementação de algumas operações comuns de tensor ou nd-array , de forma sucinta.

Entradas

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1) Multiplicação de matrizes (semelhante a np.matmul(arr1, arr2))

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2) Extraia elementos ao longo da diagonal principal (semelhante a np.diag(arr))

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3) Produto Hadamard (isto é, produto de duas matrizes) (semelhante a arr1 * arr2)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4) Quadratura elemento a elemento (semelhante a np.square(arr)ou arr ** 2)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5) Rastreio (ou seja, soma dos elementos da diagonal principal) (semelhante a np.trace(arr))

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6) Transposição da matriz (semelhante a np.transpose(arr))

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7) Produto externo (de vetores) (semelhante a np.outer(vec1, vec2))

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8) Produto interno (de vetores) (semelhante a np.inner(vec1, vec2))

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9) Soma ao longo do eixo 0 (semelhante a np.sum(arr, axis=0))

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10) Soma ao longo do eixo 1 (semelhante a np.sum(arr, axis=1))

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11) Multiplicação de matrizes em lote

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12) Soma ao longo do eixo 2 (semelhante a np.sum(arr, axis=2))

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13) Soma todos os elementos da matriz (semelhante a np.sum(arr))

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14) Soma sobre eixos múltiplos (ou seja, marginalização)
(semelhante a np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15) Produtos de ponto duplo (semelhante ao np.sum (produto hadamard), cf. 3 )

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16) multiplicação de matrizes 2D e 3D

Essa multiplicação pode ser muito útil para resolver o sistema linear de equações ( Ax = b ) onde você deseja verificar o resultado.

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

Pelo contrário, se for necessário usar np.matmul()essa verificação, precisamos executar algumas reshapeoperações para obter o mesmo resultado, como:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

Bônus : Leia mais matemática aqui: Einstein-Summation e definitivamente aqui: Tensor-Notation


7

Vamos criar 2 matrizes, com dimensões diferentes, mas compatíveis, para destacar sua interação

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Seu cálculo utiliza um 'ponto' (soma dos produtos) de a (2,3) com a (3,4) para produzir uma matriz (4,2). ié a 1ª dim de A, a última de C; ko último de B, primeiro de C. jé "consumido" pelo somatório.

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

É o mesmo que np.dot(A,B).T- é a saída final que é transposta.

Para ver mais do que acontece j, altere os Csubscritos para ijk:

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

Isso também pode ser produzido com:

A[:,:,None]*B[None,:,:]

Ou seja, adicione uma kdimensão ao final Ae ià frente de B, resultando em uma matriz (2,3,4).

0 + 4 + 16 = 20, 9 + 28 + 55 = 92etc; Soma je transpõe para obter o resultado anterior:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]

6

Achei NumPy: Os truques do comércio (Parte II) instrutivos

Usamos -> para indicar a ordem da matriz de saída. Então, pense em 'ij, i-> j' como tendo o lado esquerdo (LHS) e o lado direito (RHS). Qualquer repetição de etiquetas no LHS calcula o elemento do produto de maneira sábia e depois soma. Alterando o rótulo no lado RHS (saída), podemos definir o eixo no qual queremos prosseguir em relação à matriz de entrada, ou seja, somatório ao longo do eixo 0, 1 e assim por diante.

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

Observe que existem três eixos, i, j, k, e que j é repetido (no lado esquerdo). i,jrepresenta linhas e colunas para a. j,kpara b.

Para calcular o produto e alinhar o jeixo, precisamos adicionar um eixo a a. ( bserá transmitido ao longo (?) do primeiro eixo)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jestá ausente do lado direito, então somamos jqual é o segundo eixo da matriz 3x3x3

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

Finalmente, os índices são (em ordem alfabética) invertidos no lado direito, para que possamos transpor.

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>

NumPy: Os truques do comércio (Parte II) parece exigir um convite do proprietário do site, bem como uma conta Wordpress
Tejas Shetty

... link atualizado, felizmente eu o encontrei com uma pesquisa. - Thnx.
wwii

@TejasShetty Muitas respostas melhores aqui agora - talvez eu deva excluir esta.
Wwii

2
Por favor, não apague sua resposta.
Tejas Shetty

4

Ao ler as equações de einsum, achei mais útil apenas reduzi-las mentalmente às suas versões imperativas.

Vamos começar com a seguinte declaração (imponente):

C = np.einsum('bhwi,bhwj->bij', A, B)

Trabalhando primeiro com a pontuação, vemos que temos dois blobs separados por vírgula de quatro letras - bhwie bhwj, antes da seta, e um único blob de três letras bijdepois dela. Portanto, a equação produz um resultado de tensor de classificação 3 a partir de duas entradas de tensor de classificação 4.

Agora, permita que cada letra em cada blob seja o nome de uma variável de intervalo. A posição em que a letra aparece no blob é o índice do eixo no qual ele varia nesse tensor. A soma imperativa que produz cada elemento de C, portanto, deve começar com três aninhados para loops, um para cada índice de C.

for b in range(...):
    for i in range(...):
        for j in range(...):
            # the variables b, i and j index C in the order of their appearance in the equation
            C[b, i, j] = ...

Então, essencialmente, você tem um forloop para cada índice de saída C. Vamos deixar os intervalos indeterminados por enquanto.

A seguir, examinamos o lado esquerdo - existem variáveis ​​de intervalo que não aparecem no lado direito? No nosso caso - sim he w. Adicione um forloop aninhado interno para cada variável:

for b in range(...):
    for i in range(...):
        for j in range(...):
            C[b, i, j] = 0
            for h in range(...):
                for w in range(...):
                    ...

Dentro do loop mais interno, agora temos todos os índices definidos, para que possamos escrever o somatório real e a tradução concluída:

# three nested for-loops that index the elements of C
for b in range(...):
    for i in range(...):
        for j in range(...):

            # prepare to sum
            C[b, i, j] = 0

            # two nested for-loops for the two indexes that don't appear on the right-hand side
            for h in range(...):
                for w in range(...):
                    # Sum! Compare the statement below with the original einsum formula
                    # 'bhwi,bhwj->bij'

                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]

Se você conseguiu seguir o código até agora, parabéns! Isso é tudo o que você precisa para poder ler as equações de einsum. Observe, em particular, como a fórmula original do einsum é mapeada para a instrução final de soma no snippet acima. Os loops de for e os limites do intervalo são apenas leves e essa afirmação final é tudo o que você realmente precisa para entender o que está acontecendo.

Por uma questão de integridade, vamos ver como determinar os intervalos para cada variável de intervalo. Bem, o intervalo de cada variável é simplesmente o comprimento da (s) dimensão (s) que ela indexa. Obviamente, se uma variável indexar mais de uma dimensão em um ou mais tensores, os comprimentos de cada uma dessas dimensões deverão ser iguais. Aqui está o código acima com os intervalos completos:

# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
    for i in range(A.shape[3]):
        for j in range(B.shape[3]):
            # h and w can come from either A or B
            for h in range(A.shape[1]):
                for w in range(A.shape[2]):
                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.