Fix the data_adapter for dataset.Iterator.
Currently both Generator and CompositeTensor handler could handle it, which cause error like https://github.com/tensorflow/tensorflow/pull/43874. PiperOrigin-RevId: 337987774 Change-Id: I706079fbe57e0e87687ceeb10e14e265a754e08e
This commit is contained in:
parent
ecd5184dd2
commit
bbefe66945
@ -31,6 +31,7 @@ import six
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||
from tensorflow.python.distribute import input_lib
|
||||
from tensorflow.python.eager import context
|
||||
@ -523,9 +524,11 @@ class CompositeTensorDataAdapter(DataAdapter):
|
||||
flat_inputs += nest.flatten(y)
|
||||
|
||||
def _is_composite(v):
|
||||
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
||||
# Dataset/iterator inherits from CompositeTensor but should be handled
|
||||
# by DatasetAdapter and GeneratorAdapter.
|
||||
if (tf_utils.is_extension_type(v) and
|
||||
not isinstance(v, dataset_ops.DatasetV2)):
|
||||
not isinstance(v, (dataset_ops.DatasetV2,
|
||||
iterator_ops.IteratorBase))):
|
||||
return True
|
||||
# Support Scipy sparse tensors if scipy is installed
|
||||
if scipy_sparse is not None and scipy_sparse.issparse(v):
|
||||
|
@ -953,6 +953,25 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(returned_data, [[([0],), ([1],),
|
||||
([2],)], [([0],), ([1],), ([2],)]])
|
||||
|
||||
def test_iterator(self):
|
||||
def generator():
|
||||
for _ in range(2):
|
||||
for step in range(3):
|
||||
yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
|
||||
|
||||
it = iter(dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=('float32',)))
|
||||
data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3)
|
||||
returned_data = []
|
||||
for _, iterator in data_handler.enumerate_epochs():
|
||||
epoch_data = []
|
||||
for _ in data_handler.steps():
|
||||
epoch_data.append(next(iterator))
|
||||
returned_data.append(epoch_data)
|
||||
returned_data = self.evaluate(returned_data)
|
||||
self.assertEqual(returned_data, [[([0],), ([1],), ([2],)],
|
||||
[([0],), ([1],), ([2],)]])
|
||||
|
||||
def test_list_of_scalars(self):
|
||||
data_handler = data_adapter.DataHandler([[0], [1], [2]],
|
||||
epochs=2,
|
||||
|
Loading…
x
Reference in New Issue
Block a user