Pytorch: maneira correta de usar mapas de peso personalizados na arquitetura unet


8

Existe um truque famoso na arquitetura u-net para usar mapas de peso personalizados para aumentar a precisão. Abaixo estão os detalhes:

insira a descrição da imagem aqui

Agora, perguntando aqui e em vários outros lugares, conheço duas abordagens. Quero saber qual é a correta ou se existe outra abordagem correta que seja mais correta?

1) Primeiro é usar o torch.nn.Functionalmétodo no loop de treinamento

loss = torch.nn.functional.cross_entropy(output, target, w) onde w será o peso personalizado calculado.

2) O segundo é usar reduction='none'na chamada da função de perda fora do loop de treinamento criterion = torch.nn.CrossEntropy(reduction='none')

e depois no ciclo de treinamento, multiplicando-se pelo peso personalizado

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

Agora, estou meio confuso qual é o certo ou existe alguma outra maneira, ou ambos estão certos?

Respostas:


3

A parte de ponderação parece simplesmente entropia cruzada ponderada, que é executada dessa maneira para o número de classes (2 no exemplo abaixo).

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

EDITAR:

Você viu essa implementação de Patrick Black?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()

A coisa é peso é calculado por uma determinada função aqui e não é discreto. Para obter mais informações, aqui está um artigo - arxiv.org/abs/1505.04597
Marque

1
@ Mark oh eu vejo agora. Portanto, é uma saída de perda pixel a pixel. E as bordas são pré-calculadas usando alguma biblioteca opencvou algo assim, e essas posições de pixel são salvas para cada imagem e depois multiplicadas pelos tensores de perda mais tarde durante o treinamento, para que o algoritmo se concentre em reduzir a perda nessas áreas.
jchaykow

Obrigado. Este legítimo parece uma resposta, tentarei verificar e implementá-lo mais e aceitarei sua resposta depois.
Mark

Você pode explicar a intuição por trás dessa linhalogp = logp.gather(1, target.view(batch_size, 1, H, W))
Mark

0

Observe que torch.nn.CrossEntropyLoss () é uma classe que chama torch.nn.funcional. Consulte https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

Você pode usar os pesos ao definir os critérios. Comparando-os funcionalmente, ambos os métodos são os mesmos.

Agora, não entendo sua idéia de calcular a perda dentro do loop de treinamento no método 1 e fora do loop de treinamento no método 2. se você calcular a perda fora do loop, como irá retropropagar?


Eu não estava confuso entre usar torch.nn.CrossEntropyLoss() e torch.nn.functional.cross_entropy(output, target, w), eu estava confuso como usar o peso mapas personalizados no loss.Please ver este papel - arxiv.org/abs/1505.04597 e deixe-me saber, se você ainda não são capazes de descobrir o que eu sou pergunta
Mark

1
Se entendi direito, acho que o método 2 é o correto. Os pesos (w) dentro da perda torch.nn.functional.cross_entropy (output, target, w) são pesos para classes que não w (x) na fórmula. Podemos testá-lo facilmente com um pequeno script.
Devansh Bisla

Sim, até eu estou chegando à mesma conclusão. Voltarei a você se minha rede funcionar como esperado e marcará a resposta como aceita.
Mark

tudo bem, não é working.I estou recebendo grad can be implicitly created only for scalar outputsquando eu corro perda = perda * w método
Mark

Tem certeza de que está resumindo ou fazendo a média?
Devansh Bisla
Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.