From eb31d8660d9b246c2f15a5e29d1b21de3731b36f Mon Sep 17 00:00:00 2001 From: Chenkai Kuang <chenkai@google.com> Date: Tue, 9 Feb 2021 15:22:19 -0800 Subject: [PATCH] Add all_reduce APIs that can be called in replica context to class `CrossDeviceOps` and `StrategyExtended`. For `StrategyExtended`, it is a private API that will be used by `ReplicaContext.all_reduce`. This is in preparation for deprecation of merge_call from user API. PiperOrigin-RevId: 356604626 Change-Id: I2528b35b87db1b93907b17a246dbfbcfcb64ad33 --- tensorflow/python/distribute/BUILD | 3 +- .../collective_all_reduce_strategy.py | 25 ++++++ .../python/distribute/cross_device_ops.py | 83 +++++++++++++++++++ .../distribute/cross_device_ops_test.py | 63 ++++++++++++++ .../python/distribute/distribute_lib.py | 36 ++++++++ .../python/distribute/strategy_common_test.py | 52 ++++++++++++ 6 files changed, 260 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index fae5eb24c52..6d858a8295e 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -51,10 +51,9 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", - "//tensorflow/python/eager:executor", + "//tensorflow/python/util", "//tensorflow/python/util:tf_export", "//tensorflow/tools/docs:doc_controls", - "@enum34_archive//:enum", "@six_archive//:six", ], ) diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 37a440bf46e..bdc90ae841b 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -31,6 +31,7 @@ from tensorflow.python.distribute import cross_device_utils from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_utils +from tensorflow.python.distribute import distribution_strategy_context as ds_context from tensorflow.python.distribute import input_lib from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.distribute import multi_worker_util @@ -771,6 +772,30 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): destinations=destinations, options=self._communication_options.merge(options)) + def _replica_ctx_all_reduce(self, reduce_op, value, options=None): + """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" + # This implementation avoids using `merge_call` and just launches collective + # ops in one replica. + if options is None: + options = collective_util.Options() + + if context.executing_eagerly(): + # In eager mode, falls back to the default implemenation that uses + # `merge_call`. Replica functions are running sequentially in eager mode, + # and due to the blocking nature of collective ops, execution will hang if + # collective ops are to be launched sequentially. + return super()._replica_ctx_all_reduce(reduce_op, value, options) + + replica_context = ds_context.get_replica_context() + assert replica_context, ( + "`StrategyExtended._replica_ctx_all_reduce` must be called in a " + "replica context") + return self._cross_device_ops._all_reduce( # pylint: disable=protected-access + reduce_op, + value, + replica_context._replica_id, # pylint: disable=protected-access + options) + def _check_health(self): while True: if self._check_health_thread_should_stop.is_set(): diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index 25fd0c94eae..f8f02ae94b7 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls @@ -522,6 +523,37 @@ class CrossDeviceOps(object): """ return simple_broadcast(tensor, destinations, always_mirrored=True) + # ========================== Collective APIs ================================ + # + # Different than `reduce`, `batch_reduce` and `broadcast` which must be called + # in cross-replcia context, collective APIs are to be called in replica + # context. + + def _all_reduce(self, reduce_op, value, replica_id, options): + """All-reduce the `value` across all replicas so that all get the result. + + `value` can be a nested structure of tensors. The implementation should + generally batch the all-reduces when possible. `options` can be set to + hint the batching behavior. + + This API must be called in a replica context. + + Args: + reduce_op: A `tf.distribute.ReduceOp` value specifying how values should + be combined. Allows using string representation of the enum such as + "SUM", "MEAN". + value: Value to be reduced. A tensor or a nested structure of tensors. + replica_id: An interger indicating the id of the replica where this + all_reduce is called under. This is the local replica id that ranges + from 0 to len(local_devices) - 1. + options: A `tf.distribute.experimental.CommunicationOptions`. + + Returns: + A tensor or a nested strucutre of tensors with the reduced values. The + structure is the same as `value`. + """ + raise NotImplementedError("_all_reduce must be implemented in descendants.") + @tf_export("distribute.ReductionToOneDevice") class ReductionToOneDevice(CrossDeviceOps): @@ -850,6 +882,8 @@ class AllReduceCrossDeviceOps(CrossDeviceOps): destinations = dense_values[0]._devices # pylint: disable=protected-access grouped = _group_value_by_device(dense_values) + # device_grad_packs: + # [[(t0_gpu0, None), (t1_gpu0, None)], [(t0_gpu1, None), (t1_gpu1, None)]] device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs) # The actual aggregation of the repacked gradients. Note that they are @@ -1044,6 +1078,55 @@ class CollectiveAllReduce(CrossDeviceOps): # Currently we only support equal number of devices on each worker. return self._group_size / len(self._devices) + def _all_reduce(self, reduce_op, value, replica_id, options): + """Implements CrossDeviceOps.all_reduce.""" + # TODO(b/122840926): reuse this method in _batch_all_reduce. + flat_values = nest.flatten(value) + + if isinstance(flat_values[0], ops.IndexedSlices): + raise NotImplementedError("all_reduce doesn't support IndexedSlices.") + + batch_size = len(flat_values) + + implementation = options.implementation.value + # If NCCL launches can't be ordered (self._limited_nccl == True), we only + # use NCCL only when batch_size > 1, hoping that there's only one batched + # all-reduce, which is the gradients. + if (self._limited_nccl and + options.implementation == CommunicationImplementation.NCCL and + batch_size == 1): + implementation = CommunicationImplementation.AUTO.value + + # Reverse the lists so that there's better chance that values follows + # the order in which they are calculated (e.g. when they're gradients), so + # as to overlap calculation with communication. However, this may not be + # optimal for cases like gradients of complicated non-sequential models. + # + # Note that we reverse the list before packing so that the first pack won't + # be too small, since it's more likely for first few packs to have long + # queuing time due to concurrent intense computation. + # + # TODO(b/147393503): explore solutions for optimal ordering. + flat_values.reverse() + packs = cross_device_utils.group_by_size(flat_values, + options.bytes_per_pack) + + launcher = self._launchers[replica_id] + if not context.executing_eagerly() and replica_id == 0: + logging.info( + "Collective all_reduce: %d all-reduces, num_devices = %d, " + "group_size = %d, implementation = %s, num_packs = %d", batch_size, + len(self._launchers), self._group_size, implementation, len(packs)) + flat_results = launcher.batch_all_reduce(packs, implementation, + options.timeout_seconds) + + if reduce_op == reduce_util.ReduceOp.MEAN: + for i, v in enumerate(flat_results): + flat_results[i] = v / self._group_size + flat_results.reverse() + + return nest.pack_sequence_as(value, flat_results) + def reduce_implementation(self, reduce_op, per_replica_value, destinations, options): values_util.mark_as_unsaveable() diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 47af7174fb1..d078a66ae81 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -560,6 +560,69 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): ] self.batch_reduce_and_verify(inputs, expect, options) + @combinations.generate( + combinations.combine( + num_processes=[1, 2], + required_gpus=[0, 1, 2], + implementation=[ + CommunicationImplementation.AUTO, + CommunicationImplementation.RING, + CommunicationImplementation.NCCL, + ], + reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], + )) + def testCollectiveAllReduce(self, num_processes, required_gpus, + implementation, reduce_op): + if (required_gpus == 0 and + implementation == CommunicationImplementation.NCCL): + self.skipTest("Skip CPU + NCCL combination") + if (num_processes == 2 and + implementation == CommunicationImplementation.NCCL): + self.skipTest("Skip NCCL + 2 processes combination. NCCL requires " + "physical GPUs for every process.") + + def replica_fn(): + collective, devices, _ = self.make_collective(num_processes, + required_gpus) + options = collective_util.Options(implementation=implementation) + group_size = num_processes * (required_gpus or 1) + + @def_function.function + def collective_all_reduce(): + results = [] + for replica_id, device in enumerate(devices): + with ops.device(device): + value = constant_op.constant(1.0) + results.append( + collective._all_reduce(reduce_op, value, replica_id, options)) + return results + + got = collective_all_reduce() + if reduce_op == ReduceOp.SUM: + expect = [1.0 * group_size] * len(devices) + elif reduce_op == ReduceOp.MEAN: + expect = [1.0] * len(devices) + self.assertAllClose(got, expect) + + @def_function.function + def collective_batch_all_reduce(): + results = [] + for replica_id, device in enumerate(devices): + with ops.device(device): + value = (constant_op.constant(1.0), constant_op.constant(2.0)) + results.append( + collective._all_reduce(reduce_op, value, replica_id, options)) + return results + + got = collective_batch_all_reduce() + if reduce_op == ReduceOp.SUM: + expect = [(1.0 * group_size, 2.0 * group_size)] * len(devices) + elif reduce_op == ReduceOp.MEAN: + expect = [(1.0, 2.0)] * len(devices) + self.assertAllClose(got, expect) + + get_global_mpr(num_processes).run(replica_fn) + @combinations.generate( combinations.combine( num_processes=[1, 2], diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index f5618835ea7..68b9ee1fc79 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -2384,6 +2384,42 @@ class StrategyExtendedV2(object): for t, v in value_destination_pairs ] + def _replica_ctx_all_reduce(self, reduce_op, value, options=None): + """All-reduce `value` across all replicas so that all get the final result. + + If `value` is a nested structure of tensors, all-reduces of these tensors + will be batched when possible. `options` can be set to hint the batching + behavior. + + This API must be called in a replica context. + + Args: + reduce_op: A `tf.distribute.ReduceOp` value specifying how values should + be combined. Allows using string representation of the enum such as + "SUM", "MEAN". + value: Value to be reduced. A tensor or a nested structure of tensors. + options: A `tf.distribute.experimental.CommunicationOptions`. Options to + perform collective operations. This overrides the default options if the + `tf.distribute.Strategy` takes one in the constructor. + + Returns: + A tensor or a nested strucutre of tensors with the reduced values. The + structure is the same as `value`. + """ + if options is None: + options = collective_util.Options() + replica_context = distribution_strategy_context.get_replica_context() + assert replica_context, ( + "`StrategyExtended._replica_ctx_all_reduce` must be called in" + " a replica context") + + def merge_fn(_, flat_value): + return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value], + options) + + reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) + return nest.pack_sequence_as(value, reduced) + def _gather_to(self, value, destinations, axis, options=None): """Gather `value` across replicas along axis-th dimension to `destinations`. diff --git a/tensorflow/python/distribute/strategy_common_test.py b/tensorflow/python/distribute/strategy_common_test.py index 6b19a744457..d70f2c8c4a6 100644 --- a/tensorflow/python/distribute/strategy_common_test.py +++ b/tensorflow/python/distribute/strategy_common_test.py @@ -111,6 +111,58 @@ class ReduceTest(test.TestCase, parameterized.TestCase): self.assertEqual(3 * strategy.num_replicas_in_sync, x_s) +@combinations.generate( + combinations.combine( + strategy=[ + strategy_combinations.multi_worker_mirrored_2x1_cpu, + strategy_combinations.multi_worker_mirrored_2x1_gpu, + strategy_combinations.tpu_strategy, + ] + strategy_combinations.strategies_minus_tpu, + tf_function=[combinations.tf_function, combinations.no_tf_function], + mode=['eager'])) +class ReplicaCtxAllReduceTest(test.TestCase, parameterized.TestCase): + + def testBasic(self, strategy, tf_function): + if (isinstance(strategy, tpu_strategy.TPUStrategy) and + tf_function is combinations.no_tf_function): + self.skipTest('Skip TPUStrategy + eager combination.') + + @tf_function + def fn(): + + def replica_fn(): + value = constant_op.constant(1.0) + reduced = strategy.extended._replica_ctx_all_reduce('SUM', value) + return reduced + + return strategy.experimental_local_results(strategy.run(replica_fn)) + + got = fn()[0] + self.assertEqual(got, 1.0 * strategy.num_replicas_in_sync) + + def testNestedInput(self, strategy, tf_function): + # TODO(b/122840926): enable this test once cl/122840926 is submitted. + self.skipTest('Enable after cl/353109164 is submitted.') + + if (isinstance(strategy, tpu_strategy.TPUStrategy) and + tf_function is combinations.no_tf_function): + self.skipTest('Skip TPUStrategy + eager combination.') + + @tf_function + def fn(): + + def replica_fn(): + value = (constant_op.constant(1.0), constant_op.constant(2.0)) + reduced = strategy.extended._replica_ctx_all_reduce('SUM', value) + return reduced + + return strategy.experimental_local_results(strategy.run(replica_fn)) + + got = fn()[0] + self.assertEqual(got, (1.0 * strategy.num_replicas_in_sync, + 2.0 * strategy.num_replicas_in_sync)) + + def _make_indexed_slices(values, indices, dense_shape): tensor = ops.IndexedSlices( values=constant_op.constant(values),