diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index df78cffa4a2..af324a87c7e 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -580,9 +580,15 @@ class Flatten(Layer): permutation.append(1) inputs = array_ops.transpose(inputs, perm=permutation) - outputs = array_ops.reshape( - inputs, (tensor_shape.dimension_value(inputs.shape[0]) or - array_ops.shape(inputs)[0], -1)) + input_shape = inputs.shape + if input_shape[1:].is_fully_defined(): + flattened_dim = tensor_shape.dimension_value( + np.prod(input_shape[1:], dtype=int)) + outputs = array_ops.reshape(inputs, (-1, flattened_dim)) + else: + outputs = array_ops.reshape( + inputs, (tensor_shape.dimension_value(inputs.shape[0]) or + array_ops.shape(inputs)[0], -1)) if not context.executing_eagerly(): outputs.set_shape(self.compute_output_shape(inputs.shape)) return outputs