Update np.data.Dataset.as_numpy_iterator to support ragged tensors.
PiperOrigin-RevId: 323645392 Change-Id: Id80daa55a676bfd523f23f72394acb79f5088fd5
This commit is contained in:
parent
5f8e3e8d54
commit
69ca56e7f4
@ -27,7 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -74,9 +74,11 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testRaggedElement(self):
|
||||
self._testInvalidElement(
|
||||
ragged_tensor_value.RaggedTensorValue(
|
||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
|
||||
lst = [[1, 2], [3], [4, 5, 6]]
|
||||
rt = ragged_factory_ops.constant(lst)
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(rt)
|
||||
for actual, expected in zip(ds.as_numpy_iterator(), lst):
|
||||
self.assertTrue(np.array_equal(actual, expected))
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testDatasetElement(self):
|
||||
|
@ -67,6 +67,7 @@ from tensorflow.python.ops import gen_io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.training.tracking import base as tracking_base
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -522,7 +523,9 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
|
||||
"functions")
|
||||
for component_spec in nest.flatten(self.element_spec):
|
||||
if not isinstance(component_spec, tensor_spec.TensorSpec):
|
||||
if not isinstance(
|
||||
component_spec,
|
||||
(tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)):
|
||||
raise TypeError(
|
||||
"Dataset.as_numpy_iterator() does not support datasets containing "
|
||||
+ str(component_spec.value_type))
|
||||
|
Loading…
x
Reference in New Issue
Block a user