Estou tentando entender o papel da Flatten
função em Keras. Abaixo está meu código, que é uma rede simples de duas camadas. Ele recebe dados bidimensionais de forma (3, 2) e produz dados unidimensionais de forma (1, 4):
model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
x = np.array([[[1, 2], [3, 4], [5, 6]]])
y = model.predict(x)
print y.shape
Isso imprime que y
tem forma (1, 4). No entanto, se eu remover a Flatten
linha, ela imprime que y
tem forma (1, 3, 4).
Eu não entendo isso. Do meu conhecimento de redes neurais, a model.add(Dense(16, input_shape=(3, 2)))
função é criar uma camada oculta totalmente conectada, com 16 nós. Cada um desses nós está conectado a cada um dos elementos de entrada 3x2. Portanto, os 16 nós na saída dessa primeira camada já são "planos". Portanto, a forma de saída da primeira camada deve ser (1, 16). Em seguida, a segunda camada leva isso como uma entrada e produz dados de forma (1, 4).
Portanto, se a saída da primeira camada já é "plana" e tem a forma (1, 16), por que preciso achatá-la ainda mais?
Dense(16, input_shape=(5,3)
, cada neurônio de saída do conjunto de 16 (e, para todos os 5 conjuntos desses neurônios), será conectado a todos (3 x 5 = 15) neurônios de entrada? Ou cada neurônio no primeiro conjunto de 16 será conectado apenas aos 3 neurônios no primeiro conjunto de 5 neurônios de entrada, e então cada neurônio no segundo conjunto de 16 será conectado apenas aos 3 neurônios no segundo conjunto de 5 neurônios de entrada neurônios, etc .... Estou confuso sobre o que é!