diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 3dbe15aa612..553d1151df8 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -2157,7 +2157,8 @@ def _enable_get_next_as_optional(strategy, dataset): # dataset is created in eager mode, as we need to evaluate the dataset # cardinality. with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access - return dataset.cardinality().numpy() != cardinality.INFINITE + if dataset.cardinality().numpy() == cardinality.INFINITE: + return False return not _is_statically_shaped( dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index bce24a03504..d2918a9c267 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -1120,21 +1120,21 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, except (StopIteration, errors.OutOfRangeError): return sums - expected_for_sum = 200. - if (not drop_remainder or input_type == "input_fn"): - expected_for_sum = 310. while_sums = sum_while_loop( iter(dataset), defun(lambda state, iterator: _reduce(state, next(iterator)))) - self.assertAllEqual(nest.flatten(while_sums), [expected_for_sum] * 3) - + self.assertAllEqual( + nest.flatten(while_sums), + # When there's no partial batch, the sum is smaller. + [200. if drop_remainder else 310.] * 3) + for_sums = defun(sum_for_loop)(dataset) # For loops always call get next as optional inside tf functions, so we # expect 310 here when using an input function (as there are 5 batches of # size 4 round robined over 2 replicas. expected_for_sum = 200. - if (not drop_remainder or input_type == "input_fn"): + if (not drop_remainder or ( + defun_type == "tf_function" and input_type == "input_fn")): expected_for_sum = 310. - for_sums = defun(sum_for_loop)(dataset) self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3) @combinations.generate( @@ -1148,12 +1148,12 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ], input_type=["dataset", "input_fn"], drop_remainder=[False, True], - repeat=[False, True], tensor_type=["sparse", "ragged"], - enable_get_next_as_optional=[True, False])) - def testRaggedSparseGetNextAsOptional(self, distribution, input_type, - drop_remainder, repeat, tensor_type, - enable_get_next_as_optional): + enable_get_next_as_optional=[True, False] + )) + def testRaggedSparseGetNextAsOptional( + self, distribution, input_type, drop_remainder, tensor_type, + enable_get_next_as_optional): """Test with `RaggedTensor`s and `SparseTensor`s.""" if not tf2.enabled(): self.skipTest("Only V2 is supported.") @@ -1174,8 +1174,6 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ragged_tensor.to_sparse()), }) dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) - if repeat: - dataset = dataset.repeat() return dataset.batch(batch_size, drop_remainder=drop_remainder) if input_type == "dataset": @@ -1185,8 +1183,8 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, ds = distribution.distribute_datasets_from_function(dataset_fn) iterator = iter(ds) - self.assertEqual(iterator._enable_get_next_as_optional, (not repeat) and - enable_get_next_as_optional) + self.assertEqual(iterator._enable_get_next_as_optional, + (not drop_remainder) and enable_get_next_as_optional) @combinations.generate( combinations.combine(