Disallow dataset iterators in Keras fit, predict, and evaluate.

PiperOrigin-RevId: 259071119
This commit is contained in:
A. Unique TensorFlower 2019-07-19 17:58:46 -07:00 committed by TensorFlower Gardener
parent 6dcc61a0ae
commit 229dae116a
7 changed files with 65 additions and 167 deletions

View File

@ -153,9 +153,8 @@ class TestSequential(keras_parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch)
model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch)
self.assertTrue(model.built)
self.assertEqual(len(model.weights), 2 * 2)
self.assertFalse(model._is_graph_network)

View File

@ -465,11 +465,21 @@ class Model(network.Network):
def _select_training_loop(self, inputs):
"""Select training loop for fit/eval/predict based on the inputs."""
# TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely
# integrated into the data adapters in the v2 loop. We can't do this yet
# because we currently have to fall back for unhandled data types.
if isinstance(inputs, (iterator_ops.Iterator,
iterator_ops.IteratorV2)):
raise ValueError('For performance reasons Keras `fit`, `evaluate` and'
'`predict` accept tf.data `Datasets` as input but not '
'iterators that have been manually generated from '
'Datasets by users. Please directly pass in the '
'original `Dataset` object instead of passing in '
'`iter(dataset)`.')
# Experiment training loop with default DS path.
if (context.executing_eagerly()
and self._run_distributed
and not isinstance(inputs, (iterator_ops.Iterator,
iterator_ops.IteratorV2))
# TODO(scottzhu): Finish getting sequences working with the v2 loops.
and not isinstance(inputs, (data_utils.Sequence))
and not distributed_training_utils.is_tpu_strategy(
@ -535,7 +545,7 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator. Should return a tuple
- A `tf.data` dataset. Should return a tuple
of either `(inputs, targets)` or
`(inputs, targets, sample_weights)`.
- A generator or `keras.utils.Sequence` returning `(inputs, targets)`
@ -543,14 +553,14 @@ class Model(network.Network):
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset, dataset
iterator, generator, or `keras.utils.Sequence` instance, `y` should
tensor targets, or inversely). If `x` is a dataset, generator,
or `keras.utils.Sequence` instance, `y` should
not be specified (since targets will be obtained from `x`).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` if your data is in the
form of symbolic tensors, dataset, dataset iterators,
form of symbolic tensors, datasets,
generators, or `keras.utils.Sequence` instances (since they generate
batches).
epochs: Integer. Number of epochs to train the model.
@ -577,7 +587,7 @@ class Model(network.Network):
on this data at the end of each epoch.
The validation data is selected from the last samples
in the `x` and `y` data provided, before shuffling. This argument is
not supported when `x` is a dataset, dataset iterator, generator or
not supported when `x` is a dataset, generator or
`keras.utils.Sequence` instance.
validation_data: Data on which to evaluate
the loss and any model metrics at the end of each epoch.
@ -586,7 +596,7 @@ class Model(network.Network):
`validation_data` could be:
- tuple `(x_val, y_val)` of Numpy arrays or tensors
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
- dataset or a dataset iterator
- dataset
For the first two cases, `batch_size` must be provided.
For the last case, `validation_steps` must be provided.
shuffle: Boolean (whether to shuffle the training data
@ -611,7 +621,7 @@ class Model(network.Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
supported when `x` is a dataset, dataset iterator, generator, or
supported when `x` is a dataset, generator, or
`keras.utils.Sequence` instance, instead provide the sample_weights
as the third element of `x`.
initial_epoch: Integer.
@ -624,14 +634,14 @@ class Model(network.Network):
TensorFlow data tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined. If x is a
`tf.data` dataset or a dataset iterator, and 'steps_per_epoch'
`tf.data` dataset, and 'steps_per_epoch'
is None, the epoch will run until the input dataset is exhausted.
This argument is not supported with array inputs.
validation_steps: Only relevant if `validation_data` is provided and
is a dataset or dataset iterator. Total number of steps (batches of
is a `tf.data` dataset. Total number of steps (batches of
samples) to draw before stopping when performing validation
at the end of every epoch. If validation_data is a `tf.data` dataset
or a dataset iterator, and 'validation_steps' is None, validation
and 'validation_steps' is None, validation
will run until the `validation_data` dataset is exhausted.
validation_freq: Only relevant if validation data is provided. Integer
or `collections.Container` instance (e.g. list, tuple, etc.). If an
@ -722,20 +732,20 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
- A generator or `keras.utils.Sequence` instance.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely).
If `x` is a dataset, dataset iterator, generator or
If `x` is a dataset, generator or
`keras.utils.Sequence` instance, `y` should not be specified (since
targets will be obtained from the iterator/dataset).
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
form of symbolic tensors, dataset, dataset iterators,
form of symbolic tensors, dataset,
generators, or `keras.utils.Sequence` instances (since they generate
batches).
verbose: 0 or 1. Verbosity mode.
@ -751,13 +761,13 @@ class Model(network.Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
`sample_weight_mode="temporal"` in `compile()`. This argument is not
supported when `x` is a dataset or a dataset iterator, instead pass
supported when `x` is a dataset, instead pass
sample weights as the third element of `x`.
steps: Integer or `None`.
Total number of steps (batches of samples)
before declaring the evaluation round finished.
Ignored with the default value of `None`.
If x is a `tf.data` dataset or a dataset iterator, and `steps` is
If x is a `tf.data` dataset and `steps` is
None, 'evaluate' will run until the dataset is exhausted.
This argument is not supported with array inputs.
callbacks: List of `keras.callbacks.Callback` instances.
@ -822,20 +832,20 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
- A generator or `keras.utils.Sequence` instance.
batch_size: Integer or `None`.
Number of samples per gradient update.
If unspecified, `batch_size` will default to 32.
Do not specify the `batch_size` is your data is in the
form of symbolic tensors, dataset, dataset iterators,
form of symbolic tensors, dataset,
generators, or `keras.utils.Sequence` instances (since they generate
batches).
verbose: Verbosity mode, 0 or 1.
steps: Total number of steps (batches of samples)
before declaring the prediction round finished.
Ignored with the default value of `None`. If x is a `tf.data`
dataset or a dataset iterator, and `steps` is None, `predict` will
dataset and `steps` is None, `predict` will
run until the input dataset is exhausted.
callbacks: List of `keras.callbacks.Callback` instances.
List of callbacks to apply during prediction.
@ -904,11 +914,11 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
y: Target data. Like the input data `x`, it could be either Numpy
array(s) or TensorFlow tensor(s). It should be consistent with `x`
(you cannot have Numpy inputs and tensor targets, or inversely). If
`x` is a dataset or a dataset iterator, `y` should not be specified
`x` is a dataset, `y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample. In the case of
@ -916,7 +926,7 @@ class Model(network.Network):
sequence_length), to apply a different weight to every timestep of
every sample. In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
supported when `x` is a dataset or a dataset iterator.
supported when `x` is a dataset.
class_weight: Optional dictionary mapping class indices (integers) to a
weight (float) to apply to the model's loss for the samples from this
class during training. This can be useful to tell the model to "pay
@ -993,13 +1003,12 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset or a
dataset iterator, `y` should not be specified
(since targets will be obtained from the iterator).
tensor targets, or inversely). If `x` is a dataset `y` should
not be specified (since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
In the case of temporal data, you can pass a 2D array
@ -1007,7 +1016,7 @@ class Model(network.Network):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
supported when `x` is a dataset or a dataset iterator.
supported when `x` is a dataset.
reset_metrics: If `True`, the metrics returned will be only for this
batch. If `False`, the metrics will be statefully accumulated across
batches.
@ -1068,7 +1077,7 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
Returns:
Numpy array(s) of predictions.
@ -2221,13 +2230,12 @@ class Model(network.Network):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset or a
dataset iterator, `y` should not be specified
(since targets will be obtained from the iterator).
tensor targets, or inversely). If `x` is a dataset, `y` should not be
specified (since targets will be obtained from the iterator).
sample_weight: An optional sample-weight array passed by the user to
weight the importance of each sample in `x`.
class_weight: An optional class-weight array by the user to

