From 5a0ed634afbdc95a9524b5344e8a7b6c6621c3b7 Mon Sep 17 00:00:00 2001 From: Ruoxin Sang <rxsang@google.com> Date: Wed, 11 Nov 2020 17:34:10 -0800 Subject: [PATCH] Always enable get_next_as_optional unless the dataset is finite. PiperOrigin-RevId: 341945136 Change-Id: I79fdec366be2119b6a28063f193e6cecb7a5f9e2 --- tensorflow/python/distribute/input_lib.py | 3 +- .../python/distribute/input_lib_test.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 390d2612753..ba5590e8d10 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -2147,7 +2147,8 @@ def _enable_get_next_as_optional(strategy, dataset): # dataset is created in eager mode, as we need to evaluate the dataset # cardinality. with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access - return dataset.cardinality().numpy() != cardinality.INFINITE + if dataset.cardinality().numpy() == cardinality.INFINITE: + return False return not _is_statically_shaped( dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 442dabfd02e..8a85f96d4b1 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -1118,21 +1118,21 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, except (StopIteration, errors.OutOfRangeError): return sums - expected_for_sum = 200. - if (not drop_remainder or input_type == "input_fn"): - expected_for_sum = 310. while_sums = sum_while_loop( iter(dataset), defun(lambda state, iterator: _reduce(state, next(iterator)))) - self.assertAllEqual(nest.flatten(while_sums), [expected_for_sum] * 3) - + self.assertAllEqual( + nest.flatten(while_sums), + # When there's no partial batch, the sum is smaller. + [200. if drop_remainder else 310.] * 3) + for_sums = defun(sum_for_loop)(dataset) # For loops always call get next as optional inside tf functions, so we # expect 310 here when using an input function (as there are 5 batches of # size 4 round robined over 2 replicas. expected_for_sum = 200. - if (not drop_remainder or input_type == "input_fn"): + if (not drop_remainder or ( + defun_type == "tf_function" and input_type == "input_fn")): expected_for_sum = 310. - for_sums = defun(sum_for_loop)(dataset) self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) @combinations.generate( @@ -1146,12 +1146,12 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ], input_type=["dataset", "input_fn"], drop_remainder=[False, True], - repeat=[False, True], tensor_type=["sparse", "ragged"], - enable_get_next_as_optional=[True, False])) - def testRaggedSparseGetNextAsOptional(self, distribution, input_type, - drop_remainder, repeat, tensor_type, - enable_get_next_as_optional): + enable_get_next_as_optional=[True, False] + )) + def testRaggedSparseGetNextAsOptional( + self, distribution, input_type, drop_remainder, tensor_type, + enable_get_next_as_optional): """Test with `RaggedTensor`s and `SparseTensor`s.""" if not tf2.enabled(): self.skipTest("Only V2 is supported.") @@ -1172,8 +1172,6 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ragged_tensor.to_sparse()), }) dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) - if repeat: - dataset = dataset.repeat() return dataset.batch(batch_size, drop_remainder=drop_remainder) if input_type == "dataset": @@ -1183,8 +1181,8 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ds = distribution.distribute_datasets_from_function(dataset_fn) iterator = iter(ds) - self.assertEqual(iterator._enable_get_next_as_optional, (not repeat) and - enable_get_next_as_optional) + self.assertEqual(iterator._enable_get_next_as_optional, + (not drop_remainder) and enable_get_next_as_optional) @combinations.generate( combinations.combine(