Update keras code to use public API of tf.data.Iterator.

PiperOrigin-RevId: 338734148
Change-Id: Id57880d0cef5cb05a5d13c5894a8a1e9cb37f569
This commit is contained in:
Scott Zhu 2020-10-23 13:25:35 -07:00 committed by TensorFlower Gardener
parent 5380e77987
commit d6f26b5a7f
5 changed files with 10 additions and 10 deletions

View File

@ -184,7 +184,7 @@ def set_callback_parameters(callback_list,
def _is_generator_like(data):
"""Checks if data is a generator, Sequence, or Iterator."""
return (hasattr(data, '__next__') or hasattr(data, 'next') or isinstance(
data, (Sequence, iterator_ops.Iterator, iterator_ops.OwnedIterator)))
data, (Sequence, iterator_ops.Iterator, iterator_ops.IteratorBase)))
def make_logs(model, logs, outputs, mode, prefix=''):

View File

@ -510,7 +510,7 @@ def _prepare_feed_values(model, inputs, targets, sample_weights, mode):
# in Distribution Strategy case as it follows the same code path for both
# eager and graph modes.
# TODO(priyag,omalleyt): Either we should move the training DS with
# OwnedIterator to use training_generator code path, or figure out how to
# IteratorBase to use training_generator code path, or figure out how to
# set a symbolic Iterator out of a Dataset when in eager mode.
if context.executing_eagerly():
return get_distributed_inputs

View File

@ -412,7 +412,7 @@ def _validate_arguments(is_sequence, is_dataset, use_multiprocessing, workers,
val_gen = (
data_utils.is_generator_or_sequence(validation_data) or
isinstance(validation_data, iterator_ops.OwnedIterator))
isinstance(validation_data, iterator_ops.IteratorBase))
if (val_gen and not isinstance(validation_data, data_utils.Sequence) and
not validation_steps):
raise ValueError('Please specify the `validation_steps` argument.')
@ -455,7 +455,7 @@ def convert_to_generator_like(data,
ele for ele in data if not all(e is None for e in nest.flatten(ele)))
if data_utils.is_generator_or_sequence(data) or isinstance(
data, iterator_ops.OwnedIterator):
data, iterator_ops.IteratorBase):
if isinstance(data, data_utils.Sequence):
if steps_per_epoch is None:
steps_per_epoch = len(data)

View File

@ -1194,7 +1194,7 @@ def check_steps_argument(input_data, steps, steps_name):
but not provided.
"""
is_x_iterator = isinstance(
input_data, (iterator_ops.Iterator, iterator_ops.OwnedIterator))
input_data, (iterator_ops.Iterator, iterator_ops.IteratorBase))
if (input_data is None or is_x_iterator or has_symbolic_tensors(input_data) or
(isinstance(input_data, list) and not input_data)):
if steps is None:
@ -1418,7 +1418,7 @@ def is_feature_layer(layer):
def is_eager_dataset_or_iterator(data):
return context.executing_eagerly() and isinstance(
data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
iterator_ops.OwnedIterator))
iterator_ops.IteratorBase))
# pylint: disable=protected-access
@ -1456,7 +1456,7 @@ def verify_dataset_shuffled(x):
def is_dataset_or_iterator(data):
return isinstance(data, (dataset_ops.DatasetV1, dataset_ops.DatasetV2,
iterator_ops.Iterator, iterator_ops.OwnedIterator))
iterator_ops.Iterator, iterator_ops.IteratorBase))
def get_iterator(dataset):
@ -1741,7 +1741,7 @@ def unpack_validation_data(validation_data, raise_if_ambiguous=True):
tuple of 3, (x, y, sample_weights) for numpy and tensor input.
"""
if (isinstance(validation_data, (iterator_ops.Iterator,
iterator_ops.OwnedIterator,
iterator_ops.IteratorBase,
dataset_ops.DatasetV2,
data_utils.Sequence))
or not hasattr(validation_data, '__len__')):

View File

@ -576,7 +576,7 @@ class Model(training_lib.Model):
# integrated into the data adapters in the v2 loop. We can't do this yet
# because we currently have to fall back for unhandled data types.
if isinstance(inputs, (iterator_ops.Iterator,
iterator_ops.OwnedIterator)):
iterator_ops.IteratorBase)):
raise ValueError('For performance reasons Keras `fit`, `evaluate` and'
'`predict` accept tf.data `Datasets` as input but not '
'iterators that have been manually generated from '
@ -1742,7 +1742,7 @@ class Model(training_lib.Model):
# Check Dataset/Iterator batch size is consistent with InputLayer.
if isinstance(x, (dataset_ops.DatasetV2, iterator_ops.Iterator,
iterator_ops.OwnedIterator)):
iterator_ops.IteratorBase)):
ds_batch_size = tensor_shape.as_dimension(
nest.flatten(dataset_ops.get_legacy_output_shapes(x))[0][0]).value
if ds_batch_size is not None: