O que significa a saída da função model.predict de Keras?


14

Criei um modelo LSTM para prever perguntas duplicadas no conjunto de dados oficial do Quora. Os rótulos de teste são 0 ou 1. 1 indica que o par de perguntas está duplicado. Depois de criar o modelo usando model.fit, eu testo o modelo usando model.predictos dados de teste. A saída é uma matriz de valores, como abaixo:

 [ 0.00514298]
 [ 0.15161049]
 [ 0.27588326]
 [ 0.00236167]
 [ 1.80067325]
 [ 0.01048524]
 [ 1.43425131]
 [ 1.99202418]
 [ 0.54853892]
 [ 0.02514757]

Estou apenas mostrando os 10 primeiros valores da matriz. Não entendo o que esses valores significam e qual é o rótulo previsto para cada par de perguntas?


11
Eu acho que você tem problema em sua rede .. as probabilidades deve ser em escala 0-1 .. mas você tem 1.99, eu acho que você tem algo de errado ..
Ghanem

Respostas:


8

A saída de uma rede neural nunca será, por padrão, binária - ou seja, zeros ou uns. A rede trabalha com valores contínuos (não discretos), a fim de otimizar a perda mais livremente na estrutura de descida gradiente.

Dê uma olhada aqui em uma pergunta semelhante que também mostra algum código.

Sem qualquer tipo de ajuste e dimensionamento, é provável que a saída da sua rede caia em algum lugar no intervalo da sua entrada, em termos de seu valor nominal. No seu caso, isso parece estar entre 0 e 2.

Agora você pode escrever uma função que transforma seus valores acima em 0 ou 1, com base em algum limite. Por exemplo, dimensione os valores para estar no intervalo [0, 1]; se o valor estiver abaixo de 0,5, retorne 0; se acima de 0,5, retorne 1.


Obrigado, também pensei em usar um valor limite para classificar os rótulos. Mas qual deve ser a base sobre a qual o valor limite foi decidido?
31418 Dookoto_Sea

@Dookoto_Sea você tem que decidir isso sozinho
Jérémy Blain

@Dookoto_Sea Observe que, se seu rótulo é 0 ou 1, seu valor deve estar nesse intervalo, ter uma escala de valores de previsões de [0, 2] é intrigante, você precisa alterar a saída do modelo
Jérémy Blain

6

Se este é um problema de classificação, você deve mudar sua rede para ter 2 neurônios de saída.

Você pode converter rótulos em vetores codificados de uma só vez usando

y_train_binary = keras.utils.to_categorical(y_train, num_classes)
y_test_binary = keras.utils.to_categorical(y_test, num_classes)

Em seguida, verifique se a camada de saída possui dois neurônios com uma função de ativação softmax.

model.add(Dense(num_classes, activation='softmax'))

Isso resultará em model.predict(x_test_reshaped)uma lista de listas. Onde a lista interna é a probabilidade de uma instância pertencente a cada classe. Isso somará 1 e, evidentemente, o rótulo decidido deve ser o neurônio de saída com a maior probabilidade.

O Keras inclui isso em sua biblioteca, para que você não precise fazer essa comparação. Você pode obter o rótulo da classe diretamente usando model.predict_classes(x_test_reshaped).


3
"Se este é um problema de classificação, você deve mudar sua rede para ter 2 neurônios de saída." .. desculpe Jah, mas ele não deveria, ele pode fazê-lo com um neurônio e sigmóide em vez da função softmax.
Ghanem

@Minion, ambos os métodos são essencialmente equivalentes, o limiar que você precisaria fazer com um único neurônio de saída está implicitamente incorporado na rede. Assim, fornecendo a saída binária.
JahKnows

11
Sim, entendi. Comentei apenas porque ele mencionou: "deve mudar sua rede para ter 2 neurônios de saída". .. thanx
Ghanem 16/11

1

As previsões são baseadas no que você alimenta como saídas de treinamento e na função de ativação.

Por exemplo, com entrada de 0 a 1 e uma função de ativação sigmóide para a saída com uma perda de entropia cruzada binária, você obteria a probabilidade de um 1. Dependendo do custo de tomar a decisão errada em qualquer direção, você pode decidir como lide com essas probabilidades (por exemplo, preveja a categoria "1", se a probabilidade for> 0,5 ou talvez já seja> 0,1).

(-,

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.