Support reducing IndexedSlices in CollectiveAllReduce._all_reduce.

PiperOrigin-RevId: 358860726
Change-Id: I1f2f4333fe87a22e4902dbea8d787c9fdf50fb9b
This commit is contained in:
Chenkai Kuang 2021-02-22 11:15:04 -08:00 committed by TensorFlower Gardener
parent 6b350fcde9
commit 85c597aaa2
5 changed files with 271 additions and 76 deletions

View File

@ -70,8 +70,9 @@ def validate_destinations(destinations):
"""Validates the `destination` is one of expected types."""
if not isinstance(
destinations,
(value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable,
six.string_types, tpu_values.TPUMirroredVariable
(value_lib.DistributedValues, ops.Tensor, ops.IndexedSlices,
ps_values.AggregatingVariable, six.string_types,
tpu_values.TPUMirroredVariable
)) and not resource_variable_ops.is_resource_variable(destinations):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, or a device string.")
@ -532,25 +533,25 @@ class CrossDeviceOps(object):
def _all_reduce(self, reduce_op, value, replica_id, options):
"""All-reduce the `value` across all replicas so that all get the result.
`value` can be a nested structure of tensors. The implementation should
generally batch the all-reduces when possible. `options` can be set to
hint the batching behavior.
`value` can be a nested structure of tensors or `IndexedSlices`. The
implementation should generally batch the all-reduces when possible.
`options` can be set to hint the batching behavior.
This API must be called in a replica context.
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
value: Value to be reduced. A tensor or a nested structure of tensors.
be combined.
value: Value to be reduced. A tensor or a nested structure of tensors or
`IndexedSlices`.
replica_id: An interger indicating the id of the replica where this
all_reduce is called under. This is the local replica id that ranges
from 0 to len(local_devices) - 1.
options: A `tf.distribute.experimental.CommunicationOptions`.
Returns:
A tensor or a nested strucutre of tensors with the reduced values. The
structure is the same as `value`.
A tensor/IndexedSlices or a nested strucutre of tensors/IndexedSlices with
the reduced values. The structure is the same as `value`.
"""
raise NotImplementedError("_all_reduce must be implemented in descendants.")
@ -1083,48 +1084,74 @@ class CollectiveAllReduce(CrossDeviceOps):
# TODO(b/122840926): reuse this method in _batch_all_reduce.
flat_values = nest.flatten(value)
if isinstance(flat_values[0], ops.IndexedSlices):
raise NotImplementedError("all_reduce doesn't support IndexedSlices.")
batch_size = len(flat_values)
implementation = options.implementation.value
# If NCCL launches can't be ordered (self._limited_nccl == True), we only
# use NCCL only when batch_size > 1, hoping that there's only one batched
# all-reduce, which is the gradients.
# use NCCL when batch_size > 1, hoping that there's only one batched
# all-reduce, which is the gradient aggregation in optimizer. For TF 2.x,
# NCCL launches are always ordered.
if (self._limited_nccl and
options.implementation == CommunicationImplementation.NCCL and
batch_size == 1):
len(flat_values) == 1):
implementation = CommunicationImplementation.AUTO.value
# Reverse the lists so that there's better chance that values follows
# the order in which they are calculated (e.g. when they're gradients), so
# as to overlap calculation with communication. However, this may not be
# optimal for cases like gradients of complicated non-sequential models.
#
# Note that we reverse the list before packing so that the first pack won't
# be too small, since it's more likely for first few packs to have long
# queuing time due to concurrent intense computation.
#
# TODO(b/147393503): explore solutions for optimal ordering.
flat_values.reverse()
packs = cross_device_utils.group_by_size(flat_values,
options.bytes_per_pack)
launcher = self._launchers[replica_id]
if not context.executing_eagerly() and replica_id == 0:
logging.info(
"Collective all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, implementation = %s, num_packs = %d", batch_size,
len(self._launchers), self._group_size, implementation, len(packs))
flat_results = launcher.batch_all_reduce(packs, implementation,
options.timeout_seconds)
dense_values, dense_indices, sparse_values, sparse_indices = (
cross_device_utils.split_by_sparsity(flat_values))
dense_results = []
sparse_results = []
if reduce_op == reduce_util.ReduceOp.MEAN:
for i, v in enumerate(flat_results):
flat_results[i] = v / self._group_size
flat_results.reverse()
if dense_values:
# Reverse the lists so that there's better chance that values follows
# the order in which they are calculated (e.g. when they're gradients), so
# as to overlap calculation with communication. However, this may not be
# optimal for cases like gradients of complicated non-sequential models.
#
# Note that we reverse the list before packing so that the first pack
# won't be too small, since it's more likely for first few packs to have
# long queuing time due to concurrent intense computation.
#
# TODO(b/147393503): explore solutions for optimal ordering.
dense_values.reverse()
packs = cross_device_utils.group_by_size(dense_values,
options.bytes_per_pack)
if not context.executing_eagerly() and replica_id == 0:
logging.info(
"Collective all_reduce tensors: %d all_reduces, num_devices = %d, "
"group_size = %d, implementation = %s, num_packs = %d",
len(dense_values), len(self._launchers), self._group_size,
implementation, len(packs))
dense_results = launcher.batch_all_reduce(packs, implementation,
options.timeout_seconds)
if reduce_op == reduce_util.ReduceOp.MEAN:
for i, v in enumerate(dense_results):
with ops.device(self._devices[replica_id]):
dense_results[i] = v / self._group_size
dense_results.reverse()
if sparse_values:
if not context.executing_eagerly() and replica_id == 0:
logging.info(
"Collective all_reduce IndexedSlices: %d all_reduces, num_devices ="
"%d, group_size = %d, implementation = %s", len(dense_values),
len(self._launchers), self._group_size, implementation)
for indexed_slice in sparse_values:
sparse_results.append(
launcher.all_reduce_indexed_slices(indexed_slice, implementation,
options.timeout_seconds))
if reduce_op == reduce_util.ReduceOp.MEAN:
for i, v in enumerate(sparse_results):
with ops.device(self._devices[replica_id]):
sparse_results[i] = ops.IndexedSlices(
values=sparse_results[i].values / self._group_size,
indices=sparse_results[i].indices,
dense_shape=sparse_results[i].dense_shape)
flat_results = cross_device_utils.stitch_values(
((dense_results, dense_indices), (sparse_results, sparse_indices)))
return nest.pack_sequence_as(value, flat_results)
def reduce_implementation(self, reduce_op, per_replica_value, destinations,

View File

@ -162,10 +162,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
gpu_per_process: number of GPUs (0 if no GPUs) used by each process.
Returns:
A tuple of (collective, devices, group_size) where collective is a instance
A tuple of (collective, devices, pid) where collective is a instance
of `CollectiveAllReduce`, devices are a list of local devices (str)
attached to the current process, and group_size is the group_size of
collective.
attached to the current process, and pid is the id of this process among
all participant processes.
"""
cluster_resolver = cluster_resolver_lib.TFConfigClusterResolver()
@ -319,8 +319,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
prefer_unique_instance_key=[True, False]))
def testAllReduceDense(self, num_processes, required_gpus, implementation,
reduce_op, prefer_unique_instance_key):
def testReduceDense(self, num_processes, required_gpus, implementation,
reduce_op, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -361,8 +361,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM,
prefer_unique_instance_key=[True, False]))
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op, prefer_unique_instance_key):
def testReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -410,7 +410,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(prefer_unique_instance_key=[True, False]))
def testAllReduceSparseVariableLength(self, prefer_unique_instance_key):
def testReduceSparseVariableLength(self, prefer_unique_instance_key):
# One device per process, 2 processes, 2 replicas in total.
inputs = [
IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
@ -441,9 +441,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
prefer_unique_instance_key=[True, False]))
def testBatchAllReduceDense(self, num_processes, required_gpus,
implementation, reduce_op,
prefer_unique_instance_key):
def testBatchReduceDense(self, num_processes, required_gpus, implementation,
reduce_op, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -485,9 +484,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM,
prefer_unique_instance_key=[True, False]))
def testBatchAllReduceSparse(self, num_processes, required_gpus,
implementation, reduce_op,
prefer_unique_instance_key):
def testBatchReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -571,8 +569,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
))
def testCollectiveAllReduce(self, num_processes, required_gpus,
implementation, reduce_op):
def testAllReduceDense(self, num_processes, required_gpus, implementation,
reduce_op):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -623,6 +621,126 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
get_global_mpr(num_processes).run(replica_fn)
@combinations.generate(
combinations.combine(
num_processes=[1, 2],
required_gpus=[0, 1, 2],
implementation=[
CommunicationImplementation.AUTO,
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
))
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
if (num_processes == 2 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
"physical GPUs for every process.")
def replica_fn():
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
group_size = num_processes * (required_gpus or 1)
@def_function.function
def collective_all_reduce():
results = []
for replica_id, device in enumerate(devices):
with ops.device(device):
value = IndexedSlices(
values=array_ops.identity([[1.]]),
indices=array_ops.identity([0]),
dense_shape=array_ops.identity([5, 1]))
results.append(
collective._all_reduce(reduce_op, value, replica_id, options))
return results
got = collective_all_reduce()
if reduce_op == ReduceOp.SUM:
expect = [IndexedSlices([[1. * group_size]], [0], [5, 1])
] * len(devices)
elif reduce_op == ReduceOp.MEAN:
expect = [IndexedSlices([[1.]], [0], [5, 1])] * len(devices)
self.assertAllClose(
nest.map_structure(ops.convert_to_tensor, got),
nest.map_structure(ops.convert_to_tensor, expect))
@def_function.function
def collective_batch_all_reduce():
results = []
for replica_id, device in enumerate(devices):
with ops.device(device):
value = (IndexedSlices(
array_ops.identity([[1.]]), array_ops.identity([0]),
array_ops.identity([5, 1])),
IndexedSlices(
array_ops.identity([[3.]]), array_ops.identity([2]),
array_ops.identity([5, 1])))
results.append(
collective._all_reduce(reduce_op, value, replica_id, options))
return results
got = collective_batch_all_reduce()
if reduce_op == ReduceOp.SUM:
expect = [(IndexedSlices([[1. * group_size]], [0], [5, 1]),
IndexedSlices([[3. * group_size]], [2], [5, 1]))
] * len(devices)
elif reduce_op == ReduceOp.MEAN:
expect = [(IndexedSlices([[1.]], [0], [5, 1]),
IndexedSlices([[3.]], [2], [5, 1]))] * len(devices)
self.assertAllClose(
nest.map_structure(ops.convert_to_tensor, got),
nest.map_structure(ops.convert_to_tensor, expect))
get_global_mpr(num_processes).run(replica_fn)
@combinations.generate(
combinations.combine(
num_processes=2,
required_gpus=0,
implementation=CommunicationImplementation.AUTO,
reduce_op=ReduceOp.SUM))
def testAllReduceMixedDenseAndSparse(self, num_processes, required_gpus,
implementation, reduce_op):
def replica_fn():
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
group_size = num_processes * (required_gpus or 1)
@def_function.function
def collective_batch_all_reduce():
results = []
for replica_id, device in enumerate(devices):
with ops.device(device):
value = (IndexedSlices(
array_ops.identity([[1.]]), array_ops.identity([0]),
array_ops.identity([5, 1])), array_ops.identity(1.0),
IndexedSlices(
array_ops.identity([[3.]]), array_ops.identity([2]),
array_ops.identity([5, 1])), array_ops.identity(2.0))
results.append(
collective._all_reduce(reduce_op, value, replica_id, options))
return results
got = collective_batch_all_reduce()
expect = [
(IndexedSlices([[1. * group_size]], [0], [5, 1]), 1.0 * group_size,
IndexedSlices([[3. * group_size]], [2], [5, 1]), 2.0 * group_size)
] * len(devices)
self.assertAllClose(
nest.map_structure(ops.convert_to_tensor, got),
nest.map_structure(ops.convert_to_tensor, expect))
get_global_mpr(num_processes).run(replica_fn)
@combinations.generate(
combinations.combine(
num_processes=[1, 2],

View File

@ -603,8 +603,9 @@ def copy_tensor_or_indexed_slices_to_device(value, device):
def is_indexed_slices(value):
if isinstance(value, ops.IndexedSlices):
return True
assert isinstance(value, value_lib.DistributedValues)
return all(isinstance(v, ops.IndexedSlices) for v in value.values)
if isinstance(value, value_lib.DistributedValues):
return all(isinstance(v, ops.IndexedSlices) for v in value.values)
return False
def split_by_sparsity(values):

View File

@ -2409,8 +2409,7 @@ class StrategyExtendedV2(object):
Args:
reduce_op: A `tf.distribute.ReduceOp` value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
be combined.
value: Value to be reduced. A tensor or a nested structure of tensors.
options: A `tf.distribute.experimental.CommunicationOptions`. Options to
perform collective operations. This overrides the default options if the

View File

@ -116,13 +116,14 @@ class ReduceTest(test.TestCase, parameterized.TestCase):
strategy=[
strategy_combinations.multi_worker_mirrored_2x1_cpu,
strategy_combinations.multi_worker_mirrored_2x1_gpu,
strategy_combinations.multi_worker_mirrored_2x2_gpu,
strategy_combinations.tpu_strategy,
] + strategy_combinations.strategies_minus_tpu,
tf_function=[combinations.tf_function, combinations.no_tf_function],
mode=['eager']))
class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
def testBasic(self, strategy, tf_function):
def testDense(self, strategy, tf_function):
if (isinstance(strategy, tpu_strategy.TPUStrategy) and
tf_function is combinations.no_tf_function):
self.skipTest('Skip TPUStrategy + eager combination.')
@ -132,7 +133,8 @@ class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
def replica_fn():
value = array_ops.identity(1.0)
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
reduced = strategy.extended._replica_ctx_all_reduce(
reduce_util.ReduceOp.SUM, value)
return reduced
return strategy.experimental_local_results(strategy.run(replica_fn))
@ -140,27 +142,75 @@ class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase):
got = fn()[0]
self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync)
def testNestedInput(self, strategy, tf_function):
# TODO(b/122840926): enable this test once cl/122840926 is submitted.
self.skipTest('Enable after cl/353109164 is submitted.')
if (isinstance(strategy, tpu_strategy.TPUStrategy) and
tf_function is combinations.no_tf_function):
self.skipTest('Skip TPUStrategy + eager combination.')
def testSparse(self, strategy, tf_function):
if tf_function is combinations.no_tf_function:
self.skipTest('Skip IndexedSlices + eager combination.')
@tf_function
def fn():
def replica_fn():
value = (array_ops.identity(1.0), array_ops.identity(1.0))
reduced = strategy.extended._replica_ctx_all_reduce('SUM', value)
value = ops.IndexedSlices(
values=array_ops.identity([[1.0]]),
indices=array_ops.identity([0]),
dense_shape=array_ops.identity([5, 1]))
reduced = strategy.extended._replica_ctx_all_reduce(
reduce_util.ReduceOp.SUM, value)
return reduced
return strategy.experimental_local_results(strategy.run(replica_fn))
got = fn()[0]
self.assertEqual(got, (1.0 * strategy.num_replicas_in_sync,
2.0 * strategy.num_replicas_in_sync))
expect = ops.IndexedSlices(
values=array_ops.identity([[1.0 * strategy.num_replicas_in_sync]]),
indices=array_ops.identity([0]),
dense_shape=array_ops.identity([5, 1]))
self.assertAllEqual(
ops.convert_to_tensor(got), ops.convert_to_tensor(expect))
def testNestedInput(self, strategy, tf_function):
# TODO(b/122840926): enable this test once cl/356664949 is submitted.
self.skipTest('Enable after cl/353109164 is submitted.')
if tf_function is combinations.no_tf_function:
self.skipTest('Skip IndexedSlices + eager combination.')
@tf_function
def fn():
def replica_fn():
value = (array_ops.identity(1.0),
ops.IndexedSlices(
values=array_ops.identity([[1.0]]),
indices=array_ops.identity([0]),
dense_shape=array_ops.identity([5, 1])),
array_ops.identity(2.0),
ops.IndexedSlices(
values=array_ops.identity([[2.0]]),
indices=array_ops.identity([1]),
dense_shape=array_ops.identity([5, 1])))
reduced = strategy.extended._replica_ctx_all_reduce(
reduce_util.ReduceOp.SUM, value)
return reduced
return strategy.experimental_local_results(strategy.run(replica_fn))
got = fn()[0]
expect = (1.0 * strategy.num_replicas_in_sync,
ops.IndexedSlices(
values=array_ops.identity(
[[1.0 * strategy.num_replicas_in_sync]]),
indices=array_ops.identity([0]),
dense_shape=array_ops.identity([5, 1])),
2.0 * strategy.num_replicas_in_sync,
ops.IndexedSlices(
values=array_ops.identity(
[[2.0 * strategy.num_replicas_in_sync]]),
indices=array_ops.identity([1]),
dense_shape=array_ops.identity([5, 1])))
self.assertAllEqual(
nest.map_structure(ops.convert_to_tensor, got),
nest.map_structure(ops.convert_to_tensor, expect))
def _make_indexed_slices(values, indices, dense_shape):