diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py
index e4f782810dd..5660b5839ce 100644
--- a/tensorflow/python/distribute/custom_training_loop_input_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_input_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.eager import def_function
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -136,8 +137,52 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
 
   @combinations.generate(
       combinations.combine(
-          distribution=strategy_combinations.tpu_strategies,
-          mode=["eager"]))
+          distribution=strategy_combinations.all_strategies, mode=["eager"]))
+  def testGetNextAsOptional(self, distribution):
+    data = [5., 6., 7., 8.]
+    dataset = get_dataset_from_tensor_slices(data).batch(2)
+    dist_dataset = distribution.experimental_distribute_dataset(dataset)
+    iterator = iter(dist_dataset)
+
+    def train_step(data):
+      return math_ops.square(data)
+
+    @def_function.function
+    def run(iterator):
+      return distribution.experimental_local_results(
+          distribution.run(
+              train_step, args=(iterator.get_next_as_optional().get_value(),)))
+
+    self.assert_equal_flattened([[25., 36.]], [run(iterator)])
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=strategy_combinations.all_strategies, mode=["eager"]))
+  def testGetNextAsOptionalExampleUsage(self, distribution):
+    global_batch_size = 2
+    steps_per_loop = 6
+    dataset = dataset_ops.Dataset.range(
+        8, output_type=dtypes.int32).batch(global_batch_size)
+    distributed_iterator = iter(
+        distribution.experimental_distribute_dataset(dataset))
+
+    @def_function.function
+    def train_fn(distributed_iterator):
+
+      def step_fn(x):
+        return x
+
+      for _ in math_ops.range(steps_per_loop):
+        optional_data = distributed_iterator.get_next_as_optional()
+        if not optional_data.has_value():
+          break
+        distribution.run(step_fn, args=(optional_data.get_value(),))
+
+    train_fn(distributed_iterator)
+
+  @combinations.generate(
+      combinations.combine(
+          distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
   def testFullEagerTPU(self, distribution):
     dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
 
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index b6a89463426..ec0b911ebe0 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -200,6 +200,7 @@ import six
 from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
 from tensorflow.python.autograph.impl import api as autograph
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.distribute import collective_util
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribution_strategy_context
@@ -2879,6 +2880,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
     def get_next(self):
       return self._iterator.get_next()
 
+    def get_next_as_optional(self):
+      return iterator_ops.get_next_as_optional(self._iterator)
+
     @deprecated(None, "Use the iterator's `initializer` property instead.")
     def initialize(self):
       """Initialize underlying iterators.
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index ff468af7f87..e4a362a92c6 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import batching
 from tensorflow.python.data.experimental.ops import distribute
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.python.data.ops import optional_ops
 from tensorflow.python.distribute import device_util
 from tensorflow.python.distribute import distribute_utils
 from tensorflow.python.distribute import distribution_strategy_context
@@ -235,6 +236,40 @@ class DistributedIteratorInterface(collections.Iterator,
     raise NotImplementedError(
         "DistributedIterator.element_spec() must be implemented in descendants")
 
+  def get_next_as_optional(self):
+    """Returns a `tf.experimental.Optional` that contains the next value for all replicas.
+
+    If the `tf.distribute.DistributedIterator` has reached the end of the
+    sequence, the returned `tf.experimental.Optional` will have no value.
+
+    Example usage:
+
+    >>> strategy = tf.distribute.MirroredStrategy()
+    >>> global_batch_size = 2
+    >>> steps_per_loop = 2
+    >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
+    >>> distributed_iterator = iter(
+    ...     strategy.experimental_distribute_dataset(dataset))
+    >>> def step_fn(x):
+    ...   return x
+    >>> @tf.function
+    ... def train_fn(distributed_iterator):
+    ...   for _ in tf.range(steps_per_loop):
+    ...     optional_data = distributed_iterator.get_next_as_optional()
+    ...     if not optional_data.has_value():
+    ...       break
+    ...     tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))
+    >>> train_fn(distributed_iterator)
+    ... # ([0 1],)
+    ... # ([2 3],)
+
+    Returns:
+      An `tf.experimental.Optional` object representing the next value from the
+      `tf.distribute.DistributedIterator` (if it has one) or no value.
+    """
+    raise NotImplementedError(
+        "get_next_as_optional() not implemented in descendants")
+
 
 @tf_export("distribute.DistributedDataset", v1=[])
 class DistributedDatasetInterface(collections.Iterable,
@@ -622,6 +657,31 @@ class DistributedIteratorBase(DistributedIteratorInterface):
   def __iter__(self):
     return self
 
+  def get_next_as_optional(self):
+    global_has_value, replicas = _get_next_as_optional(self, self._strategy)
+
+    def return_none():
+      return optional_ops.Optional.empty(self._element_spec)
+
+    def return_value(replicas):
+      """Wraps the inputs for replicas in an `tf.experimental.Optional`."""
+      results = []
+      for i, worker in enumerate(self._input_workers.worker_devices):
+        with ops.device(worker):
+          devices = self._input_workers.compute_devices_for_worker(i)
+          for j, device in enumerate(devices):
+            with ops.device(device):
+              result = replicas[i][j]
+              results.append(result)
+      replicas = results
+
+      return optional_ops.Optional.from_value(
+          distribute_utils.regroup(replicas))
+
+    return control_flow_ops.cond(global_has_value,
+                                 lambda: return_value(replicas),
+                                 lambda: return_none())  # pylint: disable=unnecessary-lambda
+
   def get_next(self, name=None):
     """Returns the next input from the iterator for all replicas."""
     if not self._enable_get_next_as_optional:
diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py
index ff4436c4c8c..7f02d0121d0 100644
--- a/tensorflow/python/distribute/input_lib_test.py
+++ b/tensorflow/python/distribute/input_lib_test.py
@@ -185,38 +185,76 @@ class DistributedIteratorTestBase(test.TestCase):
       if not ops.executing_eagerly_outside_functions():
         evaluate(control_flow_ops.group(iterator.initializer))
 
-      for expected_value in expected_values:
-        next_element = iterator.get_next()
-        computed_value = evaluate(
-            [distribute_utils.select_replica(r, next_element)
-             for r in range(len(devices))])
-        self.assertEqual(len(expected_value), len(computed_value))
-        for i in range(len(expected_value)):
-          self.assertAllEqual(expected_value[i], computed_value[i])
+      def test_get_next(iterator):
+        for expected_value in expected_values:
+          next_element = iterator.get_next()
+          computed_value = evaluate([
+              distribute_utils.select_replica(r, next_element)
+              for r in range(len(devices))
+          ])
 
-      with self.assertRaises(errors.OutOfRangeError):
-        next_element = iterator.get_next()
-        evaluate(
-            [distribute_utils.select_replica(r, next_element)
-             for r in range(len(devices))])
+          self.assertEqual(len(expected_value), len(computed_value))
+          for i in range(len(expected_value)):
+            self.assertAllEqual(expected_value[i], computed_value[i])
 
-      # After re-initializing the iterator, should be able to iterate again.
-      if not ops.executing_eagerly_outside_functions():
-        evaluate(control_flow_ops.group(iterator.initializer))
+        with self.assertRaises(errors.OutOfRangeError):
+          next_element = iterator.get_next()
+          evaluate([
+              distribute_utils.select_replica(r, next_element)
+              for r in range(len(devices))
+          ])
+
+        # After re-initializing the iterator, should be able to iterate again.
+        if not ops.executing_eagerly_outside_functions():
+          evaluate(control_flow_ops.group(iterator.initializer))
+        else:
+          if api_type == "wrap_into_iterator":
+            self.skipTest("unsupported test combination")
+          else:
+            iterator = iter(dataset)
+
+        for expected_value in expected_values:
+          next_element = iterator.get_next()
+          computed_value = evaluate([
+              distribute_utils.select_replica(r, next_element)
+              for r in range(len(devices))
+          ])
+          self.assertEqual(len(expected_value), len(computed_value))
+          for i in range(len(expected_value)):
+            self.assertAllEqual(expected_value[i], computed_value[i])
+
+      def test_get_next_as_optional(iterator):
+        for expected_value in expected_values:
+          next_element = iterator.get_next_as_optional()
+          computed_value = evaluate([
+              distribute_utils.select_replica(r, next_element.get_value())
+              for r in range(len(devices))
+          ])
+
+          self.assertEqual(len(expected_value), len(computed_value))
+          for i in range(len(expected_value)):
+            self.assertAllEqual(expected_value[i], computed_value[i])
+
+        next_element = iterator.get_next_as_optional()
+        self.assertFalse(self.evaluate(next_element.has_value()))
+        with self.assertRaises(errors.InvalidArgumentError):
+          evaluate([
+              distribute_utils.select_replica(r, next_element.get_value())
+              for r in range(len(devices))
+          ])
+
+      test_get_next(iterator)
+
+      # re-initializing the iterator
+      if not tf2.enabled():
+        self.skipTest("Not testing get_next_as_optional in TF1")
       else:
         if api_type == "wrap_into_iterator":
           self.skipTest("unsupported test combination")
         else:
           iterator = iter(dataset)
 
-      for expected_value in expected_values:
-        next_element = iterator.get_next()
-        computed_value = evaluate(
-            [distribute_utils.select_replica(r, next_element)
-             for r in range(len(devices))])
-        self.assertEqual(len(expected_value), len(computed_value))
-        for i in range(len(expected_value)):
-          self.assertAllEqual(expected_value[i], computed_value[i])
+      test_get_next_as_optional(iterator)
 
     if iteration_type == "for_loop" and context.executing_eagerly():
       actual_values = []
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt
index f712d9058b9..47899cc4188 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt
@@ -13,4 +13,8 @@ tf_class {
     name: "get_next"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
+  member_method {
+    name: "get_next_as_optional"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
 }