Add get_next_as_optional method for a distributed iterator
The function is called on a distributed iterator and returns an `Optional` that contains the next value, the PerReplica input, from Distributed iterator or no value if this `iterator` has reached the end of the sequence. PiperOrigin-RevId: 317248910 Change-Id: Ide217da1aff1d62f8d0d8f43423be2d859d933d3
This commit is contained in:
parent
b7caba2c42
commit
158d4be42d
tensorflow
python/distribute
tools/api/golden/v2
@ -30,6 +30,7 @@ from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -136,8 +137,52 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.tpu_strategies,
|
||||
mode=["eager"]))
|
||||
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||
def testGetNextAsOptional(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
dataset = get_dataset_from_tensor_slices(data).batch(2)
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
iterator = iter(dist_dataset)
|
||||
|
||||
def train_step(data):
|
||||
return math_ops.square(data)
|
||||
|
||||
@def_function.function
|
||||
def run(iterator):
|
||||
return distribution.experimental_local_results(
|
||||
distribution.run(
|
||||
train_step, args=(iterator.get_next_as_optional().get_value(),)))
|
||||
|
||||
self.assert_equal_flattened([[25., 36.]], [run(iterator)])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies, mode=["eager"]))
|
||||
def testGetNextAsOptionalExampleUsage(self, distribution):
|
||||
global_batch_size = 2
|
||||
steps_per_loop = 6
|
||||
dataset = dataset_ops.Dataset.range(
|
||||
8, output_type=dtypes.int32).batch(global_batch_size)
|
||||
distributed_iterator = iter(
|
||||
distribution.experimental_distribute_dataset(dataset))
|
||||
|
||||
@def_function.function
|
||||
def train_fn(distributed_iterator):
|
||||
|
||||
def step_fn(x):
|
||||
return x
|
||||
|
||||
for _ in math_ops.range(steps_per_loop):
|
||||
optional_data = distributed_iterator.get_next_as_optional()
|
||||
if not optional_data.has_value():
|
||||
break
|
||||
distribution.run(step_fn, args=(optional_data.get_value(),))
|
||||
|
||||
train_fn(distributed_iterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.tpu_strategies, mode=["eager"]))
|
||||
def testFullEagerTPU(self, distribution):
|
||||
dataset = get_dataset_from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
|
||||
|
@ -200,6 +200,7 @@ import six
|
||||
from tensorflow.python.autograph.core import ag_ctx as autograph_ctx
|
||||
from tensorflow.python.autograph.impl import api as autograph
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -2879,6 +2880,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def get_next(self):
|
||||
return self._iterator.get_next()
|
||||
|
||||
def get_next_as_optional(self):
|
||||
return iterator_ops.get_next_as_optional(self._iterator)
|
||||
|
||||
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||
def initialize(self):
|
||||
"""Initialize underlying iterators.
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.data.ops import optional_ops
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -235,6 +236,40 @@ class DistributedIteratorInterface(collections.Iterator,
|
||||
raise NotImplementedError(
|
||||
"DistributedIterator.element_spec() must be implemented in descendants")
|
||||
|
||||
def get_next_as_optional(self):
|
||||
"""Returns a `tf.experimental.Optional` that contains the next value for all replicas.
|
||||
|
||||
If the `tf.distribute.DistributedIterator` has reached the end of the
|
||||
sequence, the returned `tf.experimental.Optional` will have no value.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> global_batch_size = 2
|
||||
>>> steps_per_loop = 2
|
||||
>>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
|
||||
>>> distributed_iterator = iter(
|
||||
... strategy.experimental_distribute_dataset(dataset))
|
||||
>>> def step_fn(x):
|
||||
... return x
|
||||
>>> @tf.function
|
||||
... def train_fn(distributed_iterator):
|
||||
... for _ in tf.range(steps_per_loop):
|
||||
... optional_data = distributed_iterator.get_next_as_optional()
|
||||
... if not optional_data.has_value():
|
||||
... break
|
||||
... tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))
|
||||
>>> train_fn(distributed_iterator)
|
||||
... # ([0 1],)
|
||||
... # ([2 3],)
|
||||
|
||||
Returns:
|
||||
An `tf.experimental.Optional` object representing the next value from the
|
||||
`tf.distribute.DistributedIterator` (if it has one) or no value.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"get_next_as_optional() not implemented in descendants")
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedDataset", v1=[])
|
||||
class DistributedDatasetInterface(collections.Iterable,
|
||||
@ -622,6 +657,31 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def get_next_as_optional(self):
|
||||
global_has_value, replicas = _get_next_as_optional(self, self._strategy)
|
||||
|
||||
def return_none():
|
||||
return optional_ops.Optional.empty(self._element_spec)
|
||||
|
||||
def return_value(replicas):
|
||||
"""Wraps the inputs for replicas in an `tf.experimental.Optional`."""
|
||||
results = []
|
||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
devices = self._input_workers.compute_devices_for_worker(i)
|
||||
for j, device in enumerate(devices):
|
||||
with ops.device(device):
|
||||
result = replicas[i][j]
|
||||
results.append(result)
|
||||
replicas = results
|
||||
|
||||
return optional_ops.Optional.from_value(
|
||||
distribute_utils.regroup(replicas))
|
||||
|
||||
return control_flow_ops.cond(global_has_value,
|
||||
lambda: return_value(replicas),
|
||||
lambda: return_none()) # pylint: disable=unnecessary-lambda
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns the next input from the iterator for all replicas."""
|
||||
if not self._enable_get_next_as_optional:
|
||||
|
@ -185,38 +185,76 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
def test_get_next(iterator):
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate([
|
||||
distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))
|
||||
])
|
||||
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate(
|
||||
[distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
next_element = iterator.get_next()
|
||||
evaluate([
|
||||
distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))
|
||||
])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
else:
|
||||
if api_type == "wrap_into_iterator":
|
||||
self.skipTest("unsupported test combination")
|
||||
else:
|
||||
iterator = iter(dataset)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate([
|
||||
distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))
|
||||
])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
|
||||
def test_get_next_as_optional(iterator):
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next_as_optional()
|
||||
computed_value = evaluate([
|
||||
distribute_utils.select_replica(r, next_element.get_value())
|
||||
for r in range(len(devices))
|
||||
])
|
||||
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
|
||||
next_element = iterator.get_next_as_optional()
|
||||
self.assertFalse(self.evaluate(next_element.has_value()))
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
evaluate([
|
||||
distribute_utils.select_replica(r, next_element.get_value())
|
||||
for r in range(len(devices))
|
||||
])
|
||||
|
||||
test_get_next(iterator)
|
||||
|
||||
# re-initializing the iterator
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Not testing get_next_as_optional in TF1")
|
||||
else:
|
||||
if api_type == "wrap_into_iterator":
|
||||
self.skipTest("unsupported test combination")
|
||||
else:
|
||||
iterator = iter(dataset)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
computed_value = evaluate(
|
||||
[distribute_utils.select_replica(r, next_element)
|
||||
for r in range(len(devices))])
|
||||
self.assertEqual(len(expected_value), len(computed_value))
|
||||
for i in range(len(expected_value)):
|
||||
self.assertAllEqual(expected_value[i], computed_value[i])
|
||||
test_get_next_as_optional(iterator)
|
||||
|
||||
if iteration_type == "for_loop" and context.executing_eagerly():
|
||||
actual_values = []
|
||||
|
@ -13,4 +13,8 @@ tf_class {
|
||||
name: "get_next"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_next_as_optional"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user