View File

@ -47,100 +47,6 @@ class BatchCounterCallback(callbacks.Callback):
self.batch_count += 1
class TestTrainingWithDatasetIterators(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_training_and_eval_methods_on_iterators_single_io(self):
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
optimizer = 'rmsprop'
loss = 'mse'
metrics = ['mae', metrics_module.CategoricalAccuracy()]
model.compile(
optimizer,
loss,
metrics=metrics,
run_eagerly=testing_utils.should_run_eagerly(),
run_distributed=testing_utils.should_run_distributed())
inputs = np.zeros((10, 3), np.float32)
targets = np.zeros((10, 4), np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(iterator, steps=2, verbose=1)
model.predict(iterator, steps=2)
# Test with validation data
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
validation_data=iterator, validation_steps=2)
# Test with validation split
with self.assertRaisesRegexp(
ValueError, '`validation_split` argument is not supported when '):
model.fit(iterator,
epochs=1, steps_per_epoch=2, verbose=0,
validation_split=0.5, validation_steps=2)
# Test with sample weight.
sample_weight = np.random.random((10,))
with self.assertRaisesRegexp(
ValueError, '`sample_weight` argument is not supported '
'when input `x` is a dataset or a dataset iterator'):
model.fit(
iterator,
epochs=1,
steps_per_epoch=2,
verbose=0,
sample_weight=sample_weight)
# Test invalid usage
with self.assertRaisesRegexp(ValueError,
'you should not specify a target'):
model.fit(iterator, iterator,
epochs=1, steps_per_epoch=2, verbose=0)
with self.assertRaisesRegexp(
ValueError, 'the `steps_per_epoch` argument'):
model.fit(iterator, epochs=1, verbose=0)
with self.assertRaisesRegexp(ValueError,
'the `steps` argument'):
model.evaluate(iterator, verbose=0)
with self.assertRaisesRegexp(ValueError,
'the `steps` argument'):
model.predict(iterator, verbose=0)
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_iterators_running_out_of_data(self):
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
optimizer = 'rmsprop'
loss = 'mse'
metrics = ['mae']
model.compile(
optimizer,
loss,
metrics=metrics,
run_eagerly=testing_utils.should_run_eagerly(),
run_distributed=testing_utils.should_run_distributed())
inputs = np.zeros((10, 3), np.float32)
targets = np.zeros((10, 4), np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(2)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
with test.mock.patch.object(logging, 'warning') as mock_log:
model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0)
self.assertRegexpMatches(
str(mock_log.call_args),
'dataset iterator ran out of data')
class TestTrainingWithDataset(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@ -618,11 +524,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase):
model.fit(dataset)
class TestMetricsWithDatasetIterators(keras_parameterized.TestCase):
class TestMetricsWithDatasets(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
def test_metrics_correctness_with_iterator(self):
def test_metrics_correctness_with_dataset(self):
layers = [
keras.layers.Dense(8, activation='relu', input_dim=4,
kernel_initializer='ones'),
@ -643,8 +549,7 @@ class TestMetricsWithDatasetIterators(keras_parameterized.TestCase):
y = np.random.randint(2, size=(100, 1)).astype(np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
outs = model.evaluate(iterator, steps=10)
outs = model.evaluate(dataset, steps=10)
self.assertEqual(np.around(outs[1], decimals=1), 0.5)
self.assertEqual(np.around(outs[2], decimals=1), 0.5)
@ -652,8 +557,7 @@ class TestMetricsWithDatasetIterators(keras_parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
outs = model.evaluate(iterator, steps=10)
outs = model.evaluate(dataset, steps=10)
self.assertEqual(outs[1], 0.)
self.assertEqual(outs[2], 0.)

View File

@ -183,30 +183,20 @@ class TrainingTest(keras_parameterized.TestCase):
x = array_ops.zeros(shape=(10, 3))
y = array_ops.zeros(shape=(10, 4))
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5)
iterator = dataset_ops.make_one_shot_iterator(dataset)
validation_dataset = dataset_ops.Dataset.from_tensor_slices(
(x, y)).repeat().batch(5) # Infinite dataset.
validation_iterator = dataset_ops.make_one_shot_iterator(validation_dataset)
with self.assertRaisesRegexp(
ValueError, r'specify .* `steps_per_epoch`'):
model.fit(iterator, epochs=1, verbose=0)
if not context.executing_eagerly():
# In eager execution, `array_ops.zeros` returns value tensors
# which can be used for validation without a `validation_steps` argument.
with self.assertRaisesRegexp(
ValueError, r'provide either `batch_size` or `validation_steps`'):
model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
validation_data=(x, y))
model.fit(dataset, epochs=1, verbose=0)
# Step argument is required for infinite datasets.
with self.assertRaisesRegexp(ValueError,
'specify the `validation_steps` argument.'):
model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0,
validation_data=validation_dataset)
with self.assertRaisesRegexp(ValueError,
'specify the `validation_steps` argument.'):
model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
validation_data=validation_iterator)
model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0,
validation_data=validation_dataset)
# TODO(b/120931266): Enable test on subclassed models after bug causing an
# extra dimension to be added to predict outputs is fixed.
@ -282,8 +272,7 @@ class CorrectnessTest(keras_parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
history = model.fit(dataset, epochs=1, steps_per_epoch=10)
self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)
def test_loss_in_call(self):

