Papel de "Aplanar" en Keras


Estoy tratando de entender el papel de la función Flatten en Keras. A continuación se muestra mi código, que es una red simple de dos capas. Toma datos de forma en 2 dimensiones (3, 2), y produce datos de forma en 1 dimensión (1, 4):

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

Esto imprime que y tiene forma (1, 4). Sin embargo, si elimino la línea Flatten, entonces imprime que y tiene forma (1, 3, 4).

No entiendo esto. Desde mi comprensión de las redes neuronales, la función model.add(Dense(16, input_shape=(3, 2))) está creando un capa totalmente conectada, con 16 nodos. Cada uno de estos nodos está conectado a cada uno de los elementos de entrada 3x2. Por lo tanto, los 16 nodos en la salida de esta primera capa ya son "planos". Por lo tanto, la forma de salida de la primera capa debe ser (1, 16). Luego, la segunda capa toma esto como una entrada, y produce datos de forma (1, 4).

Entonces, si la salida de la primera capa ya es "plana" y de forma (1, 16), ¿por qué necesito aplanarla aún más?

Gracias!

Author: Marcin Możejko, 2017-04-05

1 answers

Si lee una documentación de Dense aquí verás que:

Dense(16, input_shape=(5,3))

Daría lugar a una red Dense con 3 entradas y 16 salidas que se aplicarían independientemente para cada uno de los 5 pasos. Así que si D(x) transforma un vector de 3 dimensiones a un vector 16-d lo que obtendrás como salida de tu capa sería una secuencia de vectores: [D(x[0,:], D(x[1,:],..., D(x[4,:]] con forma (5, 16). Para tener el comportamiento que especifique primero puede Flatten su entrada a un vector 15-d y luego aplicar Dense:

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

EDITAR: Como algunas personas lucharon por entender, aquí tienes una imagen explicativa:

introduzca la descripción de la imagen aquí

 47
Author: Marcin Możejko,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-09-13 21:33:10