Estou tentando atualizar / alterar os parâmetros de um modelo de rede neural e, em seguida, fazer com que a passagem direta da rede neural atualizada esteja no gráfico de computação (não importa quantas alterações / atualizações fazemos).
Tentei essa idéia, mas sempre que faço isso, o pytorch define meus tensores atualizados (dentro do modelo) como folhas, o que mata o fluxo de gradientes para as redes que desejo receber gradientes. Isso mata o fluxo de gradientes porque os nós das folhas não fazem parte do gráfico de computação da maneira que eu quero que sejam (já que não são verdadeiramente folhas).
Eu tentei várias coisas, mas nada parece funcionar. Criei um código fictício independente que imprime os gradientes das redes que desejo ter gradientes:
import torch
import torch.nn as nn
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in loss_net.named_parameters():
print(f'name = {name}')
print(w.size())
hidden = updater_net(hidden).view(1)
print(hidden.size())
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
print(wt.size())
new_params[name] = wt
#del loss_net.fc0.weight
#setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
#setattr(loss_net.fc0, 'weight', wt)
#loss_net.fc0.weight = wt
#loss_net.fc0.weight = nn.Parameter( wt )
##
loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
se alguém souber como fazer isso, por favor, me dê um ping ... Defino o número de vezes que a atualização é 2, pois a operação de atualização deve estar no gráfico de cálculo um número arbitrário de vezes ... portanto, DEVE trabalhar para 2)
Post fortemente relacionado:
- SO: Como os parâmetros de um modelo de pitoneira não podem ser folheados e estar no gráfico de computação?
- Fórum pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076
Postagem cruzada:
backward
? Nomeadamenteretain_graph=True
e / oucreate_graph=True
?