Fix the data_adapter for dataset.Iterator.

Currently both Generator and CompositeTensor handler could handle it, which cause error like https://github.com/tensorflow/tensorflow/pull/43874.

PiperOrigin-RevId: 337987774
Change-Id: I706079fbe57e0e87687ceeb10e14e265a754e08e
This commit is contained in:
Scott Zhu 2020-10-19 20:34:35 -07:00 committed by TensorFlower Gardener
parent ecd5184dd2
commit bbefe66945
2 changed files with 24 additions and 2 deletions

View File

@ -31,6 +31,7 @@ import six
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.eager import context
@ -523,9 +524,11 @@ class CompositeTensorDataAdapter(DataAdapter):
flat_inputs += nest.flatten(y)
def _is_composite(v):
# Dataset inherits from CompositeTensor but shouldn't be handled here.
# Dataset/iterator inherits from CompositeTensor but should be handled
# by DatasetAdapter and GeneratorAdapter.
if (tf_utils.is_extension_type(v) and
not isinstance(v, dataset_ops.DatasetV2)):
not isinstance(v, (dataset_ops.DatasetV2,
iterator_ops.IteratorBase))):
return True
# Support Scipy sparse tensors if scipy is installed
if scipy_sparse is not None and scipy_sparse.issparse(v):

View File

@ -953,6 +953,25 @@ class DataHandlerTest(keras_parameterized.TestCase):
self.assertEqual(returned_data, [[([0],), ([1],),
([2],)], [([0],), ([1],), ([2],)]])
def test_iterator(self):
def generator():
for _ in range(2):
for step in range(3):
yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
it = iter(dataset_ops.Dataset.from_generator(
generator, output_types=('float32',)))
data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3)
returned_data = []
for _, iterator in data_handler.enumerate_epochs():
epoch_data = []
for _ in data_handler.steps():
epoch_data.append(next(iterator))
returned_data.append(epoch_data)
returned_data = self.evaluate(returned_data)
self.assertEqual(returned_data, [[([0],), ([1],), ([2],)],
[([0],), ([1],), ([2],)]])
def test_list_of_scalars(self):
data_handler = data_adapter.DataHandler([[0], [1], [2]],
epochs=2,