Add more tests for tf.is_tensor.
PiperOrigin-RevId: 304644225 Change-Id: I75e507a8592258f98f662fce8996110b9a666fc5
This commit is contained in:
parent
9eb54c3c64
commit
da1caca48b
@ -26,6 +26,8 @@ import numpy as np
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
|
from tensorflow.python.framework import indexed_slices
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -33,6 +35,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import gen_state_ops
|
from tensorflow.python.ops import gen_state_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -759,15 +762,39 @@ class TensorUtilTest(test.TestCase):
|
|||||||
self.assertFalse(tensor_util.ShapeEquals(t, [4]))
|
self.assertFalse(tensor_util.ShapeEquals(t, [4]))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class IsTensorTest(test.TestCase):
|
class IsTensorTest(test.TestCase):
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
|
||||||
def testConstantTensor(self):
|
def testConstantTensor(self):
|
||||||
np_val = np.random.rand(3).astype(np.int32)
|
np_val = np.random.rand(3).astype(np.int32)
|
||||||
tf_val = constant_op.constant(np_val)
|
tf_val = constant_op.constant(np_val)
|
||||||
self.assertFalse(tensor_util.is_tensor(np_val))
|
self.assertFalse(tensor_util.is_tensor(np_val))
|
||||||
self.assertTrue(tensor_util.is_tensor(tf_val))
|
self.assertTrue(tensor_util.is_tensor(tf_val))
|
||||||
|
|
||||||
|
def testRaggedTensor(self):
|
||||||
|
rt = ragged_factory_ops.constant([[1, 2], [3]])
|
||||||
|
rt_value = self.evaluate(rt)
|
||||||
|
self.assertTrue(tensor_util.is_tensor(rt))
|
||||||
|
self.assertFalse(tensor_util.is_tensor(rt_value))
|
||||||
|
|
||||||
|
def testSparseTensor(self):
|
||||||
|
st = sparse_tensor.SparseTensor([[1, 2]], [3], [10, 10])
|
||||||
|
st_value = self.evaluate(st)
|
||||||
|
self.assertTrue(tensor_util.is_tensor(st))
|
||||||
|
self.assertFalse(tensor_util.is_tensor(st_value))
|
||||||
|
|
||||||
|
def testIndexedSlices(self):
|
||||||
|
x = indexed_slices.IndexedSlices(
|
||||||
|
constant_op.constant([1, 2, 3]), constant_op.constant([10, 20, 30]))
|
||||||
|
x_value = indexed_slices.IndexedSlicesValue(
|
||||||
|
np.array([1, 2, 3]), np.array([10, 20, 30]), np.array([100]))
|
||||||
|
self.assertTrue(tensor_util.is_tensor(x))
|
||||||
|
self.assertFalse(tensor_util.is_tensor(x_value))
|
||||||
|
|
||||||
|
def testVariable(self):
|
||||||
|
v = variables.Variable([1, 2, 3])
|
||||||
|
self.assertTrue(tensor_util.is_tensor(v))
|
||||||
|
|
||||||
|
|
||||||
class ConstantValueTest(test.TestCase):
|
class ConstantValueTest(test.TestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user