Automated rollback of commit 23e33f871b2bf2879b40ebf3b883e104f30f389b. Revert #31450.
PiperOrigin-RevId: 262675086
This commit is contained in:
parent
019da8c2af
commit
bf62fcec00
@ -26,7 +26,6 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
@ -581,29 +580,9 @@ class Flatten(Layer):
|
|||||||
permutation.append(1)
|
permutation.append(1)
|
||||||
inputs = array_ops.transpose(inputs, perm=permutation)
|
inputs = array_ops.transpose(inputs, perm=permutation)
|
||||||
|
|
||||||
input_shape = inputs.shape
|
outputs = array_ops.reshape(
|
||||||
if input_shape[1:].is_fully_defined():
|
inputs, (tensor_shape.dimension_value(inputs.shape[0]) or
|
||||||
flattened_dim = tensor_shape.dimension_value(
|
array_ops.shape(inputs)[0], -1))
|
||||||
np.prod(input_shape[1:], dtype=int))
|
|
||||||
# Temporary fix for integer overflow issue.
|
|
||||||
if flattened_dim > np.iinfo(np.int32).max:
|
|
||||||
shape_dtype = dtypes.int64
|
|
||||||
else:
|
|
||||||
shape_dtype = dtypes.int32
|
|
||||||
outputs = array_ops.reshape(
|
|
||||||
inputs, constant_op.constant((-1, flattened_dim), dtype=shape_dtype))
|
|
||||||
else:
|
|
||||||
batch_size = tensor_shape.dimension_value(inputs.shape[0])
|
|
||||||
if batch_size:
|
|
||||||
# Temporary fix for integer overflow issue.
|
|
||||||
if batch_size > np.iinfo(np.int32).max:
|
|
||||||
shape_dtype = dtypes.int64
|
|
||||||
else:
|
|
||||||
shape_dtype = dtypes.int32
|
|
||||||
outputs = array_ops.reshape(
|
|
||||||
inputs, constant_op.constant((batch_size, -1), dtype=shape_dtype))
|
|
||||||
else:
|
|
||||||
outputs = array_ops.reshape(inputs, (array_ops.shape(inputs)[0], -1))
|
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
outputs.set_shape(self.compute_output_shape(inputs.shape))
|
outputs.set_shape(self.compute_output_shape(inputs.shape))
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@ -556,20 +556,6 @@ class FlattenTest(test.TestCase):
|
|||||||
self.assertEqual(list(np_output.shape), [5, 6])
|
self.assertEqual(list(np_output.shape), [5, 6])
|
||||||
self.assertEqual(y.get_shape().as_list(), [5, None])
|
self.assertEqual(y.get_shape().as_list(), [5, None])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testFlattenLargeDim(self):
|
|
||||||
x = array_ops.placeholder(shape=(None, 21316, 21316, 80), dtype='float32')
|
|
||||||
y = core_layers.Flatten()(x)
|
|
||||||
self.assertEqual(y.shape.as_list(), [None, 21316 * 21316 * 80])
|
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testFlattenLargeBatchDim(self):
|
|
||||||
batch_size = np.iinfo(np.int32).max + 10
|
|
||||||
x = array_ops.placeholder(
|
|
||||||
shape=(batch_size, None, None, 1), dtype='float32')
|
|
||||||
y = core_layers.Flatten()(x)
|
|
||||||
self.assertEqual(y.shape.as_list(), [batch_size, None])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user