From 5a0ed634afbdc95a9524b5344e8a7b6c6621c3b7 Mon Sep 17 00:00:00 2001
From: Ruoxin Sang <rxsang@google.com>
Date: Wed, 11 Nov 2020 17:34:10 -0800
Subject: [PATCH] Always enable get_next_as_optional unless the dataset is
 finite.

PiperOrigin-RevId: 341945136
Change-Id: I79fdec366be2119b6a28063f193e6cecb7a5f9e2
---
 tensorflow/python/distribute/input_lib.py     |  3 +-
 .../python/distribute/input_lib_test.py       | 30 +++++++++----------
 2 files changed, 16 insertions(+), 17 deletions(-)

diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index 390d2612753..ba5590e8d10 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -2147,7 +2147,8 @@ def _enable_get_next_as_optional(strategy, dataset):
     # dataset is created in eager mode, as we need to evaluate the dataset
     # cardinality.
     with ops.device(dataset._variant_tensor.device):  # pylint: disable=protected-access
-      return dataset.cardinality().numpy() != cardinality.INFINITE
+      if dataset.cardinality().numpy() == cardinality.INFINITE:
+        return False
 
   return not _is_statically_shaped(
       dataset.element_spec) or strategy.extended._in_multi_worker_mode()  # pylint: disable=protected-access
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index 442dabfd02e..8a85f96d4b1 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -1118,21 +1118,21 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
         except (StopIteration, errors.OutOfRangeError):
           return sums
 
-    expected_for_sum = 200.
-    if (not drop_remainder or input_type == "input_fn"):
-      expected_for_sum = 310.
     while_sums = sum_while_loop(
         iter(dataset),
         defun(lambda state, iterator: _reduce(state, next(iterator))))
-    self.assertAllEqual(nest.flatten(while_sums), [expected_for_sum] * 3)
-
+    self.assertAllEqual(
+        nest.flatten(while_sums),
+        # When there's no partial batch, the sum is smaller.
+        [200. if drop_remainder else 310.] * 3)
+    for_sums = defun(sum_for_loop)(dataset)
     # For loops always call get next as optional inside tf functions, so we
     # expect 310 here when using an input function (as there are 5 batches of
     # size 4 round robined over 2 replicas.
     expected_for_sum = 200.
-    if (not drop_remainder or input_type == "input_fn"):
+    if (not drop_remainder or (
+        defun_type == "tf_function" and input_type == "input_fn")):
       expected_for_sum = 310.
-    for_sums = defun(sum_for_loop)(dataset)
     self.assertAllEqual(nest.flatten(for_sums), [expected_for_sum] * 3)
 
   @combinations.generate(
@@ -1146,12 +1146,12 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
           ],
           input_type=["dataset", "input_fn"],
           drop_remainder=[False, True],
-          repeat=[False, True],
           tensor_type=["sparse", "ragged"],
-          enable_get_next_as_optional=[True, False]))
-  def testRaggedSparseGetNextAsOptional(self, distribution, input_type,
-                                        drop_remainder, repeat, tensor_type,
-                                        enable_get_next_as_optional):
+          enable_get_next_as_optional=[True, False]
+      ))
+  def testRaggedSparseGetNextAsOptional(
+      self, distribution, input_type, drop_remainder, tensor_type,
+      enable_get_next_as_optional):
     """Test with `RaggedTensor`s and `SparseTensor`s."""
     if not tf2.enabled():
       self.skipTest("Only V2 is supported.")
@@ -1172,8 +1172,6 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
                         ragged_tensor.to_sparse()),
       })
       dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
-      if repeat:
-        dataset = dataset.repeat()
       return dataset.batch(batch_size, drop_remainder=drop_remainder)
 
     if input_type == "dataset":
@@ -1183,8 +1181,8 @@ class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
       ds = distribution.distribute_datasets_from_function(dataset_fn)
     iterator = iter(ds)
 
-    self.assertEqual(iterator._enable_get_next_as_optional, (not repeat) and
-                     enable_get_next_as_optional)
+    self.assertEqual(iterator._enable_get_next_as_optional,
+                     (not drop_remainder) and enable_get_next_as_optional)
 
   @combinations.generate(
       combinations.combine(