Qual é o uso do torch.no_grad no pytorch?


19

Eu sou novo no pytorch e comecei com este código do github. Eu não entendo o comentário na linha 60-61 no código "because weights have requires_grad=True, but we don't need to track this in autograd". Entendi que mencionamos requires_grad=Trueàs variáveis ​​as quais precisamos calcular os gradientes para usar o autograd, mas o que significa ser "tracked by autograd"?

Respostas:


21

O wrapper "com torch.no_grad ()" configurou temporariamente todos os sinalizadores require_grad como false. Um exemplo do tutorial oficial do PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Fora:

True
True
False

Eu recomendo que você leia todos os tutoriais no site acima.

No seu exemplo: acho que o autor não deseja que o PyTorch calcule os gradientes das novas variáveis ​​definidas w1 e w2, pois ele apenas deseja atualizar seus valores.


5
with torch.no_grad()

fará com que todas as operações no bloco não tenham gradientes.

No pytorch, você não pode alterar o posicionamento de w1 e w2, que são duas variáveis ​​com require_grad = True . Penso que evitar a alteração de colocação de w1 e w2 é porque causará erro no cálculo da propagação traseira. Uma vez que a mudança na colocação muda totalmente w1 e w2.

No entanto, se você usar isso no_grad(), poderá controlar o novo w1 e o novo w2 não terão gradientes, pois são gerados por operações, o que significa que você altera apenas o valor de w1 e w2, e não a parte do gradiente, eles ainda têm informações de gradiente de variável definidas anteriormente e a propagação traseira pode continuar.

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.