Extend tf.identity to work with CompositeTensors (such as SparseTensor)
PiperOrigin-RevId: 267507615
This commit is contained in:
parent
fb657674fc
commit
4e0dcb47a7
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
@ -56,11 +57,11 @@ class IdentityOpTest(test.TestCase):
|
|||||||
shape = [2, 3]
|
shape = [2, 3]
|
||||||
array_2x3 = [[1, 2, 3], [6, 5, 4]]
|
array_2x3 = [[1, 2, 3], [6, 5, 4]]
|
||||||
tensor = constant_op.constant(array_2x3)
|
tensor = constant_op.constant(array_2x3)
|
||||||
self.assertEquals(shape, tensor.get_shape())
|
self.assertEqual(shape, tensor.get_shape())
|
||||||
self.assertEquals(shape, array_ops.identity(tensor).get_shape())
|
self.assertEqual(shape, array_ops.identity(tensor).get_shape())
|
||||||
self.assertEquals(shape, array_ops.identity(array_2x3).get_shape())
|
self.assertEqual(shape, array_ops.identity(array_2x3).get_shape())
|
||||||
self.assertEquals(shape,
|
self.assertEqual(shape,
|
||||||
array_ops.identity(np.array(array_2x3)).get_shape())
|
array_ops.identity(np.array(array_2x3)).get_shape())
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
def testRefIdentityShape(self):
|
def testRefIdentityShape(self):
|
||||||
@ -69,8 +70,15 @@ class IdentityOpTest(test.TestCase):
|
|||||||
tensor = variables.VariableV1(
|
tensor = variables.VariableV1(
|
||||||
constant_op.constant(
|
constant_op.constant(
|
||||||
[[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
|
[[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
|
||||||
self.assertEquals(shape, tensor.get_shape())
|
self.assertEqual(shape, tensor.get_shape())
|
||||||
self.assertEquals(shape, gen_array_ops.ref_identity(tensor).get_shape())
|
self.assertEqual(shape, gen_array_ops.ref_identity(tensor).get_shape())
|
||||||
|
|
||||||
|
def testCompositeTensor(self):
|
||||||
|
original = sparse_tensor.SparseTensor([[3]], [1.0], [100])
|
||||||
|
copied = array_ops.identity(original)
|
||||||
|
self.assertAllEqual(original.indices, copied.indices)
|
||||||
|
self.assertAllEqual(original.values, copied.values)
|
||||||
|
self.assertAllEqual(original.dense_shape, copied.dense_shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -25,6 +25,7 @@ import six
|
|||||||
from tensorflow.python.compat import compat
|
from tensorflow.python.compat import compat
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import common_shapes
|
from tensorflow.python.framework import common_shapes
|
||||||
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -194,6 +195,8 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
|
|||||||
Returns:
|
Returns:
|
||||||
A `Tensor`. Has the same type as `input`.
|
A `Tensor`. Has the same type as `input`.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(input, composite_tensor.CompositeTensor):
|
||||||
|
return nest.map_structure(identity, input, expand_composites=True)
|
||||||
if context.executing_eagerly() and not hasattr(input, "graph"):
|
if context.executing_eagerly() and not hasattr(input, "graph"):
|
||||||
# Make sure we get an input with handle data attached from resource
|
# Make sure we get an input with handle data attached from resource
|
||||||
# variables. Variables have correct handle data when graph building.
|
# variables. Variables have correct handle data when graph building.
|
||||||
|
Loading…
Reference in New Issue
Block a user