Add test for get_tensor_from_tensor_info with RaggedTensor.

PiperOrigin-RevId: 305904499
Change-Id: I090146c24cad2bc257a87911d58c12e51e3bfa27
This commit is contained in:
Edward Loper 2020-04-10 10:51:46 -07:00 committed by TensorFlower Gardener
parent c179055de4
commit d9039d350a

View File

@ -30,6 +30,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import utils
@ -129,6 +130,15 @@ class UtilsTest(test.TestCase):
self.assertEqual(expected.indices.name, actual.indices.name)
self.assertEqual(expected.dense_shape.name, actual.dense_shape.name)
@test_util.run_v1_only("b/120545219")
def testGetTensorFromInfoRagged(self):
expected = ragged_factory_ops.constant([[1, 2], [3]], name="x")
tensor_info = utils.build_tensor_info(expected)
actual = utils.get_tensor_from_tensor_info(tensor_info)
self.assertIsInstance(actual, ragged_tensor.RaggedTensor)
self.assertEqual(expected.values.name, actual.values.name)
self.assertEqual(expected.row_splits.name, actual.row_splits.name)
def testGetTensorFromInfoInOtherGraph(self):
with ops.Graph().as_default() as expected_graph:
expected = array_ops.placeholder(dtypes.float32, 1, name="right")