From bbefe66945752b49776367b4d090a896d4672d1f Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Mon, 19 Oct 2020 20:34:35 -0700 Subject: [PATCH] Fix the data_adapter for dataset.Iterator. Currently both Generator and CompositeTensor handler could handle it, which cause error like https://github.com/tensorflow/tensorflow/pull/43874. PiperOrigin-RevId: 337987774 Change-Id: I706079fbe57e0e87687ceeb10e14e265a754e08e --- .../python/keras/engine/data_adapter.py | 7 +++++-- .../python/keras/engine/data_adapter_test.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 2cc6f69403e..6afe1840458 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -31,6 +31,7 @@ import six from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import distribute_options from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import input_lib from tensorflow.python.eager import context @@ -523,9 +524,11 @@ class CompositeTensorDataAdapter(DataAdapter): flat_inputs += nest.flatten(y) def _is_composite(v): - # Dataset inherits from CompositeTensor but shouldn't be handled here. + # Dataset/iterator inherits from CompositeTensor but should be handled + # by DatasetAdapter and GeneratorAdapter. if (tf_utils.is_extension_type(v) and - not isinstance(v, dataset_ops.DatasetV2)): + not isinstance(v, (dataset_ops.DatasetV2, + iterator_ops.IteratorBase))): return True # Support Scipy sparse tensors if scipy is installed if scipy_sparse is not None and scipy_sparse.issparse(v): diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 9ca63ec42f0..59613439bf9 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -953,6 +953,25 @@ class DataHandlerTest(keras_parameterized.TestCase): self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], [([0],), ([1],), ([2],)]]) + def test_iterator(self): + def generator(): + for _ in range(2): + for step in range(3): + yield (ops.convert_to_tensor_v2_with_dispatch([step]),) + + it = iter(dataset_ops.Dataset.from_generator( + generator, output_types=('float32',))) + data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3) + returned_data = [] + for _, iterator in data_handler.enumerate_epochs(): + epoch_data = [] + for _ in data_handler.steps(): + epoch_data.append(next(iterator)) + returned_data.append(epoch_data) + returned_data = self.evaluate(returned_data) + self.assertEqual(returned_data, [[([0],), ([1],), ([2],)], + [([0],), ([1],), ([2],)]]) + def test_list_of_scalars(self): data_handler = data_adapter.DataHandler([[0], [1], [2]], epochs=2,