From 229dae116a1e13b9a6286a7a6bf26c5c3ab6bf28 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 19 Jul 2019 17:58:46 -0700 Subject: [PATCH] Disallow dataset iterators in Keras fit, predict, and evaluate. PiperOrigin-RevId: 259071119 --- .../python/keras/engine/sequential_test.py | 3 +- tensorflow/python/keras/engine/training.py | 74 +++++++------ .../keras/engine/training_dataset_test.py | 104 +----------------- .../keras/engine/training_eager_test.py | 23 +--- .../python/keras/engine/training_test.py | 5 +- .../python/keras/engine/training_v2_utils.py | 16 +-- .../python/keras/model_subclassing_test.py | 7 +- 7 files changed, 65 insertions(+), 167 deletions(-) diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 0dca345e117..babb37d6c37 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -153,9 +153,8 @@ class TestSequential(keras_parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch) + model.fit(dataset, epochs=1, steps_per_epoch=steps_per_epoch) self.assertTrue(model.built) self.assertEqual(len(model.weights), 2 * 2) self.assertFalse(model._is_graph_network) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 1fefa5744cd..a415358ff03 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -465,11 +465,21 @@ class Model(network.Network): def _select_training_loop(self, inputs): """Select training loop for fit/eval/predict based on the inputs.""" + # TODO(kaftan) or TODO(scottzhu): This check should eventually be nicely + # integrated into the data adapters in the v2 loop. We can't do this yet + # because we currently have to fall back for unhandled data types. + if isinstance(inputs, (iterator_ops.Iterator, + iterator_ops.IteratorV2)): + raise ValueError('For performance reasons Keras `fit`, `evaluate` and' + '`predict` accept tf.data `Datasets` as input but not ' + 'iterators that have been manually generated from ' + 'Datasets by users. Please directly pass in the ' + 'original `Dataset` object instead of passing in ' + '`iter(dataset)`.') + # Experiment training loop with default DS path. if (context.executing_eagerly() and self._run_distributed - and not isinstance(inputs, (iterator_ops.Iterator, - iterator_ops.IteratorV2)) # TODO(scottzhu): Finish getting sequences working with the v2 loops. and not isinstance(inputs, (data_utils.Sequence)) and not distributed_training_utils.is_tpu_strategy( @@ -535,7 +545,7 @@ class Model(network.Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. Should return a tuple + - A `tf.data` dataset. Should return a tuple of either `(inputs, targets)` or `(inputs, targets, sample_weights)`. - A generator or `keras.utils.Sequence` returning `(inputs, targets)` @@ -543,14 +553,14 @@ class Model(network.Network): y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset, dataset - iterator, generator, or `keras.utils.Sequence` instance, `y` should + tensor targets, or inversely). If `x` is a dataset, generator, + or `keras.utils.Sequence` instance, `y` should not be specified (since targets will be obtained from `x`). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the - form of symbolic tensors, dataset, dataset iterators, + form of symbolic tensors, datasets, generators, or `keras.utils.Sequence` instances (since they generate batches). epochs: Integer. Number of epochs to train the model. @@ -577,7 +587,7 @@ class Model(network.Network): on this data at the end of each epoch. The validation data is selected from the last samples in the `x` and `y` data provided, before shuffling. This argument is - not supported when `x` is a dataset, dataset iterator, generator or + not supported when `x` is a dataset, generator or `keras.utils.Sequence` instance. validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. @@ -586,7 +596,7 @@ class Model(network.Network): `validation_data` could be: - tuple `(x_val, y_val)` of Numpy arrays or tensors - tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays - - dataset or a dataset iterator + - dataset For the first two cases, `batch_size` must be provided. For the last case, `validation_steps` must be provided. shuffle: Boolean (whether to shuffle the training data @@ -611,7 +621,7 @@ class Model(network.Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset, dataset iterator, generator, or + supported when `x` is a dataset, generator, or `keras.utils.Sequence` instance, instead provide the sample_weights as the third element of `x`. initial_epoch: Integer. @@ -624,14 +634,14 @@ class Model(network.Network): TensorFlow data tensors, the default `None` is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. If x is a - `tf.data` dataset or a dataset iterator, and 'steps_per_epoch' + `tf.data` dataset, and 'steps_per_epoch' is None, the epoch will run until the input dataset is exhausted. This argument is not supported with array inputs. validation_steps: Only relevant if `validation_data` is provided and - is a dataset or dataset iterator. Total number of steps (batches of + is a `tf.data` dataset. Total number of steps (batches of samples) to draw before stopping when performing validation at the end of every epoch. If validation_data is a `tf.data` dataset - or a dataset iterator, and 'validation_steps' is None, validation + and 'validation_steps' is None, validation will run until the `validation_data` dataset is exhausted. validation_freq: Only relevant if validation data is provided. Integer or `collections.Container` instance (e.g. list, tuple, etc.). If an @@ -722,20 +732,20 @@ class Model(network.Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. - A generator or `keras.utils.Sequence` instance. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely). - If `x` is a dataset, dataset iterator, generator or + If `x` is a dataset, generator or `keras.utils.Sequence` instance, `y` should not be specified (since targets will be obtained from the iterator/dataset). batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors, dataset, dataset iterators, + form of symbolic tensors, dataset, generators, or `keras.utils.Sequence` instances (since they generate batches). verbose: 0 or 1. Verbosity mode. @@ -751,13 +761,13 @@ class Model(network.Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator, instead pass + supported when `x` is a dataset, instead pass sample weights as the third element of `x`. steps: Integer or `None`. Total number of steps (batches of samples) before declaring the evaluation round finished. Ignored with the default value of `None`. - If x is a `tf.data` dataset or a dataset iterator, and `steps` is + If x is a `tf.data` dataset and `steps` is None, 'evaluate' will run until the dataset is exhausted. This argument is not supported with array inputs. callbacks: List of `keras.callbacks.Callback` instances. @@ -822,20 +832,20 @@ class Model(network.Network): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. - A generator or `keras.utils.Sequence` instance. batch_size: Integer or `None`. Number of samples per gradient update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` is your data is in the - form of symbolic tensors, dataset, dataset iterators, + form of symbolic tensors, dataset, generators, or `keras.utils.Sequence` instances (since they generate batches). verbose: Verbosity mode, 0 or 1. steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of `None`. If x is a `tf.data` - dataset or a dataset iterator, and `steps` is None, `predict` will + dataset and `steps` is None, `predict` will run until the input dataset is exhausted. callbacks: List of `keras.callbacks.Callback` instances. List of callbacks to apply during prediction. @@ -904,11 +914,11 @@ class Model(network.Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely). If - `x` is a dataset or a dataset iterator, `y` should not be specified + `x` is a dataset, `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of @@ -916,7 +926,7 @@ class Model(network.Network): sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset. class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. This can be useful to tell the model to "pay @@ -993,13 +1003,12 @@ class Model(network.Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset or a - dataset iterator, `y` should not be specified - (since targets will be obtained from the iterator). + tensor targets, or inversely). If `x` is a dataset `y` should + not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array @@ -1007,7 +1016,7 @@ class Model(network.Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset. reset_metrics: If `True`, the metrics returned will be only for this batch. If `False`, the metrics will be statefully accumulated across batches. @@ -1068,7 +1077,7 @@ class Model(network.Network): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. Returns: Numpy array(s) of predictions. @@ -2221,13 +2230,12 @@ class Model(network.Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset or a - dataset iterator, `y` should not be specified - (since targets will be obtained from the iterator). + tensor targets, or inversely). If `x` is a dataset, `y` should not be + specified (since targets will be obtained from the iterator). sample_weight: An optional sample-weight array passed by the user to weight the importance of each sample in `x`. class_weight: An optional class-weight array by the user to diff --git a/tensorflow/python/keras/engine/training_dataset_test.py b/tensorflow/python/keras/engine/training_dataset_test.py index cd3613198fd..145465b9f3b 100644 --- a/tensorflow/python/keras/engine/training_dataset_test.py +++ b/tensorflow/python/keras/engine/training_dataset_test.py @@ -47,100 +47,6 @@ class BatchCounterCallback(callbacks.Callback): self.batch_count += 1 -class TestTrainingWithDatasetIterators(keras_parameterized.TestCase): - - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes - def test_training_and_eval_methods_on_iterators_single_io(self): - model = testing_utils.get_small_mlp(1, 4, input_dim=3) - optimizer = 'rmsprop' - loss = 'mse' - metrics = ['mae', metrics_module.CategoricalAccuracy()] - model.compile( - optimizer, - loss, - metrics=metrics, - run_eagerly=testing_utils.should_run_eagerly(), - run_distributed=testing_utils.should_run_distributed()) - - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - - model.fit(iterator, epochs=1, steps_per_epoch=2, verbose=1) - model.evaluate(iterator, steps=2, verbose=1) - model.predict(iterator, steps=2) - - # Test with validation data - model.fit(iterator, - epochs=1, steps_per_epoch=2, verbose=0, - validation_data=iterator, validation_steps=2) - # Test with validation split - with self.assertRaisesRegexp( - ValueError, '`validation_split` argument is not supported when '): - model.fit(iterator, - epochs=1, steps_per_epoch=2, verbose=0, - validation_split=0.5, validation_steps=2) - - # Test with sample weight. - sample_weight = np.random.random((10,)) - with self.assertRaisesRegexp( - ValueError, '`sample_weight` argument is not supported ' - 'when input `x` is a dataset or a dataset iterator'): - model.fit( - iterator, - epochs=1, - steps_per_epoch=2, - verbose=0, - sample_weight=sample_weight) - - # Test invalid usage - with self.assertRaisesRegexp(ValueError, - 'you should not specify a target'): - model.fit(iterator, iterator, - epochs=1, steps_per_epoch=2, verbose=0) - - with self.assertRaisesRegexp( - ValueError, 'the `steps_per_epoch` argument'): - model.fit(iterator, epochs=1, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'the `steps` argument'): - model.evaluate(iterator, verbose=0) - with self.assertRaisesRegexp(ValueError, - 'the `steps` argument'): - model.predict(iterator, verbose=0) - - @keras_parameterized.run_with_all_model_types - @keras_parameterized.run_all_keras_modes - def test_iterators_running_out_of_data(self): - model = testing_utils.get_small_mlp(1, 4, input_dim=3) - optimizer = 'rmsprop' - loss = 'mse' - metrics = ['mae'] - model.compile( - optimizer, - loss, - metrics=metrics, - run_eagerly=testing_utils.should_run_eagerly(), - run_distributed=testing_utils.should_run_distributed()) - - inputs = np.zeros((10, 3), np.float32) - targets = np.zeros((10, 4), np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(2) - dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - - with test.mock.patch.object(logging, 'warning') as mock_log: - model.fit(iterator, epochs=1, steps_per_epoch=3, verbose=0) - self.assertRegexpMatches( - str(mock_log.call_args), - 'dataset iterator ran out of data') - - class TestTrainingWithDataset(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @@ -618,11 +524,11 @@ class TestTrainingWithDataset(keras_parameterized.TestCase): model.fit(dataset) -class TestMetricsWithDatasetIterators(keras_parameterized.TestCase): +class TestMetricsWithDatasets(keras_parameterized.TestCase): @keras_parameterized.run_with_all_model_types @keras_parameterized.run_all_keras_modes - def test_metrics_correctness_with_iterator(self): + def test_metrics_correctness_with_dataset(self): layers = [ keras.layers.Dense(8, activation='relu', input_dim=4, kernel_initializer='ones'), @@ -643,8 +549,7 @@ class TestMetricsWithDatasetIterators(keras_parameterized.TestCase): y = np.random.randint(2, size=(100, 1)).astype(np.float32) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - outs = model.evaluate(iterator, steps=10) + outs = model.evaluate(dataset, steps=10) self.assertEqual(np.around(outs[1], decimals=1), 0.5) self.assertEqual(np.around(outs[2], decimals=1), 0.5) @@ -652,8 +557,7 @@ class TestMetricsWithDatasetIterators(keras_parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - outs = model.evaluate(iterator, steps=10) + outs = model.evaluate(dataset, steps=10) self.assertEqual(outs[1], 0.) self.assertEqual(outs[2], 0.) diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py index 57d2f50d2ec..e74c5b678d4 100644 --- a/tensorflow/python/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/engine/training_eager_test.py @@ -183,30 +183,20 @@ class TrainingTest(keras_parameterized.TestCase): x = array_ops.zeros(shape=(10, 3)) y = array_ops.zeros(shape=(10, 4)) dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5) - iterator = dataset_ops.make_one_shot_iterator(dataset) validation_dataset = dataset_ops.Dataset.from_tensor_slices( (x, y)).repeat().batch(5) # Infinite dataset. - validation_iterator = dataset_ops.make_one_shot_iterator(validation_dataset) - with self.assertRaisesRegexp( - ValueError, r'specify .* `steps_per_epoch`'): - model.fit(iterator, epochs=1, verbose=0) - if not context.executing_eagerly(): - # In eager execution, `array_ops.zeros` returns value tensors - # which can be used for validation without a `validation_steps` argument. - with self.assertRaisesRegexp( - ValueError, r'provide either `batch_size` or `validation_steps`'): - model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0, - validation_data=(x, y)) + model.fit(dataset, epochs=1, verbose=0) + # Step argument is required for infinite datasets. with self.assertRaisesRegexp(ValueError, 'specify the `validation_steps` argument.'): - model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0, + model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0, validation_data=validation_dataset) with self.assertRaisesRegexp(ValueError, 'specify the `validation_steps` argument.'): - model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0, - validation_data=validation_iterator) + model.fit(dataset, steps_per_epoch=2, epochs=1, verbose=0, + validation_data=validation_dataset) # TODO(b/120931266): Enable test on subclassed models after bug causing an # extra dimension to be added to predict outputs is fixed. @@ -282,8 +272,7 @@ class CorrectnessTest(keras_parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - history = model.fit(iterator, epochs=1, steps_per_epoch=10) + history = model.fit(dataset, epochs=1, steps_per_epoch=10) self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4) def test_loss_in_call(self): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index 9c82bc1a5ae..9f020221322 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -859,8 +859,7 @@ class TrainingTest(keras_parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) dataset = dataset.repeat(10) dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - model.fit(iterator, epochs=1, steps_per_epoch=2) + model.fit(dataset, epochs=1, steps_per_epoch=2) if context.executing_eagerly(): # Test with eager execution @@ -870,7 +869,7 @@ class TrainingTest(keras_parameterized.TestCase): model.fit(x_train, y_train, batch_size=5, epochs=1) # Test with eager execution and iterator - model.fit(iterator, epochs=1, steps_per_epoch=2) + model.fit(dataset, epochs=1, steps_per_epoch=2) def test_losses_in_defun(self): with context.eager_mode(): diff --git a/tensorflow/python/keras/engine/training_v2_utils.py b/tensorflow/python/keras/engine/training_v2_utils.py index 2f42a5f531b..982ef2a71a1 100644 --- a/tensorflow/python/keras/engine/training_v2_utils.py +++ b/tensorflow/python/keras/engine/training_v2_utils.py @@ -178,11 +178,11 @@ def train_on_batch( (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and tensor targets, or inversely). If - `x` is a dataset or a dataset iterator, `y` should not be specified + `x` is a dataset `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. In the case of @@ -190,7 +190,7 @@ def train_on_batch( sequence_length), to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset. class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model's loss for the samples from this class during training. This can be useful to tell the model to "pay @@ -249,12 +249,12 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and - tensor targets, or inversely). If `x` is a dataset or a - dataset iterator, `y` should not be specified + tensor targets, or inversely). If `x` is a dataset, + `y` should not be specified (since targets will be obtained from the iterator). sample_weight: Optional array of the same length as x, containing weights to apply to the model's loss for each sample. @@ -263,7 +263,7 @@ def test_on_batch(model, x, y=None, sample_weight=None, reset_metrics=True): to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile(). This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset. reset_metrics: If `True`, the metrics returned will be only for this batch. If `False`, the metrics will be statefully accumulated across batches. @@ -310,7 +310,7 @@ def predict_on_batch(model, x): (in case the model has multiple inputs). - A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs). - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset. Returns: Numpy array(s) of predictions. diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index eecb3b5bd20..39d6594a318 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -646,7 +646,7 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) _ = model.evaluate([x1, x2], [y1, y2], verbose=0) - def test_single_io_workflow_with_dataset_iterators(self): + def test_single_io_workflow_with_datasets(self): num_classes = 2 num_samples = 10 input_dim = 50 @@ -664,10 +664,9 @@ class ModelSubclassCompiledTest(keras_parameterized.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((x, y)) dataset = dataset.repeat(100) dataset = dataset.batch(10) - iterator = dataset_ops.make_one_shot_iterator(dataset) - model.fit(iterator, epochs=2, steps_per_epoch=10, verbose=0) - _ = model.evaluate(iterator, steps=10, verbose=0) + model.fit(dataset, epochs=2, steps_per_epoch=10, verbose=0) + _ = model.evaluate(dataset, steps=10, verbose=0) def test_attributes(self): # layers, weights, trainable_weights, non_trainable_weights, inputs, outputs