Fix 0D case by ensuring result of np.prod is always integer type

This commit is contained in:
Trevor Morris 2019-08-01 14:32:58 -07:00
parent 18a9074060
commit 0798838e3f

View File

@ -582,8 +582,9 @@ class Flatten(Layer):
input_shape = inputs.shape
if input_shape[1:].is_fully_defined():
outputs = array_ops.reshape(
inputs, (-1, tensor_shape.dimension_value(np.prod(input_shape[1:]))))
flattened_dim = tensor_shape.dimension_value(
np.prod(input_shape[1:], dtype=int))
outputs = array_ops.reshape(inputs, (-1, flattened_dim))
else:
outputs = array_ops.reshape(
inputs, (tensor_shape.dimension_value(inputs.shape[0]) or