Update np.data.Dataset.as_numpy_iterator to support ragged tensors.
PiperOrigin-RevId: 323645392 Change-Id: Id80daa55a676bfd523f23f72394acb79f5088fd5
This commit is contained in:
parent
5f8e3e8d54
commit
69ca56e7f4
@ -27,7 +27,7 @@ from tensorflow.python.data.ops import dataset_ops
|
|||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -74,9 +74,11 @@ class AsNumpyIteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testRaggedElement(self):
|
def testRaggedElement(self):
|
||||||
self._testInvalidElement(
|
lst = [[1, 2], [3], [4, 5, 6]]
|
||||||
ragged_tensor_value.RaggedTensorValue(
|
rt = ragged_factory_ops.constant(lst)
|
||||||
np.array([0, 1, 2]), np.array([0, 1, 3], dtype=np.int64)))
|
ds = dataset_ops.Dataset.from_tensor_slices(rt)
|
||||||
|
for actual, expected in zip(ds.as_numpy_iterator(), lst):
|
||||||
|
self.assertTrue(np.array_equal(actual, expected))
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDatasetElement(self):
|
def testDatasetElement(self):
|
||||||
|
@ -67,6 +67,7 @@ from tensorflow.python.ops import gen_io_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.training.tracking import base as tracking_base
|
from tensorflow.python.training.tracking import base as tracking_base
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
@ -522,7 +523,9 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
|||||||
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
|
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
|
||||||
"functions")
|
"functions")
|
||||||
for component_spec in nest.flatten(self.element_spec):
|
for component_spec in nest.flatten(self.element_spec):
|
||||||
if not isinstance(component_spec, tensor_spec.TensorSpec):
|
if not isinstance(
|
||||||
|
component_spec,
|
||||||
|
(tensor_spec.TensorSpec, ragged_tensor.RaggedTensorSpec)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Dataset.as_numpy_iterator() does not support datasets containing "
|
"Dataset.as_numpy_iterator() does not support datasets containing "
|
||||||
+ str(component_spec.value_type))
|
+ str(component_spec.value_type))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user