diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py index e4f782810dd..5660b5839ce 100644 --- a/tensorflow/python/distribute/custom_training_loop_input_test.py +++ b/tensorflow/python/distribute/custom_training_loop_input_test.py @@ -30,6 +30,7 @@ from tensorflow.python.distribute import strategy_combinations from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -136,8 +137,52 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, @combinations.generate( combinations.combine( - distribution=strategy_combinations.tpu_strategies, - mode=["eager"])) + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def testGetNextAsOptional(self, distribution): + data = [5., 6., 7., 8.] + dataset = get_dataset_from_tensor_slices(data).batch(2) + dist_dataset = distribution.experimental_distribute_dataset(dataset) + iterator = iter(dist_dataset) + + def train_step(data): + return math_ops.square(data) + + @def_function.function + def run(iterator): + return distribution.experimental_local_results( + distribution.run( + train_step, args=(iterator.get_next_as_optional().get_value(),))) + + self.assert_equal_flattened([[25., 36.]], [run(iterator)]) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies, mode=["eager"])) + def testGetNextAsOptionalExampleUsage(self, distribution): + global_batch_size = 2 + steps_per_loop = 6 + dataset = dataset_ops.Dataset.range( + 8, output_type=dtypes.int32).batch(global_batch_size) + distributed_iterator = iter( + distribution.experimental_distribute_dataset(dataset)) + + @def_function.function + def train_fn(distributed_iterator): + + def step_fn(x): + return x + + for _ in math_ops.range(steps_per_loop): + optional_data = distributed_iterator.get_next_as_optional() + if not optional_data.has_value(): + break + distribution.run(step_fn, args=(optional_data.get_value(),)) + + train_fn(distributed_iterator) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.tpu_strategies, mode=["eager"])) def testFullEagerTPU(self, distribution): dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index b6a89463426..ec0b911ebe0 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -200,6 +200,7 @@ import six from tensorflow.python.autograph.core import ag_ctx as autograph_ctx from tensorflow.python.autograph.impl import api as autograph from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribution_strategy_context @@ -2879,6 +2880,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1): def get_next(self): return self._iterator.get_next() + def get_next_as_optional(self): + return iterator_ops.get_next_as_optional(self._iterator) + @deprecated(None, "Use the iterator's `initializer` property instead.") def initialize(self): """Initialize underlying iterators. diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index ff468af7f87..e4a362a92c6 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import batching from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.data.ops import optional_ops from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_utils from tensorflow.python.distribute import distribution_strategy_context @@ -235,6 +236,40 @@ class DistributedIteratorInterface(collections.Iterator, raise NotImplementedError( "DistributedIterator.element_spec() must be implemented in descendants") + def get_next_as_optional(self): + """Returns a `tf.experimental.Optional` that contains the next value for all replicas. + + If the `tf.distribute.DistributedIterator` has reached the end of the + sequence, the returned `tf.experimental.Optional` will have no value. + + Example usage: + + >>> strategy = tf.distribute.MirroredStrategy() + >>> global_batch_size = 2 + >>> steps_per_loop = 2 + >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size) + >>> distributed_iterator = iter( + ... strategy.experimental_distribute_dataset(dataset)) + >>> def step_fn(x): + ... return x + >>> @tf.function + ... def train_fn(distributed_iterator): + ... for _ in tf.range(steps_per_loop): + ... optional_data = distributed_iterator.get_next_as_optional() + ... if not optional_data.has_value(): + ... break + ... tf.print(strategy.run(step_fn, args=(optional_data.get_value(),))) + >>> train_fn(distributed_iterator) + ... # ([0 1],) + ... # ([2 3],) + + Returns: + An `tf.experimental.Optional` object representing the next value from the + `tf.distribute.DistributedIterator` (if it has one) or no value. + """ + raise NotImplementedError( + "get_next_as_optional() not implemented in descendants") + @tf_export("distribute.DistributedDataset", v1=[]) class DistributedDatasetInterface(collections.Iterable, @@ -622,6 +657,31 @@ class DistributedIteratorBase(DistributedIteratorInterface): def __iter__(self): return self + def get_next_as_optional(self): + global_has_value, replicas = _get_next_as_optional(self, self._strategy) + + def return_none(): + return optional_ops.Optional.empty(self._element_spec) + + def return_value(replicas): + """Wraps the inputs for replicas in an `tf.experimental.Optional`.""" + results = [] + for i, worker in enumerate(self._input_workers.worker_devices): + with ops.device(worker): + devices = self._input_workers.compute_devices_for_worker(i) + for j, device in enumerate(devices): + with ops.device(device): + result = replicas[i][j] + results.append(result) + replicas = results + + return optional_ops.Optional.from_value( + distribute_utils.regroup(replicas)) + + return control_flow_ops.cond(global_has_value, + lambda: return_value(replicas), + lambda: return_none()) # pylint: disable=unnecessary-lambda + def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" if not self._enable_get_next_as_optional: diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index ff4436c4c8c..7f02d0121d0 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -185,38 +185,76 @@ class DistributedIteratorTestBase(test.TestCase): if not ops.executing_eagerly_outside_functions(): evaluate(control_flow_ops.group(iterator.initializer)) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [distribute_utils.select_replica(r, next_element) - for r in range(len(devices))]) - self.assertEqual(len(expected_value), len(computed_value)) - for i in range(len(expected_value)): - self.assertAllEqual(expected_value[i], computed_value[i]) + def test_get_next(iterator): + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate([ + distribute_utils.select_replica(r, next_element) + for r in range(len(devices)) + ]) - with self.assertRaises(errors.OutOfRangeError): - next_element = iterator.get_next() - evaluate( - [distribute_utils.select_replica(r, next_element) - for r in range(len(devices))]) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) - # After re-initializing the iterator, should be able to iterate again. - if not ops.executing_eagerly_outside_functions(): - evaluate(control_flow_ops.group(iterator.initializer)) + with self.assertRaises(errors.OutOfRangeError): + next_element = iterator.get_next() + evaluate([ + distribute_utils.select_replica(r, next_element) + for r in range(len(devices)) + ]) + + # After re-initializing the iterator, should be able to iterate again. + if not ops.executing_eagerly_outside_functions(): + evaluate(control_flow_ops.group(iterator.initializer)) + else: + if api_type == "wrap_into_iterator": + self.skipTest("unsupported test combination") + else: + iterator = iter(dataset) + + for expected_value in expected_values: + next_element = iterator.get_next() + computed_value = evaluate([ + distribute_utils.select_replica(r, next_element) + for r in range(len(devices)) + ]) + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) + + def test_get_next_as_optional(iterator): + for expected_value in expected_values: + next_element = iterator.get_next_as_optional() + computed_value = evaluate([ + distribute_utils.select_replica(r, next_element.get_value()) + for r in range(len(devices)) + ]) + + self.assertEqual(len(expected_value), len(computed_value)) + for i in range(len(expected_value)): + self.assertAllEqual(expected_value[i], computed_value[i]) + + next_element = iterator.get_next_as_optional() + self.assertFalse(self.evaluate(next_element.has_value())) + with self.assertRaises(errors.InvalidArgumentError): + evaluate([ + distribute_utils.select_replica(r, next_element.get_value()) + for r in range(len(devices)) + ]) + + test_get_next(iterator) + + # re-initializing the iterator + if not tf2.enabled(): + self.skipTest("Not testing get_next_as_optional in TF1") else: if api_type == "wrap_into_iterator": self.skipTest("unsupported test combination") else: iterator = iter(dataset) - for expected_value in expected_values: - next_element = iterator.get_next() - computed_value = evaluate( - [distribute_utils.select_replica(r, next_element) - for r in range(len(devices))]) - self.assertEqual(len(expected_value), len(computed_value)) - for i in range(len(expected_value)): - self.assertAllEqual(expected_value[i], computed_value[i]) + test_get_next_as_optional(iterator) if iteration_type == "for_loop" and context.executing_eagerly(): actual_values = [] diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt index f712d9058b9..47899cc4188 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-distributed-iterator.pbtxt @@ -13,4 +13,8 @@ tf_class { name: "get_next" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "get_next_as_optional" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } }