Update np.data.Dataset.as_numpy_iterator to support ragged tensors.

PiperOrigin-RevId: 323645392
Change-Id: Id80daa55a676bfd523f23f72394acb79f5088fd5
This commit is contained in:
Edward Loper 2020-07-28 13:47:45 -07:00 committed by TensorFlower Gardener
parent 5f8e3e8d54
commit 69ca56e7f4
2 changed files with 10 additions and 5 deletions

View File

@ -27,7 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
@ -74,9 +74,11 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations())
def testRaggedElement(self):
self._testInvalidElement(
ragged_tensor_value.RaggedTensorValue(
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
lst = [[1, 2], [3], [4, 5, 6]]
rt = ragged_factory_ops.constant(lst)
ds = dataset_ops.Dataset.from_tensor_slices(rt)
for actual, expected in zip(ds.as_numpy_iterator(), lst):
self.assertTrue(np.array_equal(actual, expected))
@combinations.generate(test_base.eager_only_combinations())
def testDatasetElement(self):

View File

@ -67,6 +67,7 @@ from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.training.tracking import base as tracking_base
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import deprecation
@ -522,7 +523,9 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
"functions")
for component_spec in nest.flatten(self.element_spec):
if not isinstance(component_spec, tensor_spec.TensorSpec):
if not isinstance(
component_spec,
(tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)):
raise TypeError(
"Dataset.as_numpy_iterator() does not support datasets containing "
+ str(component_spec.value_type))