Improve code based on reviewer feedback

This commit is contained in:
Trevor Morris 2019-07-19 15:54:40 -07:00
parent 2f61f75e24
commit 18a9074060

View File

@ -580,9 +580,10 @@ class Flatten(Layer):
permutation.append(1)
inputs = array_ops.transpose(inputs, perm=permutation)
input_shape = tensor_shape.TensorShape(inputs.shape).as_list()
if input_shape and all(input_shape[1:]):
outputs = array_ops.reshape(inputs, (-1, int(np.prod(input_shape[1:]))))
input_shape = inputs.shape
if input_shape[1:].is_fully_defined():
outputs = array_ops.reshape(
inputs, (-1, tensor_shape.dimension_value(np.prod(input_shape[1:]))))
else:
outputs = array_ops.reshape(
inputs, (tensor_shape.dimension_value(inputs.shape[0]) or