Fix recursion error from NumPy->DS change.

PiperOrigin-RevId: 263703207
This commit is contained in:
Thomas O'Malley 2019-08-15 21:14:41 -07:00 committed by Goldie Gadde
parent 8235fb5642
commit 9c57a63096
2 changed files with 17 additions and 11 deletions

View File

@ -248,7 +248,7 @@ class TensorLikeDataAdapter(DataAdapter):
# 2. parallelized map
# 3. vectorized shuffle by using reshape and unbatch
# 4. disabled static optimizations
indices_ds = None
indices_list = []
for _ in range(epochs):
indices = np.arange(num_samples)
if shuffle:
@ -264,10 +264,10 @@ class TensorLikeDataAdapter(DataAdapter):
epoch_indices_ds = epoch_indices_ds.concatenate(
dataset_ops.DatasetV2.from_tensors(partial_batch_indices))
if indices_ds is None:
indices_ds = epoch_indices_ds
else:
indices_ds = indices_ds.concatenate(epoch_indices_ds)
indices_list.append(epoch_indices_ds)
indices_ds = dataset_ops.DatasetV2.from_tensor_slices(
indices_list).flat_map(lambda x: x)
data_ds = dataset_ops.DatasetV2.from_tensors(inputs).repeat()
dataset = dataset_ops.DatasetV2.zip((data_ds, indices_ds))

View File

@ -27,6 +27,7 @@ from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.utils import data_utils
@ -109,11 +110,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.assertTrue(adapter.has_partial_batch())
self.assertEqual(adapter.partial_batch_size(), 2)
@test_util.run_in_graph_and_eager_modes
def test_training_numpy(self):
dataset = self.adapter_cls(
self.numpy_input, self.numpy_target, batch_size=5).get_dataset()
if not context.executing_eagerly():
return # Only test in eager.
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
self.model.fit(self.numpy_input, self.numpy_target, batch_size=5)
def test_can_handle(self):
self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
@ -124,11 +127,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
@test_util.run_in_graph_and_eager_modes
def test_training(self):
dataset = self.adapter_cls(
self.tensor_input, self.tensor_target, batch_size=5).get_dataset()
if not context.executing_eagerly():
return # Only test EagerTensors.
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
self.model.fit(dataset)
self.model.fit(self.tensor_input, self.tensor_target, batch_size=5)
def test_size(self):
adapter = self.adapter_cls(
@ -295,4 +300,5 @@ class KerasSequenceAdapterTest(DataAdapterTestBase):
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()