Fix the data_adapter for dataset.Iterator.
Currently both Generator and CompositeTensor handler could handle it, which cause error like https://github.com/tensorflow/tensorflow/pull/43874. PiperOrigin-RevId: 337987774 Change-Id: I706079fbe57e0e87687ceeb10e14e265a754e08e
This commit is contained in:
parent
ecd5184dd2
commit
bbefe66945
@ -31,6 +31,7 @@ import six
|
|||||||
from tensorflow.python.data.experimental.ops import cardinality
|
from tensorflow.python.data.experimental.ops import cardinality
|
||||||
from tensorflow.python.data.experimental.ops import distribute_options
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
from tensorflow.python.distribute import distribution_strategy_context as ds_context
|
||||||
from tensorflow.python.distribute import input_lib
|
from tensorflow.python.distribute import input_lib
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -523,9 +524,11 @@ class CompositeTensorDataAdapter(DataAdapter):
|
|||||||
flat_inputs += nest.flatten(y)
|
flat_inputs += nest.flatten(y)
|
||||||
|
|
||||||
def _is_composite(v):
|
def _is_composite(v):
|
||||||
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
# Dataset/iterator inherits from CompositeTensor but should be handled
|
||||||
|
# by DatasetAdapter and GeneratorAdapter.
|
||||||
if (tf_utils.is_extension_type(v) and
|
if (tf_utils.is_extension_type(v) and
|
||||||
not isinstance(v, dataset_ops.DatasetV2)):
|
not isinstance(v, (dataset_ops.DatasetV2,
|
||||||
|
iterator_ops.IteratorBase))):
|
||||||
return True
|
return True
|
||||||
# Support Scipy sparse tensors if scipy is installed
|
# Support Scipy sparse tensors if scipy is installed
|
||||||
if scipy_sparse is not None and scipy_sparse.issparse(v):
|
if scipy_sparse is not None and scipy_sparse.issparse(v):
|
||||||
|
@ -953,6 +953,25 @@ class DataHandlerTest(keras_parameterized.TestCase):
|
|||||||
self.assertEqual(returned_data, [[([0],), ([1],),
|
self.assertEqual(returned_data, [[([0],), ([1],),
|
||||||
([2],)], [([0],), ([1],), ([2],)]])
|
([2],)], [([0],), ([1],), ([2],)]])
|
||||||
|
|
||||||
|
def test_iterator(self):
|
||||||
|
def generator():
|
||||||
|
for _ in range(2):
|
||||||
|
for step in range(3):
|
||||||
|
yield (ops.convert_to_tensor_v2_with_dispatch([step]),)
|
||||||
|
|
||||||
|
it = iter(dataset_ops.Dataset.from_generator(
|
||||||
|
generator, output_types=('float32',)))
|
||||||
|
data_handler = data_adapter.DataHandler(it, epochs=2, steps_per_epoch=3)
|
||||||
|
returned_data = []
|
||||||
|
for _, iterator in data_handler.enumerate_epochs():
|
||||||
|
epoch_data = []
|
||||||
|
for _ in data_handler.steps():
|
||||||
|
epoch_data.append(next(iterator))
|
||||||
|
returned_data.append(epoch_data)
|
||||||
|
returned_data = self.evaluate(returned_data)
|
||||||
|
self.assertEqual(returned_data, [[([0],), ([1],), ([2],)],
|
||||||
|
[([0],), ([1],), ([2],)]])
|
||||||
|
|
||||||
def test_list_of_scalars(self):
|
def test_list_of_scalars(self):
|
||||||
data_handler = data_adapter.DataHandler([[0], [1], [2]],
|
data_handler = data_adapter.DataHandler([[0], [1], [2]],
|
||||||
epochs=2,
|
epochs=2,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user