diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index eb45636e677..b21801786d9 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -582,11 +582,11 @@ class Flatten(Layer): input_shape = tensor_shape.TensorShape(inputs.shape).as_list() if input_shape and all(input_shape[1:]): - outputs = array_ops.reshape(inputs, (-1, np.prod(input_shape[1:]))) + outputs = array_ops.reshape(inputs, (-1, int(np.prod(input_shape[1:])))) else: outputs = array_ops.reshape( inputs, (tensor_shape.dimension_value(inputs.shape[0]) or - array_ops.shape(inputs)[0], -1)) + array_ops.shape(inputs)[0], -1)) if not context.executing_eagerly(): outputs.set_shape(self.compute_output_shape(inputs.shape)) return outputs