Add all_reduce APIs that can be called in replica context to class CrossDeviceOps and StrategyExtended.

For `StrategyExtended`, it is a private API that will be used by `ReplicaContext.all_reduce`.

This is in preparation for deprecation of merge_call from user API.

PiperOrigin-RevId: 356604626
Change-Id: I2528b35b87db1b93907b17a246dbfbcfcb64ad33
This commit is contained in:
Chenkai Kuang 2021-02-09 15:22:19 -08:00 committed by TensorFlower Gardener
parent 2a2cf507f0
commit eb31d8660d
6 changed files with 260 additions and 2 deletions

View File

@ -51,10 +51,9 @@ py_library(
"//tensorflow/python:tensor_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:executor",
"//tensorflow/python/util",
"//tensorflow/python/util:tf_export",
"//tensorflow/tools/docs:doc_controls",
"@enum34_archive//:enum",
"@six_archive//:six",
],
)

View File

@ -31,6 +31,7 @@ from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import multi_worker_util
@ -771,6 +772,30 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
destinations=destinations,
options=self._communication_options.merge(options))
def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
"""Implements `StrategyExtendedV2._replica_ctx_all_reduce`."""
# This implementation avoids using `merge_call` and just launches collective
# ops in one replica.
if options is None:
options = collective_util.Options()
if context.executing_eagerly():
# In eager mode, falls back to the default implemenation that uses
# `merge_call`. Replica functions are running sequentially in eager mode,
# and due to the blocking nature of collective ops, execution will hang if
# collective ops are to be launched sequentially.
return super()._replica_ctx_all_reduce(reduce_op, value, options)
replica_context = ds_context.get_replica_context()
assert replica_context, (
"`StrategyExtended._replica_ctx_all_reduce` must be called in a "
"replica context")
return self._cross_device_ops._all_reduce( # pylint: disable=protected-access
reduce_op,
value,
replica_context._replica_id, # pylint: disable=protected-access
options)
def _check_health(self):
while True:
if self._check_health_thread_should_stop.is_set():

View File

