From 69ca56e7f41910bb330f4d353a04f36080e606fe Mon Sep 17 00:00:00 2001 From: Edward Loper Date: Tue, 28 Jul 2020 13:47:45 -0700 Subject: [PATCH] Update np.data.Dataset.as_numpy_iterator to support ragged tensors. PiperOrigin-RevId: 323645392 Change-Id: Id80daa55a676bfd523f23f72394acb79f5088fd5 --- .../python/data/kernel_tests/as_numpy_iterator_test.py | 10 ++++++---- tensorflow/python/data/ops/dataset_ops.py | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py b/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py index ea80389b0a5..a69e49439c4 100644 --- a/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py +++ b/tensorflow/python/data/kernel_tests/as_numpy_iterator_test.py @@ -27,7 +27,7 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops.ragged import ragged_tensor_value +from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.platform import test @@ -74,9 +74,11 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase): @combinations.generate(test_base.eager_only_combinations()) def testRaggedElement(self): - self._testInvalidElement( - ragged_tensor_value.RaggedTensorValue( - np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64))) + lst = [[1, 2], [3], [4, 5, 6]] + rt = ragged_factory_ops.constant(lst) + ds = dataset_ops.Dataset.from_tensor_slices(rt) + for actual, expected in zip(ds.as_numpy_iterator(), lst): + self.assertTrue(np.array_equal(actual, expected)) @combinations.generate(test_base.eager_only_combinations()) def testDatasetElement(self): diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index bd75d0a735a..512cd2db90a 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -67,6 +67,7 @@ from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.training.tracking import base as tracking_base from tensorflow.python.training.tracking import tracking from tensorflow.python.util import deprecation @@ -522,7 +523,9 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, raise RuntimeError("as_numpy_iterator() is not supported while tracing " "functions") for component_spec in nest.flatten(self.element_spec): - if not isinstance(component_spec, tensor_spec.TensorSpec): + if not isinstance( + component_spec, + (tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)): raise TypeError( "Dataset.as_numpy_iterator() does not support datasets containing " + str(component_spec.value_type))