View File

@ -859,8 +859,7 @@ class TrainingTest(keras_parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.repeat(10)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
model.fit(iterator, epochs=1, steps_per_epoch=2)
model.fit(dataset, epochs=1, steps_per_epoch=2)
if context.executing_eagerly():
# Test with eager execution
@ -870,7 +869,7 @@ class TrainingTest(keras_parameterized.TestCase):
model.fit(x_train, y_train, batch_size=5, epochs=1)
# Test with eager execution and iterator
model.fit(iterator, epochs=1, steps_per_epoch=2)
model.fit(dataset, epochs=1, steps_per_epoch=2)
def test_losses_in_defun(self):
with context.eager_mode():

View File

@ -178,11 +178,11 @@ def train_on_batch(
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
y: Target data. Like the input data `x`, it could be either Numpy
array(s) or TensorFlow tensor(s). It should be consistent with `x`
(you cannot have Numpy inputs and tensor targets, or inversely). If
`x` is a dataset or a dataset iterator, `y` should not be specified
`x` is a dataset `y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample. In the case of
@ -190,7 +190,7 @@ def train_on_batch(
sequence_length), to apply a different weight to every timestep of
every sample. In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
supported when `x` is a dataset or a dataset iterator.
supported when `x` is a dataset.
class_weight: Optional dictionary mapping class indices (integers) to a
weight (float) to apply to the model's loss for the samples from this
class during training. This can be useful to tell the model to "pay
@ -249,12 +249,12 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
(in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors,
if the model has named inputs.
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with `x` (you cannot have Numpy inputs and
tensor targets, or inversely). If `x` is a dataset or a
dataset iterator, `y` should not be specified
tensor targets, or inversely). If `x` is a dataset,
`y` should not be specified
(since targets will be obtained from the iterator).
sample_weight: Optional array of the same length as x, containing
weights to apply to the model's loss for each sample.
@ -263,7 +263,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True):
to apply a different weight to every timestep of every sample.
In this case you should make sure to specify
sample_weight_mode="temporal" in compile(). This argument is not
supported when `x` is a dataset or a dataset iterator.
supported when `x` is a dataset.
reset_metrics: If `True`, the metrics returned will be only for this
batch. If `False`, the metrics will be statefully accumulated across
batches.
@ -310,7 +310,7 @@ def predict_on_batch(model, x):
(in case the model has multiple inputs).
- A TensorFlow tensor, or a list of tensors
(in case the model has multiple inputs).
- A `tf.data` dataset or a dataset iterator.
- A `tf.data` dataset.
Returns:
Numpy array(s) of predictions.

View File

@ -646,7 +646,7 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
_ = model.evaluate([x1, x2], [y1, y2], verbose=0)
def test_single_io_workflow_with_dataset_iterators(self):
def test_single_io_workflow_with_datasets(self):
num_classes = 2
num_samples = 10
input_dim = 50
@ -664,10 +664,9 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
dataset = dataset.repeat(100)
dataset = dataset.batch(10)
iterator = dataset_ops.make_one_shot_iterator(dataset)
model.fit(iterator, epochs=2, steps_per_epoch=10, verbose=0)
_ = model.evaluate(iterator, steps=10, verbose=0)
model.fit(dataset, epochs=2, steps_per_epoch=10, verbose=0)
_ = model.evaluate(dataset, steps=10, verbose=0)
def test_attributes(self):
# layers, weights, trainable_weights, non_trainable_weights, inputs, outputs