diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 63f32d8101c..47af7174fb1 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -52,6 +52,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import nest +CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher CommunicationImplementation = collective_util.CommunicationImplementation ReduceOp = reduce_util.ReduceOp IndexedSlicesValue = indexed_slices.IndexedSlicesValue @@ -107,9 +108,8 @@ def enable_collective_ops(): protocol=cluster_resolver.rpc_layer) context.context().enable_collective_ops(server_def) # Recover default flag values. - cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = True - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True - cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = False + CollectiveReplicaLauncher._prefer_unique_instance_key = True + CollectiveReplicaLauncher._prefer_ordering_token = False class MultiProcessPoolRunner(): @@ -215,12 +215,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): "gpus_per_process", "reduce_op", "communication_options", - "prefer_scoped_allocator", - "prefer_collective_v2", + "prefer_unique_instance_key", ]) RunOptions.__new__.__defaults__ = (["eager", "func_graph"], 2, 0, ReduceOp.SUM, - collective_util.Options(), True, False) + collective_util.Options(), True) def reduce_and_verify(self, inputs, expect, options): """Reduce the given `inputs` and verify the output matches `expect`. @@ -234,8 +233,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): """ def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - options.prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + options.prefer_unique_instance_key) collective, devices, pid = self.make_collective(options.num_processes, options.gpus_per_process) @@ -273,10 +272,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): """ def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = ( - options.prefer_scoped_allocator) - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - options.prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + options.prefer_unique_instance_key) collective, devices, pid = self.make_collective(options.num_processes, options.gpus_per_process) @@ -321,9 +318,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): CommunicationImplementation.NCCL, ], reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testAllReduceDense(self, num_processes, required_gpus, implementation, - reduce_op, prefer_collective_v2): + reduce_op, prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") @@ -337,7 +334,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): reduce_op=reduce_op, communication_options=collective_util.Options( implementation=implementation), - prefer_collective_v2=prefer_collective_v2) + prefer_unique_instance_key=prefer_unique_instance_key) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [1.0, 2.0, 3.0, 4.0] @@ -363,9 +360,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): ], # TODO(b/166682130): add MEAN reduce once the bug is fixed. reduce_op=ReduceOp.SUM, - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testAllReduceSparse(self, num_processes, required_gpus, implementation, - reduce_op, prefer_collective_v2): + reduce_op, prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") @@ -380,7 +377,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): reduce_op=reduce_op, communication_options=collective_util.Options( implementation=implementation), - prefer_collective_v2=prefer_collective_v2) + prefer_unique_instance_key=prefer_unique_instance_key) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [ @@ -412,8 +409,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): self.reduce_and_verify(inputs, expect, options) @combinations.generate( - combinations.combine(prefer_collective_v2=[True, False])) - def testAllReduceSparseVariableLength(self, prefer_collective_v2): + combinations.combine(prefer_unique_instance_key=[True, False])) + def testAllReduceSparseVariableLength(self, prefer_unique_instance_key): # One device per process, 2 processes, 2 replicas in total. inputs = [ IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]), @@ -431,7 +428,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): mode=["func_graph"], # Sparse reduce is not supported in eager. num_processes=2, reduce_op=ReduceOp.SUM, - prefer_collective_v2=prefer_collective_v2)) + prefer_unique_instance_key=prefer_unique_instance_key)) @combinations.generate( combinations.combine( @@ -443,11 +440,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): CommunicationImplementation.NCCL, ], reduce_op=[ReduceOp.SUM, ReduceOp.MEAN], - prefer_scoped_allocator=[True, False], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testBatchAllReduceDense(self, num_processes, required_gpus, implementation, reduce_op, - prefer_scoped_allocator, prefer_collective_v2): + prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") @@ -462,8 +458,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): reduce_op=reduce_op, communication_options=collective_util.Options( implementation=implementation), - prefer_scoped_allocator=prefer_scoped_allocator, - prefer_collective_v2=prefer_collective_v2) + prefer_unique_instance_key=prefer_unique_instance_key) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] @@ -489,11 +484,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): ], # TODO(b/166682130): add MEAN reduce once the bug is fixed. reduce_op=ReduceOp.SUM, - prefer_scoped_allocator=[True, False], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testBatchAllReduceSparse(self, num_processes, required_gpus, implementation, reduce_op, - prefer_scoped_allocator, prefer_collective_v2): + prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") @@ -509,8 +503,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): reduce_op=reduce_op, communication_options=collective_util.Options( implementation=implementation), - prefer_scoped_allocator=prefer_scoped_allocator, - prefer_collective_v2=prefer_collective_v2) + prefer_unique_instance_key=prefer_unique_instance_key) group_size = options.num_processes * (options.gpus_per_process or 1) inputs_data = ([ @@ -578,13 +571,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): CommunicationImplementation.RING, CommunicationImplementation.NCCL, ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testAllGatherSameShape(self, num_processes, required_gpus, implementation, - func_mode, axis, prefer_collective_v2): + func_mode, axis, prefer_unique_instance_key): def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, _ = self.make_collective(num_processes, required_gpus) options = collective_util.Options(implementation=implementation) @@ -624,7 +617,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation): def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True + CollectiveReplicaLauncher._prefer_unique_instance_key = True collective, devices, _ = self.make_collective(num_processes, required_gpus) options = collective_util.Options(implementation=implementation) @@ -653,15 +646,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation=[ CommunicationImplementation.NCCL, CommunicationImplementation.RING ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes, required_gpus, implementation, - prefer_collective_v2): + prefer_unique_instance_key): def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, _ = self.make_collective(num_processes, required_gpus) options = collective_util.Options(implementation=implementation) @@ -715,13 +708,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation=[ CommunicationImplementation.NCCL, CommunicationImplementation.RING ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testInputsAreFunctionArgs(self, num_processes, required_gpus, - implementation, prefer_collective_v2): + implementation, prefer_unique_instance_key): def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, _ = self.make_collective(num_processes, required_gpus) options = collective_util.Options(implementation=implementation) @@ -757,16 +750,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation=[ CommunicationImplementation.RING, CommunicationImplementation.NCCL ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testTimeoutReduceDense(self, num_processes, implementation, required_gpus, - prefer_collective_v2): + prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") + def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, task_id = self.make_collective( num_processes, required_gpus) if task_id != 0: @@ -794,16 +788,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation=[ CommunicationImplementation.RING, CommunicationImplementation.NCCL ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testTimeoutBatchReduceDense(self, num_processes, implementation, - required_gpus, prefer_collective_v2): + required_gpus, prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, task_id = self.make_collective( num_processes, required_gpus) if task_id != 0: @@ -832,16 +826,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation=[ CommunicationImplementation.RING, CommunicationImplementation.NCCL ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testTimeoutReduceSparse(self, num_processes, implementation, - required_gpus, prefer_collective_v2): + required_gpus, prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, task_id = self.make_collective( num_processes, required_gpus) if task_id != 0: @@ -871,16 +865,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): implementation=[ CommunicationImplementation.RING, CommunicationImplementation.NCCL ], - prefer_collective_v2=[True, False])) + prefer_unique_instance_key=[True, False])) def testTimeoutBatchReduceSparse(self, num_processes, required_gpus, - implementation, prefer_collective_v2): + implementation, prefer_unique_instance_key): if (required_gpus == 0 and implementation == CommunicationImplementation.NCCL): self.skipTest("Skip CPU + NCCL combination") def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = ( - prefer_collective_v2) + CollectiveReplicaLauncher._prefer_unique_instance_key = ( + prefer_unique_instance_key) collective, devices, task_id = self.make_collective( num_processes, required_gpus) if task_id != 0: @@ -908,8 +902,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): def testNcclOrdering(self, num_processes, required_gpus): def replica_fn(): - cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True - cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = True + CollectiveReplicaLauncher._prefer_unique_instance_key = True + CollectiveReplicaLauncher._prefer_ordering_token = True collective, devices, _ = self.make_collective(num_processes, required_gpus) options = collective_util.Options( diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index cfc3bc0182b..c6edc2cf736 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -257,8 +257,7 @@ class CollectiveKeys(object): class CollectiveReplicaLauncher(object): """Launch collectives on one replica.""" - _prefer_scoped_allocator = True - _prefer_collective_v2 = True + _prefer_unique_instance_key = True _prefer_ordering_token = True def __init__(self, @@ -281,26 +280,19 @@ class CollectiveReplicaLauncher(object): return ops.control_dependencies([control_input]) return ops.NullContextmanager() - def _use_collective_v2(self): + def _use_unique_instance_key(self): if not ops.executing_eagerly_outside_functions(): return False - return CollectiveReplicaLauncher._prefer_collective_v2 - - def _use_scoped_allocator(self): - if self._use_collective_v2(): - # ScopedAllocator doesn't support collective V2. - return False - return CollectiveReplicaLauncher._prefer_scoped_allocator + return CollectiveReplicaLauncher._prefer_unique_instance_key def _use_ordering_token(self): - if not self._use_collective_v2(): - # Only collective V2 supports ordering token. + if not ops.executing_eagerly_outside_functions(): return False return CollectiveReplicaLauncher._prefer_ordering_token def _next_instance_key(self): """Returns the next instance key.""" - if self._use_collective_v2(): + if self._use_unique_instance_key(): # Assigning instance keys at function building time have issues since # different workers may retrace the function at different times. With # collective V2 we can use capture_call_time_value to use a placeholder as @@ -360,23 +352,14 @@ class CollectiveReplicaLauncher(object): ordering_token = self._get_ordering_token(communication_hint) with ops.device(self._device), \ self._control_input(control_input): - if self._use_collective_v2(): - return collective_ops.all_reduce_v2( - input_tensor, - self._group_size, - self._group_key, - instance_key, - communication_hint=communication_hint, - timeout=timeout, - ordering_token=ordering_token) - else: - return collective_ops.all_reduce( - input_tensor, - self._group_size, - self._group_key, - instance_key, - communication_hint=communication_hint, - timeout=timeout) + return collective_ops.all_reduce_v2( + input_tensor, + self._group_size, + self._group_key, + instance_key, + communication_hint=communication_hint, + timeout=timeout, + ordering_token=ordering_token) def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0): """All-gather a dense tensor. @@ -393,23 +376,14 @@ class CollectiveReplicaLauncher(object): instance_key = self._next_instance_key() ordering_token = self._get_ordering_token(communication_hint) with ops.device(self._device): - if self._use_collective_v2(): - return collective_ops.all_gather_v2( - input_tensor, - self._group_size, - self._group_key, - instance_key, - communication_hint=communication_hint, - timeout=timeout, - ordering_token=ordering_token) - else: - return collective_ops.all_gather( - input_tensor, - self._group_size, - self._group_key, - instance_key, - communication_hint=communication_hint, - timeout=timeout) + return collective_ops.all_gather_v2( + input_tensor, + self._group_size, + self._group_key, + instance_key, + communication_hint=communication_hint, + timeout=timeout, + ordering_token=ordering_token) def batch_all_reduce(self, input_tensor_packs, @@ -430,17 +404,17 @@ class CollectiveReplicaLauncher(object): Returns: A flat list of reduced tensors. """ - # We don't batch with concat in eager. It's easy to get it wrong because - # we need to avoid any numpy() calls on values produced by the async - # executor. This effectively disables batching in eager, but it's unlikely - # to all-reduce a large number of tensors in eager. - batch_with_concat = (not self._use_scoped_allocator() and - not context.executing_eagerly()) outputs = [] for pack in input_tensor_packs: - # TODO(b/169168846): inserts a parallel all_gather to verify packings - # are the same on each replica. - if batch_with_concat: + if context.executing_eagerly(): + # We don't batch in eager as it sometimes makes the performance worse + # due the concat/split ops. + for input_tensor in pack: + outputs.append( + self.all_reduce(input_tensor, None, communication_hint, timeout)) + else: + # TODO(b/169168846): inserts a parallel all_gather to verify packings + # are the same on each replica. with ops.device(self._device): flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] shapes = [array_ops.shape(t) for t in pack] @@ -455,19 +429,6 @@ class CollectiveReplicaLauncher(object): flat_outputs = array_ops.split(reduced, num_elements, axis=0) for shape, flat_output in zip(shapes, flat_outputs): outputs.append(array_ops.reshape(flat_output, shape)) - else: - # By placing all CollectiveReduce ops in a batch under single name - # scope, we ensure they will be picked up by the `ScopedAllocator` - # grappler optimizer and packed into a single all-reduce. - with ops.name_scope('allreduce'): - for input_tensor in pack: - if communication_hint == 'NCCL' and outputs: - control_input = outputs[-1] - else: - control_input = None - outputs.append( - self.all_reduce(input_tensor, control_input, communication_hint, - timeout)) return outputs