diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 28973bbf1f5..f0d76389eba 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -248,7 +248,7 @@ class TensorLikeDataAdapter(DataAdapter): # 2. parallelized map # 3. vectorized shuffle by using reshape and unbatch # 4. disabled static optimizations - indices_ds = None + indices_list = [] for _ in range(epochs): indices = np.arange(num_samples) if shuffle: @@ -264,10 +264,10 @@ class TensorLikeDataAdapter(DataAdapter): epoch_indices_ds = epoch_indices_ds.concatenate( dataset_ops.DatasetV2.from_tensors(partial_batch_indices)) - if indices_ds is None: - indices_ds = epoch_indices_ds - else: - indices_ds = indices_ds.concatenate(epoch_indices_ds) + indices_list.append(epoch_indices_ds) + + indices_ds = dataset_ops.DatasetV2.from_tensor_slices( + indices_list).flat_map(lambda x: x) data_ds = dataset_ops.DatasetV2.from_tensors(inputs).repeat() dataset = dataset_ops.DatasetV2.zip((data_ds, indices_ds)) diff --git a/tensorflow/python/keras/engine/data_adapter_test.py b/tensorflow/python/keras/engine/data_adapter_test.py index 525f99217b4..7f08206024b 100644 --- a/tensorflow/python/keras/engine/data_adapter_test.py +++ b/tensorflow/python/keras/engine/data_adapter_test.py @@ -27,6 +27,7 @@ from tensorflow.python import keras from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras.engine import data_adapter from tensorflow.python.keras.utils import data_utils @@ -109,11 +110,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase): self.assertTrue(adapter.has_partial_batch()) self.assertEqual(adapter.partial_batch_size(), 2) + @test_util.run_in_graph_and_eager_modes def test_training_numpy(self): - dataset = self.adapter_cls( - self.numpy_input, self.numpy_target, batch_size=5).get_dataset() + if not context.executing_eagerly(): + return # Only test in eager. + self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd') - self.model.fit(dataset) + self.model.fit(self.numpy_input, self.numpy_target, batch_size=5) def test_can_handle(self): self.assertTrue(self.adapter_cls.can_handle(self.tensor_input)) @@ -124,11 +127,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase): self.assertFalse(self.adapter_cls.can_handle(self.generator_input)) self.assertFalse(self.adapter_cls.can_handle(self.sequence_input)) + @test_util.run_in_graph_and_eager_modes def test_training(self): - dataset = self.adapter_cls( - self.tensor_input, self.tensor_target, batch_size=5).get_dataset() + if not context.executing_eagerly(): + return # Only test EagerTensors. + self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd') - self.model.fit(dataset) + self.model.fit(self.tensor_input, self.tensor_target, batch_size=5) def test_size(self): adapter = self.adapter_cls( @@ -295,4 +300,5 @@ class KerasSequenceAdapterTest(DataAdapterTestBase): if __name__ == '__main__': + ops.enable_eager_execution() test.main()