Qual é o papel de “Flatten” em Keras?


108

Estou tentando entender o papel da Flattenfunção em Keras. Abaixo está meu código, que é uma rede simples de duas camadas. Ele recebe dados bidimensionais de forma (3, 2) e produz dados unidimensionais de forma (1, 4):

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

Isso imprime que ytem forma (1, 4). No entanto, se eu remover a Flattenlinha, ela imprime que ytem forma (1, 3, 4).

Eu não entendo isso. Do meu conhecimento de redes neurais, a model.add(Dense(16, input_shape=(3, 2)))função é criar uma camada oculta totalmente conectada, com 16 nós. Cada um desses nós está conectado a cada um dos elementos de entrada 3x2. Portanto, os 16 nós na saída dessa primeira camada já são "planos". Portanto, a forma de saída da primeira camada deve ser (1, 16). Em seguida, a segunda camada leva isso como uma entrada e produz dados de forma (1, 4).

Portanto, se a saída da primeira camada já é "plana" e tem a forma (1, 16), por que preciso achatá-la ainda mais?

Respostas:


123

Se você ler a entrada da documentação de Keras para Dense, verá que esta chamada:

Dense(16, input_shape=(5,3))

resultaria em uma Denserede com 3 entradas e 16 saídas que seriam aplicadas independentemente para cada uma das 5 etapas. Portanto, se D(x)transforma o vetor tridimensional em vetor 16-d, o que você obterá como saída de sua camada será uma sequência de vetores: [D(x[0,:]), D(x[1,:]),..., D(x[4,:])]com forma (5, 16). Para ter o comportamento especificado, você pode primeiro Flatteninserir um vetor de 15 d e depois aplicar Dense:

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

EDITAR: Como algumas pessoas tiveram dificuldade em entender - aqui você tem uma imagem explicativa:

insira a descrição da imagem aqui


Obrigado pela sua explicação. Só para esclarecer: comDense(16, input_shape=(5,3) , cada neurônio de saída do conjunto de 16 (e, para todos os 5 conjuntos desses neurônios), será conectado a todos (3 x 5 = 15) neurônios de entrada? Ou cada neurônio no primeiro conjunto de 16 será conectado apenas aos 3 neurônios no primeiro conjunto de 5 neurônios de entrada, e então cada neurônio no segundo conjunto de 16 será conectado apenas aos 3 neurônios no segundo conjunto de 5 neurônios de entrada neurônios, etc .... Estou confuso sobre o que é!
Karnivaurus

1
Você tem uma camada densa que obtém 3 neurônios e a saída 16 que é aplicada a cada um dos 5 conjuntos de 3 neurônios.
Marcin Możejko

1
Ah ok. O que estou tentando fazer é pegar uma lista de 5 pixels de cor como entrada e quero que eles passem por uma camada totalmente conectada. Isso input_shape=(5,3)significa que existem 5 pixels, e cada pixel possui três canais (R, G, B). Mas, de acordo com o que você está dizendo, cada canal seria processado individualmente, enquanto eu quero que todos os três canais sejam processados ​​por todos os neurônios na primeira camada. Então, aplicar a Flattencamada imediatamente no início me daria o que eu quero?
Karnivaurus

8
Um pequeno desenho com e sem Flattenpode ajudar a entender.
Xvolks

2
Ok, pessoal - eu forneci uma imagem para vocês. Agora você pode deletar seus votos negativos.
Marcin Możejko


35

leitura curta:

Achatar um tensor significa remover todas as dimensões, exceto uma. Isso é exatamente o que a camada Flatten faz.

longa leitura:

Se levarmos em consideração o modelo original (com a camada Flatten) criado, podemos obter o seguinte resumo do modelo:

Layer (type)                 Output Shape              Param #   
=================================================================
D16 (Dense)                  (None, 3, 16)             48        
_________________________________________________________________
A (Activation)               (None, 3, 16)             0         
_________________________________________________________________
F (Flatten)                  (None, 48)                0         
_________________________________________________________________
D4 (Dense)                   (None, 4)                 196       
=================================================================
Total params: 244
Trainable params: 244
Non-trainable params: 0

Para este resumo, a próxima imagem irá fornecer um pouco mais de sentido sobre os tamanhos de entrada e saída para cada camada.

A forma de saída para a camada Flatten como você pode ler é (None, 48). Aqui está a dica. Você deve ler (1, 48)ou (2, 48)ou ... ou (16, 48)... ou (32, 48), ...

Na verdade, Nonenessa posição significa qualquer tamanho de lote. Para que as entradas sejam recuperadas, a primeira dimensão significa o tamanho do lote e a segunda significa o número de recursos de entrada.

A função da camada Flatten em Keras é super simples:

Uma operação de achatamento em um tensor remodela o tensor para ter a forma que é igual ao número de elementos contidos no tensor, não incluindo a dimensão do lote .

insira a descrição da imagem aqui


Observação: usei o model.summary()método para fornecer a forma de saída e os detalhes dos parâmetros.


1
Diagrama muito perspicaz.
Shrey Joshi

1
Obrigado pelo diagrama. Isso me dá uma imagem clara.
Sultão Ahmed Sagor

0

Flatten torna explícito como você serializa um tensor multidimensional (tipicamente o de entrada). Isso permite o mapeamento entre o tensor de entrada (achatado) e a primeira camada oculta. Se a primeira camada oculta for "densa", cada elemento do tensor de entrada (serializado) será conectado a cada elemento da matriz oculta. Se você não usar Flatten, a maneira como o tensor de entrada é mapeado na primeira camada oculta seria ambígua.


0

Eu me deparei com isso recentemente, certamente me ajudou a entender: https://www.cs.ryerson.ca/~aharley/vis/conv/

Portanto, há uma entrada, um Conv2D, MaxPooling2D etc, as camadas Flatten estão no final e mostram exatamente como são formadas e como vão para definir as classificações finais (0-9).

Ao utilizar nosso site, você reconhece que leu e compreendeu nossa Política de Cookies e nossa Política de Privacidade.
Licensed under cc by-sa 3.0 with attribution required.