@ -45,6 +45,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
@ -522,6 +523,37 @@ class CrossDeviceOps(object):
"""
return simple_broadcast(tensor, destinations, always_mirrored=True)
# ========================== Collective APIs ================================
#
# Different than `reduce`, `batch_reduce` and `broadcast` which must be called
# in cross-replcia context, collective APIs are to be called in replica
# context.
def _all_reduce(self, reduce_op, value, replica_id, options):
"""All-reduce the `value` across all replicas so that all get the result.
`value` can be a nested structure of tensors. The implementation should
generally batch the all-reduces when possible. `options` can be set to
hint the batching behavior.
This API must be called in a replica context.
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value: Value to be reduced. A tensor or a nested structure of tensors.
replica_id: An interger indicating the id of the replica where this
all_reduce is called under. This is the local replica id that ranges
from 0 to len(local_devices) - 1.
options: A `tf.distribute.experimental.CommunicationOptions`.
Returns:
A tensor or a nested strucutre of tensors with the reduced values. The
structure is the same as `value`.
"""
raise NotImplementedError("_all_reduce must be implemented in descendants.")
@tf_export("distribute.ReductionToOneDevice")
class ReductionToOneDevice(CrossDeviceOps):
@ -850,6 +882,8 @@ class AllReduceCrossDeviceOps(CrossDeviceOps):
destinations = dense_values[0]._devices # pylint: disable=protected-access
grouped = _group_value_by_device(dense_values)
# device_grad_packs:
# [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]]
device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
# The actual aggregation of the repacked gradients. Note that they are
@ -1044,6 +1078,55 @@ class CollectiveAllReduce(CrossDeviceOps):
# Currently we only support equal number of devices on each worker.
return self._group_size / len(self._devices)
def _all_reduce(self, reduce_op, value, replica_id, options):
"""Implements CrossDeviceOps.all_reduce."""
# TODO(b/122840926): reuse this method in _batch_all_reduce.
flat_values = nest.flatten(value)
if isinstance(flat_values[0], ops.IndexedSlices):
raise NotImplementedError("all_reduce doesn't support IndexedSlices.")
batch_size = len(flat_values)
implementation = options.implementation.value
# If NCCL launches can't be ordered (self._limited_nccl == True), we only
# use NCCL only when batch_size > 1, hoping that there's only one batched
# all-reduce, which is the gradients.
if (self._limited_nccl and
options.implementation == CommunicationImplementation.NCCL and
batch_size == 1):
implementation = CommunicationImplementation.AUTO.value
# Reverse the lists so that there's better chance that values follows
# the order in which they are calculated (e.g. when they're gradients), so
# as to overlap calculation with communication. However, this may not be
# optimal for cases like gradients of complicated non-sequential models.
#
# Note that we reverse the list before packing so that the first pack won't
# be too small, since it's more likely for first few packs to have long
# queuing time due to concurrent intense computation.
#
# TODO(b/147393503): explore solutions for optimal ordering.
flat_values.reverse()
packs = cross_device_utils.group_by_size(flat_values,
options.bytes_per_pack)
launcher = self._launchers[replica_id]
if not context.executing_eagerly() and replica_id == 0:
logging.info(
"Collective all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, implementation = %s, num_packs = %d", batch_size,
len(self._launchers), self._group_size, implementation, len(packs))
flat_results = launcher.batch_all_reduce(packs, implementation,
options.timeout_seconds)
if reduce_op == reduce_util.ReduceOp.MEAN:
for i, v in enumerate(flat_results):
flat_results[i] = v / self._group_size
flat_results.reverse()
return nest.pack_sequence_as(value, flat_results)
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
options):
values_util.mark_as_unsaveable()

View File

@ -560,6 +560,69 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
]
self.batch_reduce_and_verify(inputs, expect, options)
@combinations.generate(
combinations.combine(
num_processes=[1, 2],
required_gpus=[0, 1, 2],
implementation=[
CommunicationImplementation.AUTO,
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
))
def testCollectiveAllReduce(self, num_processes, required_gpus,
implementation, reduce_op):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
if (num_processes == 2 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
"physical GPUs for every process.")
def replica_fn():
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
group_size = num_processes * (required_gpus or 1)
@def_function.function
def collective_all_reduce():
results = []
for replica_id, device in enumerate(devices):
with ops.device(device):
value = constant_op.constant(1.0)
results.append(
collective._all_reduce(reduce_op, value, replica_id, options))
return results
got = collective_all_reduce()
if reduce_op == ReduceOp.SUM:
expect = [1.0 * group_size] * len(devices)
elif reduce_op == ReduceOp.MEAN:
expect = [1.0] * len(devices)
self.assertAllClose(got, expect)
@def_function.function
def collective_batch_all_reduce():
results = []
for replica_id, device in enumerate(devices):
with ops.device(device):
value = (constant_op.constant(1.0), constant_op.constant(2.0))
results.append(
collective._all_reduce(reduce_op, value, replica_id, options))
return results
got = collective_batch_all_reduce()
if reduce_op == ReduceOp.SUM:
expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices)
elif reduce_op == ReduceOp.MEAN:
expect = [(1.0, 2.0)] * len(devices)
self.assertAllClose(got, expect)
get_global_mpr(num_processes).run(replica_fn)
@combinations.generate(
combinations.combine(
num_processes=[1, 2],

View File

@ -2384,6 +2384,42 @@ class StrategyExtendedV2(object):
for t, v in value_destination_pairs
]
def _replica_ctx_all_reduce(self, reduce_op, value, options=None):
"""All-reduce `value` across all replicas so that all get the final result.
If `value` is a nested structure of tensors, all-reduces of these tensors
will be batched when possible. `options` can be set to hint the batching
behavior.
This API must be called in a replica context.
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value: Value to be reduced. A tensor or a nested structure of tensors.
options: A `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the
`tf.distribute.Strategy` takes one in the constructor.
Returns:
A tensor or a nested strucutre of tensors with the reduced values. The
structure is the same as `value`.
"""
if options is None:
options = collective_util.Options()
replica_context = distribution_strategy_context.get_replica_context()
assert replica_context, (
"`StrategyExtended._replica_ctx_all_reduce` must be called in"
" a replica context")
def merge_fn(_, flat_value):
return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value],
options)
reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),))
return nest.pack_sequence_as(value, reduced)
def _gather_to(self, value, destinations, axis, options=None):
"""Gather `value` across replicas along axis-th dimension to `destinations`.

View File

@ -111,6 +111,58 @@ class ReduceTest(test.TestCase, parameterized.TestCase):
self.assertEqual(3 * strategy.num_replicas_in_sync, x_s)
@combinations.generate(
combinations.combine(
strategy=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.tpu_strategy,
] + strategy_combinations.strategies_minus_tpu,
tf_function=[combinations.tf_function, combinations.no_tf_function],
mode=['eager']))
class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
def testBasic(self, strategy, tf_function):
if (isinstance(strategy, tpu_strategy.TPUStrategy) and
tf_function is combinations.no_tf_function):
self.skipTest('Skip TPUStrategy + eager combination.')
@tf_function
def fn():
def replica_fn():
value = constant_op.constant(1.0)
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
return reduced
return strategy.experimental_local_results(strategy.run(replica_fn))
got = fn()[0]
self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
def testNestedInput(self, strategy, tf_function):
# TODO(b/122840926): enable this test once cl/122840926 is submitted.
self.skipTest('Enable after cl/353109164 is submitted.')
if (isinstance(strategy, tpu_strategy.TPUStrategy) and
tf_function is combinations.no_tf_function):
self.skipTest('Skip TPUStrategy + eager combination.')
@tf_function
def fn():
def replica_fn():
value = (constant_op.constant(1.0), constant_op.constant(2.0))
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
return reduced
return strategy.experimental_local_results(strategy.run(replica_fn))
got = fn()[0]
self.assertEqual(got, (1.0 * strategy.num_replicas_in_sync,
2.0 * strategy.num_replicas_in_sync))
def _make_indexed_slices(values, indices, dense_shape):
tensor = ops.IndexedSlices(
values=constant_op.constant(values),