Add get_next_as_optional method for a distributed iterator

The function is called on a distributed iterator and returns an `Optional` that contains the next value, the PerReplica input, from Distributed iterator or no value if this `iterator` has reached the end of the sequence.

PiperOrigin-RevId: 317248910
Change-Id: Ide217da1aff1d62f8d0d8f43423be2d859d933d3
This commit is contained in:
Xinyi Wang 2020-06-18 22:15:44 -07:00 committed by TensorFlower Gardener
parent b7caba2c42
commit 158d4be42d
5 changed files with 177 additions and 26 deletions

View File

@ -30,6 +30,7 @@ from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -136,8 +137,52 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.tpu_strategies,
mode=["eager"]))
distribution=strategy_combinations.all_strategies, mode=["eager"]))
def testGetNextAsOptional(self, distribution):
data = [5., 6., 7., 8.]
dataset = get_dataset_from_tensor_slices(data).batch(2)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
iterator = iter(dist_dataset)
def train_step(data):
return math_ops.square(data)
@def_function.function
def run(iterator):
return distribution.experimental_local_results(
distribution.run(
train_step, args=(iterator.get_next_as_optional().get_value(),)))
self.assert_equal_flattened([[25., 36.]], [run(iterator)])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies, mode=["eager"]))
def testGetNextAsOptionalExampleUsage(self, distribution):
global_batch_size = 2
steps_per_loop = 6
dataset = dataset_ops.Dataset.range(
8, output_type=dtypes.int32).batch(global_batch_size)
distributed_iterator = iter(
distribution.experimental_distribute_dataset(dataset))
@def_function.function
def train_fn(distributed_iterator):
def step_fn(x):
return x
for _ in math_ops.range(steps_per_loop):
optional_data = distributed_iterator.get_next_as_optional()
if not optional_data.has_value():
break
distribution.run(step_fn, args=(optional_data.get_value(),))
train_fn(distributed_iterator)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
def testFullEagerTPU(self, distribution):
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)

View File

@ -200,6 +200,7 @@ import six
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
from tensorflow.python.autograph.impl import api as autograph
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribution_strategy_context
@ -2879,6 +2880,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def get_next(self):
return self._iterator.get_next()
def get_next_as_optional(self):
return iterator_ops.get_next_as_optional(self._iterator)
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
"""Initialize underlying iterators.

View File

@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
@ -235,6 +236,40 @@ class DistributedIteratorInterface(collections.Iterator,
raise NotImplementedError(
"DistributedIterator.element_spec() must be implemented in descendants")
def get_next_as_optional(self):
"""Returns a `tf.experimental.Optional` that contains the next value for all replicas.
If the `tf.distribute.DistributedIterator` has reached the end of the
sequence, the returned `tf.experimental.Optional` will have no value.
Example usage:
>>> strategy = tf.distribute.MirroredStrategy()
>>> global_batch_size = 2
>>> steps_per_loop = 2
>>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
>>> distributed_iterator = iter(
... strategy.experimental_distribute_dataset(dataset))
>>> def step_fn(x):
... return x
>>> @tf.function
... def train_fn(distributed_iterator):
... for _ in tf.range(steps_per_loop):
... optional_data = distributed_iterator.get_next_as_optional()
... if not optional_data.has_value():
... break
... tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))
>>> train_fn(distributed_iterator)
... # ([0 1],)
... # ([2 3],)
Returns:
An `tf.experimental.Optional` object representing the next value from the
`tf.distribute.DistributedIterator` (if it has one) or no value.
"""
raise NotImplementedError(
"get_next_as_optional() not implemented in descendants")
@tf_export("distribute.DistributedDataset", v1=[])
class DistributedDatasetInterface(collections.Iterable,
@ -622,6 +657,31 @@ class DistributedIteratorBase(DistributedIteratorInterface):
def __iter__(self):
return self
def get_next_as_optional(self):
global_has_value, replicas = _get_next_as_optional(self, self._strategy)
def return_none():
return optional_ops.Optional.empty(self._element_spec)
def return_value(replicas):
"""Wraps the inputs for replicas in an `tf.experimental.Optional`."""
results = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
devices = self._input_workers.compute_devices_for_worker(i)
for j, device in enumerate(devices):
with ops.device(device):
result = replicas[i][j]
results.append(result)
replicas = results
return optional_ops.Optional.from_value(
distribute_utils.regroup(replicas))
return control_flow_ops.cond(global_has_value,
lambda: return_value(replicas),
lambda: return_none()) # pylint: disable=unnecessary-lambda
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
if not self._enable_get_next_as_optional:

View File

@ -185,38 +185,76 @@ class DistributedIteratorTestBase(test.TestCase):
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
def test_get_next(iterator):
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate([
distribute_utils.select_replica(r, next_element)
for r in range(len(devices))
])
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
# After re-initializing the iterator, should be able to iterate again.
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate([
distribute_utils.select_replica(r, next_element)
for r in range(len(devices))
])
# After re-initializing the iterator, should be able to iterate again.
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
else:
if api_type == "wrap_into_iterator":
self.skipTest("unsupported test combination")
else:
iterator = iter(dataset)
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate([
distribute_utils.select_replica(r, next_element)
for r in range(len(devices))
])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
def test_get_next_as_optional(iterator):
for expected_value in expected_values:
next_element = iterator.get_next_as_optional()
computed_value = evaluate([
distribute_utils.select_replica(r, next_element.get_value())
for r in range(len(devices))
])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
next_element = iterator.get_next_as_optional()
self.assertFalse(self.evaluate(next_element.has_value()))
with self.assertRaises(errors.InvalidArgumentError):
evaluate([
distribute_utils.select_replica(r, next_element.get_value())
for r in range(len(devices))
])
test_get_next(iterator)
# re-initializing the iterator
if not tf2.enabled():
self.skipTest("Not testing get_next_as_optional in TF1")
else:
if api_type == "wrap_into_iterator":
self.skipTest("unsupported test combination")
else:
iterator = iter(dataset)
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[distribute_utils.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
test_get_next_as_optional(iterator)
if iteration_type == "for_loop" and context.executing_eagerly():
actual_values = []

View File

@ -13,4 +13,8 @@ tf_class {
name: "get_next"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_next_as_optional"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}