Eu tenho um conjunto de dados com 3 classes com os seguintes itens:
- Classe 1: 900 elementos
- Classe 2: 15000 elementos
- Classe 3: 800 elementos
Preciso prever as classes 1 e 3, que sinalizam desvios importantes da norma. A classe 2 é o caso "normal" padrão com o qual não me importo.
Que tipo de função de perda eu usaria aqui? Eu estava pensando em usar CrossEntropyLoss, mas como há um desequilíbrio de classe, isso precisaria ser ponderado, suponho? Como isso funciona na prática? Assim (usando PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
Ou o peso deve ser invertido? ou seja, 1 / peso?
Essa é a abordagem correta para começar ou existem outros / melhores métodos que eu poderia usar?
obrigado