Extend tf.identity to work with CompositeTensors (such as SparseTensor)

PiperOrigin-RevId: 267507615
This commit is contained in:
Edward Loper 2019-09-05 19:19:27 -07:00 committed by TensorFlower Gardener
parent fb657674fc
commit 4e0dcb47a7
2 changed files with 18 additions and 7 deletions

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@ -56,11 +57,11 @@ class IdentityOpTest(test.TestCase):
shape = [2, 3]
array_2x3 = [[1, 2, 3], [6, 5, 4]]
tensor = constant_op.constant(array_2x3)
self.assertEquals(shape, tensor.get_shape())
self.assertEquals(shape, array_ops.identity(tensor).get_shape())
self.assertEquals(shape, array_ops.identity(array_2x3).get_shape())
self.assertEquals(shape,
array_ops.identity(np.array(array_2x3)).get_shape())
self.assertEqual(shape, tensor.get_shape())
self.assertEqual(shape, array_ops.identity(tensor).get_shape())
self.assertEqual(shape, array_ops.identity(array_2x3).get_shape())
self.assertEqual(shape,
array_ops.identity(np.array(array_2x3)).get_shape())
@test_util.run_v1_only("b/120545219")
def testRefIdentityShape(self):
@ -69,8 +70,15 @@ class IdentityOpTest(test.TestCase):
tensor = variables.VariableV1(
constant_op.constant(
[[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
self.assertEquals(shape, tensor.get_shape())
self.assertEquals(shape, gen_array_ops.ref_identity(tensor).get_shape())
self.assertEqual(shape, tensor.get_shape())
self.assertEqual(shape, gen_array_ops.ref_identity(tensor).get_shape())
def testCompositeTensor(self):
original = sparse_tensor.SparseTensor([[3]], [1.0], [100])
copied = array_ops.identity(original)
self.assertAllEqual(original.indices, copied.indices)
self.assertAllEqual(original.values, copied.values)
self.assertAllEqual(original.dense_shape, copied.dense_shape)
if __name__ == "__main__":

View File

@ -25,6 +25,7 @@ import six
from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -194,6 +195,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
Returns:
A `Tensor`. Has the same type as `input`.
"""
if isinstance(input, composite_tensor.CompositeTensor):
return nest.map_structure(identity, input, expand_composites=True)
if context.executing_eagerly() and not hasattr(input, "graph"):
# Make sure we get an input with handle data attached from resource
# variables. Variables have correct handle data when graph building.