Simplificando, torch.Tensor.view()
inspirado em numpy.ndarray.reshape()
ou numpy.reshape()
, cria uma nova visualização do tensor, desde que a nova forma seja compatível com a forma do tensor original.
Vamos entender isso em detalhes usando um exemplo concreto.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Com esse tensor t
de forma (18,)
, novas vistas podem ser criadas apenas para as seguintes formas:
(1, 18)
ou equivalentemente (1, -1)
ou ou equivalentemente ou ou equivalentemente ou ou equivalentemente ou ou equivalentemente ou ou equivalentemente ou(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Como já podemos observar pelas tuplas de forma acima, a multiplicação dos elementos da tupla de forma (por exemplo 2*9
, 3*6
etc.) deve sempre ser igual ao número total de elementos no tensor original ( 18
no nosso exemplo).
Outra coisa a observar é que usamos um -1
em um dos lugares em cada uma das tuplas de forma. Usando a -1
, estamos sendo preguiçosos ao fazer o cálculo e delegar a tarefa ao PyTorch para fazer o cálculo desse valor para a forma quando ela cria a nova exibição . Uma coisa importante a ser observada é que só podemos usar uma única -1
na tupla de forma. Os valores restantes devem ser explicitamente fornecidos por nós. Outro PyTorch irá reclamar, lançando um RuntimeError
:
RuntimeError: apenas uma dimensão pode ser deduzida
Portanto, com todas as formas mencionadas acima, o PyTorch sempre retornará uma nova visualização do tensor original t
. Isso basicamente significa que apenas altera as informações de passada do tensor para cada uma das novas visualizações solicitadas.
Abaixo estão alguns exemplos que ilustram como as passadas dos tensores são alteradas a cada nova vista .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Agora, veremos os avanços para as novas visualizações :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Então essa é a mágica da view()
função. Ele apenas altera os passos do tensor (original) para cada uma das novas vistas , desde que a forma da nova vista seja compatível com a forma original.
Outra coisa interessante pode observar a partir dos tuplos Strides é que o valor do elemento no 0 ª posição é igual ao valor do elemento no 1 st posição da tupla forma.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Isto é porque:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
o passo (6, 1)
diz que para ir de um elemento para o próximo elemento ao longo da 0ª dimensão, temos que pular ou dar 6 passos. (ou seja, para ir de 0
para 6
, alguém tem que tomar 6 passos.) Mas, para ir de um elemento para o próximo elemento no 1 st dimensão, só precisamos de apenas um passo (por exemplo, para ir a partir 2
de 3
).
Assim, as informações de passada estão no centro de como os elementos são acessados da memória para realizar o cálculo.
Essa função retornaria uma vista e é exatamente a mesma que usar torch.Tensor.view()
, desde que a nova forma seja compatível com a forma do tensor original. Caso contrário, ele retornará uma cópia.
No entanto, as notas de torch.reshape()
adverte que:
entradas contíguas e entradas com passos compatíveis podem ser remodeladas sem copiar, mas não se deve depender do comportamento de cópia versus exibição.
reshape
no PyTorch ?!