diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index dac29b1c15e..f5b74545ad9 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -106,6 +106,7 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:nccl_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", ], @@ -1079,6 +1080,7 @@ cuda_py_test( ":multi_process_runner", ":multi_worker_test_base", ":reduce_util", + ":test_util", ":values", "//tensorflow/python:array_ops", "//tensorflow/python:collective_ops", diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index c5aca728827..3c424b301a8 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -990,6 +990,11 @@ class CollectiveAllReduce(CrossDeviceOps): all workers and then put results on the right destinations. """ + # Whether to only use NCCL for batched all-reduce when NCCL is requested. This + # is because of the lack of mechanism to order NCCL operations + # deterministically. + _limited_nccl = True + def __init__(self, devices, group_size, collective_keys=None): """Initializes the object. @@ -1121,8 +1126,8 @@ class CollectiveAllReduce(CrossDeviceOps): # all-reduce, which is the gradients. # TODO(b/132575814): switch to NCCL for all collectives when communication # is NCCL if and only if we can order collectives deterministically. - # is NCCL. - if (options.implementation == CommunicationImplementation.NCCL and + if (self._limited_nccl and + options.implementation == CommunicationImplementation.NCCL and batch_size == 1): implementation = CommunicationImplementation.AUTO.value @@ -1182,8 +1187,9 @@ class CollectiveAllReduce(CrossDeviceOps): # For now, we use NCCL only when batch_size > 1. # TODO(b/132575814): switch to NCCL for all collectives when implementation # is NCCL. - if options.implementation == CommunicationImplementation.NCCL and len( - per_replica_values) == 1: + if (self._limited_nccl and + options.implementation == CommunicationImplementation.NCCL and + len(per_replica_values) == 1): implementation = CommunicationImplementation.AUTO.value gathered_values = [] diff --git a/tensorflow/python/distribute/cross_device_ops_test.py b/tensorflow/python/distribute/cross_device_ops_test.py index 191394f69af..a5818c37aa5 100644 --- a/tensorflow/python/distribute/cross_device_ops_test.py +++ b/tensorflow/python/distribute/cross_device_ops_test.py @@ -32,9 +32,11 @@ from tensorflow.python.distribute import collective_util from tensorflow.python.distribute import combinations from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import cross_device_utils +from tensorflow.python.distribute import device_util from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base from tensorflow.python.distribute import reduce_util +from tensorflow.python.distribute import test_util from tensorflow.python.distribute import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import def_function @@ -70,7 +72,12 @@ def make_per_replica_value(value, devices): """ values = [] for device_idx, device in enumerate(devices): - v = value(device_idx) if callable(value) else value + if callable(value): + v = value(device_idx) + elif isinstance(value, list): + v = value[device_idx] + else: + v = value if isinstance(v, IndexedSlicesValue): with ops.device(device): values.append( @@ -99,6 +106,11 @@ def enable_collective_ops(): task_index=cluster_resolver.task_id, protocol=cluster_resolver.rpc_layer) context.context().enable_collective_ops(server_def) + # Recover default flag values. + cross_device_ops_lib.CollectiveAllReduce._limited_nccl = True + cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = True + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = False + cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = False class MultiProcessPoolRunner(): @@ -858,9 +870,101 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase): get_global_mpr(num_processes).run(replica_fn) + @combinations.generate(combinations.combine(num_processes=1, required_gpus=2)) + def testNcclOrdering(self, num_processes, required_gpus): + + def replica_fn(): + cross_device_ops_lib.CollectiveAllReduce._limited_nccl = False + cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True + cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = True + collective, devices, _ = self.make_collective( + num_processes, required_gpus) + options = collective_util.Options( + implementation=CommunicationImplementation.NCCL) + + v_dense = make_per_replica_value([1.0, 1.0], devices) + v_sparse = make_per_replica_value([ + IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]), + IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]), + ], devices) + + @def_function.function + def nested_dense(): + collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) + + @def_function.function + def nested_sparse(): + collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) + + # All collectives, function calls, if clause and while loops should be + # chained by control dependencies, so that the execution order is + # deterministic. + @def_function.function + def f(): + # pylint: disable=pointless-statement + collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) + # reducing dense value. + collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) + # reducing sparse value. + collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) + # reduce dense value in nested tf.function. + nested_dense() + # reduce sparse value in nested tf.function. + nested_sparse() + # reduce dense value in tf.cond. + if array_ops.identity(1.0) > array_ops.identity(2.0): + collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) + else: + v_dense + # reduce sparse value in tf.cond. + if array_ops.identity(1.0) > array_ops.identity(2.0): + v_sparse + else: + collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, + options) + # reduce dense value in tf.while_loop. + i = array_ops.identity(1) + while i < 3: + collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) + i += 1 + # reduce sparse value in tf.while_loop. + i = array_ops.identity(1) + while i < 3: + collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, + options) + i += 1 + # reducing dense and sparse value again. + collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options) + collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options) + # pylint: enable=pointless-statement + + graph = f.get_concrete_function().graph + should_be_ordered = set([ + "CollectiveReduce", "CollectiveGather", "If", "While", + "StatefulPartitionedCall" + ]) + nodes_by_device = {} + for op in graph.get_operations(): + if op.type in should_be_ordered: + if op.device not in nodes_by_device: + nodes_by_device[op.device] = [] + nodes_by_device[op.device].append(op) + order = test_util.topological_sort_operations(graph.get_operations()) + for device in devices: + device = device_util.canonicalize(device) + # Those function ops don't have device annotations, but they contain + # collectives for both devices so we always include them. + operations = nodes_by_device[device] + nodes_by_device[""] + # Verify that we get all types of nodes we want. + self.assertEqual(set(op.type for op in operations), should_be_ordered) + test_util.assert_sequential_execution(order, operations) + + get_global_mpr(num_processes).run(replica_fn) + if __name__ == "__main__": # Set default inter op thread pool size to one to ensure we don't exhaust the # thread pool with the additional executors to run collectives in eager. os.environ["TF_NUM_INTEROP_THREADS"] = "1" - multi_process_runner.test_main() + # TODO(b/172304955): figure why logical devices doesn't work. + test_util.main(config_logical_devices=False) diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py index 96866fb1ca4..d90c3b73717 100644 --- a/tensorflow/python/distribute/cross_device_utils.py +++ b/tensorflow/python/distribute/cross_device_utils.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import collective_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nccl_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging INSTANCE_KEY_START_NUMBER = 100 @@ -258,6 +259,7 @@ class CollectiveReplicaLauncher(object): _use_scoped_allocator = True _use_collective_v2 = False + _use_ordering_token = False def __init__(self, group_key, @@ -272,6 +274,12 @@ class CollectiveReplicaLauncher(object): self._collective_keys = collective_keys self._device = device self._executor = executor + if (self._use_ordering_token and self._use_collective_v2 and + ops.executing_eagerly_outside_functions()): + with ops.init_scope(), ops.device(device): + self._ordering_token = resource_variable_ops.ResourceVariable(0.) + else: + self._ordering_token = None def _executor_scope(self): if context.executing_eagerly() and not self._executor: @@ -281,7 +289,7 @@ class CollectiveReplicaLauncher(object): return ops.NullContextmanager() def _control_input(self, control_input): - if control_input is not None: + if control_input is not None and self._ordering_token is None: return ops.control_dependencies([control_input]) return ops.NullContextmanager() @@ -323,6 +331,11 @@ class CollectiveReplicaLauncher(object): return self._collective_keys.get_instance_key(self._group_key, self._device) + def _get_ordering_token(self, communication_hint): + if self._ordering_token is not None and communication_hint == 'NCCL': + return self._ordering_token.handle + return None + def all_reduce(self, input_tensor, control_input=None, @@ -345,6 +358,7 @@ class CollectiveReplicaLauncher(object): The reduced tensor. """ instance_key = self._next_instance_key() + ordering_token = self._get_ordering_token(communication_hint) with self._executor_scope(), \ ops.device(self._device), \ self._control_input(control_input): @@ -355,7 +369,8 @@ class CollectiveReplicaLauncher(object): self._group_key, instance_key, communication_hint=communication_hint, - timeout=timeout) + timeout=timeout, + ordering_token=ordering_token) else: return collective_ops.all_reduce( input_tensor, @@ -381,6 +396,7 @@ class CollectiveReplicaLauncher(object): The reduced tensor. """ instance_key = self._next_instance_key() + ordering_token = self._get_ordering_token(communication_hint) with self._executor_scope(), ops.device(self._device): if self._should_use_collective_v2(): return collective_ops.all_gather_v2( @@ -389,7 +405,8 @@ class CollectiveReplicaLauncher(object): self._group_key, instance_key, communication_hint=communication_hint, - timeout=timeout) + timeout=timeout, + ordering_token=ordering_token) else: return collective_ops.all_gather( input_tensor, diff --git a/tensorflow/python/distribute/test_util.py b/tensorflow/python/distribute/test_util.py index 2cc278c647e..45085ba6203 100644 --- a/tensorflow/python/distribute/test_util.py +++ b/tensorflow/python/distribute/test_util.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import itertools from absl import app @@ -106,3 +107,86 @@ def main(enable_v2_behavior=True, config_logical_devices=True): v2_compat.disable_v2_behavior() # TODO(b/131360402): configure default logical devices. multi_process_runner.test_main() + + +def _op_dependencies(op): + """Returns the data and control dependencies of a tf.Operation combined.""" + deps = [] + for node in itertools.chain(op.inputs, op.control_inputs): + if isinstance(node, ops.Tensor): + node = node.op + assert isinstance(node, ops.Operation) + deps.append(node) + return deps + + +def topological_sort_operations(operations): + """Topological sorts a list of operations. + + This does a topological sort of the operations in a graph. The edges include + both data dependencies and control dependencies. Note that the edge goes from + an operation to its dependencies. + + Args: + operations: a list of tf.Operation in the same graph. + + Returns: + A map from a tf.Operation to its topological order. + """ + in_degrees = {} + for op in operations: + if op not in in_degrees: + in_degrees[op] = 0 + for next_op in _op_dependencies(op): + in_degrees[next_op] = in_degrees.get(next_op, 0) + 1 + nexts = [] + for op, in_degree in in_degrees.items(): + if in_degree == 0: + nexts.append(op) + order = {} + next_order = 0 + while nexts: + op, nexts = nexts[0], nexts[1:] + order[op] = next_order + next_order += 1 + for next_op in _op_dependencies(op): + in_degrees[next_op] -= 1 + if in_degrees[next_op] == 0: + nexts.append(next_op) + assert len(order) == len(operations) + return order + + +def _exists_dependency(start, end): + """Returns whether there exists a dependency chain from start to end.""" + nexts = [start] + while nexts: + op, nexts = nexts[0], nexts[1:] + for next_op in _op_dependencies(op): + if next_op == end: + return True + nexts.append(next_op) + return False + + +def assert_sequential_execution(order, operations): + """Asserts there's a deterministic execution order between the operations. + + Args: + order: a map from a tf.Operation to its topological order. + operations: a list of operations that should be executed sequentially. It + can be given in any order. + """ + # Topological ordering guarantees that, if there's a dependency from N_a to + # N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies + # among the operations, it always goes from a operation with a smaller + # topological order to one with a larger topological order. Therefore, we only + # need to sort the operations by their topological orders, and verify that + # there's a path of dependency between adjacent pairs. + operations = sorted(operations, key=lambda op: order[op]) + for i in range(len(operations) - 1): + if not _exists_dependency(operations[i], operations[i + 1]): + print(operations[i].graph.as_graph_def()) + raise AssertionError( + "No dependency between {} and {}. Graph is dumped to stdout.".format( + operations[i].name, operations[i + 1].name)) diff --git a/tensorflow/python/distribute/test_util_test.py b/tensorflow/python/distribute/test_util_test.py index 165f97be6e2..756e08dbb42 100644 --- a/tensorflow/python/distribute/test_util_test.py +++ b/tensorflow/python/distribute/test_util_test.py @@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import config from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -82,5 +83,37 @@ class LogicalDevicesTest(test.TestCase): self.assertLen(config.get_logical_device_configuration(cpu_device), 3) +class AssertSequentailExecutionTest(test.TestCase): + + def test1(self): + + @def_function.function + def f(): + a = array_ops.identity(1., name='a') + b = a + 1 + c = array_ops.identity(2., name='c') + d = array_ops.identity(a + c, name='d') + with ops.control_dependencies([b]): + e = array_ops.identity(3., name='e') + f = array_ops.identity(c + e, name='f') + return d, f + + graph = f.get_concrete_function().graph + order = test_util.topological_sort_operations(graph.get_operations()) + a = graph.get_operation_by_name('a') + c = graph.get_operation_by_name('c') + d = graph.get_operation_by_name('d') + e = graph.get_operation_by_name('e') + f = graph.get_operation_by_name('f') + test_util.assert_sequential_execution(order, [a, d]) + test_util.assert_sequential_execution(order, [e, a, f]) + with self.assertRaises(AssertionError): + test_util.assert_sequential_execution(order, [a, c]) + with self.assertRaises(AssertionError): + test_util.assert_sequential_execution(order, [f, a, c]) + with self.assertRaises(AssertionError): + test_util.assert_sequential_execution(order, [d, e, a, c]) + + if __name__ == '__main__': test_util.main()