Add test for get_tensor_from_tensor_info with RaggedTensor.
PiperOrigin-RevId: 305904499 Change-Id: I090146c24cad2bc257a87911d58c12e51e3bfa27
This commit is contained in:
parent
c179055de4
commit
d9039d350a
@ -30,6 +30,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import nested_structure_coder
|
||||
from tensorflow.python.saved_model import utils
|
||||
@ -129,6 +130,15 @@ class UtilsTest(test.TestCase):
|
||||
self.assertEqual(expected.indices.name, actual.indices.name)
|
||||
self.assertEqual(expected.dense_shape.name, actual.dense_shape.name)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testGetTensorFromInfoRagged(self):
|
||||
expected = ragged_factory_ops.constant([[1, 2], [3]], name="x")
|
||||
tensor_info = utils.build_tensor_info(expected)
|
||||
actual = utils.get_tensor_from_tensor_info(tensor_info)
|
||||
self.assertIsInstance(actual, ragged_tensor.RaggedTensor)
|
||||
self.assertEqual(expected.values.name, actual.values.name)
|
||||
self.assertEqual(expected.row_splits.name, actual.row_splits.name)
|
||||
|
||||
def testGetTensorFromInfoInOtherGraph(self):
|
||||
with ops.Graph().as_default() as expected_graph:
|
||||
expected = array_ops.placeholder(dtypes.float32, 1, name="right")
|
||||
|
Loading…
Reference in New Issue
Block a user