“ Flatten”在Keras中的作用是什么?


108

我试图了解该Flatten功能在Keras中的作用。下面是我的代码,它是一个简单的两层网络。它接收形状为(3,2)的二维数据,并输出形状为(1,4)的一维数据:

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

x = np.array([[[1, 2], [3, 4], [5, 6]]])

y = model.predict(x)

print y.shape

打印出y形状为(1、4)的图形。但是,如果我删除该Flatten行,则会打印出y形状为(1、3、4)的行。

我不明白 根据我对神经网络的理解,该model.add(Dense(16, input_shape=(3, 2)))函数正在创建一个具有16个节点的隐藏的全连接层。这些节点中的每个都连接到3x2输入元素中的每个。因此,该第一层输出处的16个节点已经“平坦”。因此,第一层的输出形状应为(1、16)。然后,第二层将此作为输入,并输出形状为(1、4)的数据。

因此,如果第一层的输出已经“平坦”并且形状为(1,16),为什么还要进一步使其平坦?

Answers:


123

如果阅读的Keras文档条目Dense,您将看到以下调用:

Dense(16, input_shape=(5,3))

这样将形成一个Dense具有3个输入和16个输出的网络,这些网络将独立应用于5个步骤中的每个步骤。因此,如果D(x)将3维矢量转换为16维矢量,则从图层输出的输出将是一系列矢量:[D(x[0,:]), D(x[1,:]),..., D(x[4,:])]shape (5, 16)。为了具有您指定的行为,您可以先将Flatten输入输入到15维矢量,然后应用Dense

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

编辑: 由于有些人难以理解-在这里,您有一个解释性的图像:

在此处输入图片说明


感谢您的解释。只是为了澄清一下:用Dense(16, input_shape=(5,3),将来自16组(对于这些神经元的所有5组)的每个输出神经元都连接到所有(3 x 5 = 15)输入神经元吗?或者将第一个16组中的每个神经元仅连接到第一个5组输入神经元中的3个神经元,然后将第二个16组中的每个神经元仅连接到第二组5个输入神经元中的3个神经元神经元等。。。。
Karnivaurus

1
您有一个密集层,其中包含3个神经元,输出16个层,将其应用于5组3个神经元中的每一个。
MarcinMożejko'17

1
喔好吧。我想做的是将5个彩色像素的列表作为输入,我希望它们通过完全连接的层。因此input_shape=(5,3)意味着有5个像素,每个像素具有三个通道(R,G,B)。但是根据您所说的,每个通道将被单独处理,而我希望第一层中的所有神经元都对所有三个通道进行处理。那么Flatten在开始时立即应用该层会给我我想要的吗?
Karnivaurus

8
有和没有的一点绘图Flatten可能有助于理解。
Xvolks

2
好的,伙计们-我为您提供了一张图片。现在,您可以删除您的否决票了。
MarcinMożejko17年


35

简短阅读:

展平张量意味着除去除一个以外的所有尺寸。这正是Flatten层所做的。

长读:

如果我们考虑创建的原始模型(具有Flatten层),则可以得到以下模型摘要:

Layer (type)                 Output Shape              Param #   
=================================================================
D16 (Dense)                  (None, 3, 16)             48        
_________________________________________________________________
A (Activation)               (None, 3, 16)             0         
_________________________________________________________________
F (Flatten)                  (None, 48)                0         
_________________________________________________________________
D4 (Dense)                   (None, 4)                 196       
=================================================================
Total params: 244
Trainable params: 244
Non-trainable params: 0

对于此摘要,下一张图像有望对每一层的输入和输出大小提供更多的了解。

您可以阅读的Flatten层的输出形状为(None, 48)。这里是提示。您应该阅读(1, 48)(2, 48)或...或(16, 48)...或(32, 48),...

实际上,None在该位置上意味着任何批量。对于召回的输入,第一维表示批处理大小,第二维表示输入要素的数量。

的作用展平层在Keras是超级简单:

对张量进行展平操作可将张量整形,使其形状等于不包含批次尺寸的张量中包含的元素数量。

在此处输入图片说明


注意:我使用了该model.summary()方法来提供输出形状和参数详细信息。


1
非常有见地的图表。
Shrey Joshi

1
感谢您的图表。它给我一个清晰的画面。
Sultan Ahmed Sagor

0

展平表示您如何序列化多维张量(通常是输入张量)。这允许在(拉平的)输入张量和第一隐藏层之间进行映射。如果第一个隐藏层是“致密的”,则(序列化的)输入张量的每个元素都将与隐藏数组的每个元素连接。如果不使用Flatten,则输入张量映射到第一个隐藏层的方式将是模棱两可的。


By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.