tl; dr Mesmo que este é um conjunto de dados de classificação de imagem, ele continua a ser um muito fácil tarefa, para a qual se pode facilmente encontrar um mapeamento direto de entradas para previsões.
Responda:
Essa é uma pergunta muito interessante e, graças à simplicidade da regressão logística, você pode encontrar a resposta.
O que a regressão logística faz é que cada imagem aceite entradas e multiplique-as com pesos para gerar sua previsão. O interessante é que, devido ao mapeamento direto entre entrada e saída (ou seja, nenhuma camada oculta), o valor de cada peso corresponde ao quanto cada uma das entradas é levada em consideração ao calcular a probabilidade de cada classe. Agora, pegando os pesos de cada classe e remodelando-os em (ou seja, a resolução da imagem), podemos dizer quais pixels são mais importantes para o cálculo de cada classe .78478428×28
Note, novamente, que esses são os pesos .
Agora, dê uma olhada na imagem acima e foque nos dois primeiros dígitos (ou seja, zero e um). Os pesos azuis significam que a intensidade desse pixel contribui muito para essa classe e os valores vermelhos significam que contribui negativamente.
Agora imagine como uma pessoa desenha um ? Ele desenha uma forma circular vazia no meio. Isso é exatamente o que os pesos captaram. De fato, se alguém desenha o meio da imagem, conta negativamente como um zero. Portanto, para reconhecer zeros, você não precisa de filtros sofisticados e recursos de alto nível. Você pode apenas olhar para os locais dos pixels desenhados e julgar de acordo com isso.0
A mesma coisa para o . Sempre tem uma linha vertical reta no meio da imagem. Tudo o resto conta negativamente.1
O resto dos dígitos é um pouco mais complicado, mas com pouca imaginação, você pode ver o , o , o e o . O restante dos números é um pouco mais difícil, que é o que realmente limita a regressão logística de atingir os anos 90.2378
Com isso, você pode ver que a regressão logística tem uma chance muito boa de acertar muitas imagens e é por isso que é tão alta.
O código para reproduzir a figura acima é um pouco datado, mas aqui está:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b
y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #
correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Train model
batch_size = 64
with tf.Session() as sess:
loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []
sess.run(tf.global_variables_initializer())
for step in range(1, 1001):
x_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})
l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
loss_tr.append(l_tr)
acc_tr.append(a_tr)
loss_ts.append(l_ts)
acc_ts.append(a_ts)
weights = sess.run(W)
print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# Plotting:
for i in range(10):
plt.subplot(2, 5, i+1)
weight = weights[:,i].reshape([28,28])
plt.title(i)
plt.imshow(weight, cmap='RdBu') # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
frame1.axes.get_yaxis().set_visible(False)