Fix recursion error from NumPy->DS change.
PiperOrigin-RevId: 263703207
This commit is contained in:
parent
8235fb5642
commit
9c57a63096
@ -248,7 +248,7 @@ class TensorLikeDataAdapter(DataAdapter):
|
||||
# 2. parallelized map
|
||||
# 3. vectorized shuffle by using reshape and unbatch
|
||||
# 4. disabled static optimizations
|
||||
indices_ds = None
|
||||
indices_list = []
|
||||
for _ in range(epochs):
|
||||
indices = np.arange(num_samples)
|
||||
if shuffle:
|
||||
@ -264,10 +264,10 @@ class TensorLikeDataAdapter(DataAdapter):
|
||||
epoch_indices_ds = epoch_indices_ds.concatenate(
|
||||
dataset_ops.DatasetV2.from_tensors(partial_batch_indices))
|
||||
|
||||
if indices_ds is None:
|
||||
indices_ds = epoch_indices_ds
|
||||
else:
|
||||
indices_ds = indices_ds.concatenate(epoch_indices_ds)
|
||||
indices_list.append(epoch_indices_ds)
|
||||
|
||||
indices_ds = dataset_ops.DatasetV2.from_tensor_slices(
|
||||
indices_list).flat_map(lambda x: x)
|
||||
|
||||
data_ds = dataset_ops.DatasetV2.from_tensors(inputs).repeat()
|
||||
dataset = dataset_ops.DatasetV2.zip((data_ds, indices_ds))
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import data_adapter
|
||||
from tensorflow.python.keras.utils import data_utils
|
||||
@ -109,11 +110,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
||||
self.assertTrue(adapter.has_partial_batch())
|
||||
self.assertEqual(adapter.partial_batch_size(), 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_training_numpy(self):
|
||||
dataset = self.adapter_cls(
|
||||
self.numpy_input, self.numpy_target, batch_size=5).get_dataset()
|
||||
if not context.executing_eagerly():
|
||||
return # Only test in eager.
|
||||
|
||||
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
|
||||
self.model.fit(dataset)
|
||||
self.model.fit(self.numpy_input, self.numpy_target, batch_size=5)
|
||||
|
||||
def test_can_handle(self):
|
||||
self.assertTrue(self.adapter_cls.can_handle(self.tensor_input))
|
||||
@ -124,11 +127,13 @@ class TensorLikeDataAdapterTest(DataAdapterTestBase):
|
||||
self.assertFalse(self.adapter_cls.can_handle(self.generator_input))
|
||||
self.assertFalse(self.adapter_cls.can_handle(self.sequence_input))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_training(self):
|
||||
dataset = self.adapter_cls(
|
||||
self.tensor_input, self.tensor_target, batch_size=5).get_dataset()
|
||||
if not context.executing_eagerly():
|
||||
return # Only test EagerTensors.
|
||||
|
||||
self.model.compile(loss='sparse_categorical_crossentropy', optimizer='sgd')
|
||||
self.model.fit(dataset)
|
||||
self.model.fit(self.tensor_input, self.tensor_target, batch_size=5)
|
||||
|
||||
def test_size(self):
|
||||
adapter = self.adapter_cls(
|
||||
@ -295,4 +300,5 @@ class KerasSequenceAdapterTest(DataAdapterTestBase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user