From da1caca48bf9df43a5370c89833de1c318f0997f Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Fri, 3 Apr 2020 10:29:54 -0700 Subject: [PATCH] Add more tests for tf.is_tensor. PiperOrigin-RevId: 304644225 Change-Id: I75e507a8592258f98f662fce8996110b9a666fc5 --- .../python/framework/tensor_util_test.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index b2ab779386b..35b5c0a3b1e 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -26,6 +26,8 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph +from tensorflow.python.framework import indexed_slices +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_util @@ -33,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -759,15 +762,39 @@ class TensorUtilTest(test.TestCase): self.assertFalse(tensor_util.ShapeEquals(t, [4])) +@test_util.run_all_in_graph_and_eager_modes class IsTensorTest(test.TestCase): - @test_util.run_in_graph_and_eager_modes def testConstantTensor(self): np_val = np.random.rand(3).astype(np.int32) tf_val = constant_op.constant(np_val) self.assertFalse(tensor_util.is_tensor(np_val)) self.assertTrue(tensor_util.is_tensor(tf_val)) + def testRaggedTensor(self): + rt = ragged_factory_ops.constant([[1, 2], [3]]) + rt_value = self.evaluate(rt) + self.assertTrue(tensor_util.is_tensor(rt)) + self.assertFalse(tensor_util.is_tensor(rt_value)) + + def testSparseTensor(self): + st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10]) + st_value = self.evaluate(st) + self.assertTrue(tensor_util.is_tensor(st)) + self.assertFalse(tensor_util.is_tensor(st_value)) + + def testIndexedSlices(self): + x = indexed_slices.IndexedSlices( + constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30])) + x_value = indexed_slices.IndexedSlicesValue( + np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100])) + self.assertTrue(tensor_util.is_tensor(x)) + self.assertFalse(tensor_util.is_tensor(x_value)) + + def testVariable(self): + v = variables.Variable([1, 2, 3]) + self.assertTrue(tensor_util.is_tensor(v)) + class ConstantValueTest(test.TestCase):