Add more tests for tf.is_tensor.

PiperOrigin-RevId: 304644225
Change-Id: I75e507a8592258f98f662fce8996110b9a666fc5
This commit is contained in:
Edward Loper 2020-04-03 10:29:54 -07:00 committed by TensorFlower Gardener
parent 9eb54c3c64
commit da1caca48b

View File

@ -26,6 +26,8 @@ import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph from tensorflow.python.framework import func_graph
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -33,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -759,15 +762,39 @@ class TensorUtilTest(test.TestCase):
self.assertFalse(tensor_util.ShapeEquals(t, [4])) self.assertFalse(tensor_util.ShapeEquals(t, [4]))
@test_util.run_all_in_graph_and_eager_modes
class IsTensorTest(test.TestCase): class IsTensorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testConstantTensor(self): def testConstantTensor(self):
np_val = np.random.rand(3).astype(np.int32) np_val = np.random.rand(3).astype(np.int32)
tf_val = constant_op.constant(np_val) tf_val = constant_op.constant(np_val)
self.assertFalse(tensor_util.is_tensor(np_val)) self.assertFalse(tensor_util.is_tensor(np_val))
self.assertTrue(tensor_util.is_tensor(tf_val)) self.assertTrue(tensor_util.is_tensor(tf_val))
def testRaggedTensor(self):
rt = ragged_factory_ops.constant([[1, 2], [3]])
rt_value = self.evaluate(rt)
self.assertTrue(tensor_util.is_tensor(rt))
self.assertFalse(tensor_util.is_tensor(rt_value))
def testSparseTensor(self):
st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10])
st_value = self.evaluate(st)
self.assertTrue(tensor_util.is_tensor(st))
self.assertFalse(tensor_util.is_tensor(st_value))
def testIndexedSlices(self):
x = indexed_slices.IndexedSlices(
constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30]))
x_value = indexed_slices.IndexedSlicesValue(
np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100]))
self.assertTrue(tensor_util.is_tensor(x))
self.assertFalse(tensor_util.is_tensor(x_value))
def testVariable(self):
v = variables.Variable([1, 2, 3])
self.assertTrue(tensor_util.is_tensor(v))
class ConstantValueTest(test.TestCase): class ConstantValueTest(test.TestCase):