diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index df78cffa4a2..af324a87c7e 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -580,9 +580,15 @@ class Flatten(Layer):
       permutation.append(1)
       inputs = array_ops.transpose(inputs, perm=permutation)
 
-    outputs = array_ops.reshape(
-        inputs, (tensor_shape.dimension_value(inputs.shape[0]) or
-                 array_ops.shape(inputs)[0], -1))
+    input_shape = inputs.shape
+    if input_shape[1:].is_fully_defined():
+      flattened_dim = tensor_shape.dimension_value(
+          np.prod(input_shape[1:], dtype=int))
+      outputs = array_ops.reshape(inputs, (-1, flattened_dim))
+    else:
+      outputs = array_ops.reshape(
+          inputs, (tensor_shape.dimension_value(inputs.shape[0]) or
+                   array_ops.shape(inputs)[0], -1))
     if not context.executing_eagerly():
       outputs.set_shape(self.compute_output_shape(inputs.shape))
     return outputs