diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index af324a87c7e..df78cffa4a2 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -580,15 +580,9 @@ class Flatten(Layer): permutation.append(1) inputs = array_ops.transpose(inputs, perm=permutation) - input_shape = inputs.shape - if input_shape[1:].is_fully_defined(): - flattened_dim = tensor_shape.dimension_value( - np.prod(input_shape[1:], dtype=int)) - outputs = array_ops.reshape(inputs, (-1, flattened_dim)) - else: - outputs = array_ops.reshape( - inputs, (tensor_shape.dimension_value(inputs.shape[0]) or - array_ops.shape(inputs)[0], -1)) + outputs = array_ops.reshape( + inputs, (tensor_shape.dimension_value(inputs.shape[0]) or + array_ops.shape(inputs)[0], -1)) if not context.executing_eagerly(): outputs.set_shape(self.compute_output_shape(inputs.shape)) return outputs