From 4e0dcb47a7c1293a007772a59d6b4f6a4c89a933 Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Thu, 5 Sep 2019 19:19:27 -0700 Subject: [PATCH] Extend tf.identity to work with CompositeTensors (such as SparseTensor) PiperOrigin-RevId: 267507615 --- .../kernel_tests/identity_op_py_test.py | 22 +++++++++++++------ tensorflow/python/ops/array_ops.py | 3 +++ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py index 40ec9db4226..013502dfe09 100644 --- a/tensorflow/python/kernel_tests/identity_op_py_test.py +++ b/tensorflow/python/kernel_tests/identity_op_py_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops @@ -56,11 +57,11 @@ class IdentityOpTest(test.TestCase): shape = [2, 3] array_2x3 = [[1, 2, 3], [6, 5, 4]] tensor = constant_op.constant(array_2x3) - self.assertEquals(shape, tensor.get_shape()) - self.assertEquals(shape, array_ops.identity(tensor).get_shape()) - self.assertEquals(shape, array_ops.identity(array_2x3).get_shape()) - self.assertEquals(shape, - array_ops.identity(np.array(array_2x3)).get_shape()) + self.assertEqual(shape, tensor.get_shape()) + self.assertEqual(shape, array_ops.identity(tensor).get_shape()) + self.assertEqual(shape, array_ops.identity(array_2x3).get_shape()) + self.assertEqual(shape, + array_ops.identity(np.array(array_2x3)).get_shape()) @test_util.run_v1_only("b/120545219") def testRefIdentityShape(self): @@ -69,8 +70,15 @@ class IdentityOpTest(test.TestCase): tensor = variables.VariableV1( constant_op.constant( [[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32)) - self.assertEquals(shape, tensor.get_shape()) - self.assertEquals(shape, gen_array_ops.ref_identity(tensor).get_shape()) + self.assertEqual(shape, tensor.get_shape()) + self.assertEqual(shape, gen_array_ops.ref_identity(tensor).get_shape()) + + def testCompositeTensor(self): + original = sparse_tensor.SparseTensor([[3]], [1.0], [100]) + copied = array_ops.identity(original) + self.assertAllEqual(original.indices, copied.indices) + self.assertAllEqual(original.values, copied.values) + self.assertAllEqual(original.dense_shape, copied.dense_shape) if __name__ == "__main__": diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 8eed8150c89..1bcef69bd7b 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -25,6 +25,7 @@ import six from tensorflow.python.compat import compat from tensorflow.python.eager import context from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -194,6 +195,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin Returns: A `Tensor`. Has the same type as `input`. """ + if isinstance(input, composite_tensor.CompositeTensor): + return nest.map_structure(identity, input, expand_composites=True) if context.executing_eagerly() and not hasattr(input, "graph"): # Make sure we get an input with handle data attached from resource # variables. Variables have correct handle data when graph building.