Order NCCL all-reduce with ordering token

Auto control dep will chain operations with the same resource input. We'll do the same thing for all-gather after some refactoring is done.

PiperOrigin-RevId: 341868107
Change-Id: I5570a28c2e1c638980e3509088c0525e957c463b
This commit is contained in:
Ran Chen 2020-11-11 11:09:51 -08:00 committed by TensorFlower Gardener
parent 28835e4103
commit 07b9eccccf
6 changed files with 255 additions and 9 deletions

View File

@ -106,6 +106,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:nccl_ops",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
],
@ -1079,6 +1080,7 @@ cuda_py_test(
":multi_process_runner",
":multi_worker_test_base",
":reduce_util",
":test_util",
":values",
"//tensorflow/python:array_ops",
"//tensorflow/python:collective_ops",

View File

@ -990,6 +990,11 @@ class CollectiveAllReduce(CrossDeviceOps):
all workers and then put results on the right destinations.
"""
# Whether to only use NCCL for batched all-reduce when NCCL is requested. This
# is because of the lack of mechanism to order NCCL operations
# deterministically.
_limited_nccl = True
def __init__(self, devices, group_size, collective_keys=None):
"""Initializes the object.
@ -1121,8 +1126,8 @@ class CollectiveAllReduce(CrossDeviceOps):
# all-reduce, which is the gradients.
# TODO(b/132575814): switch to NCCL for all collectives when communication
# is NCCL if and only if we can order collectives deterministically.
# is NCCL.
if (options.implementation == CommunicationImplementation.NCCL and
if (self._limited_nccl and
options.implementation == CommunicationImplementation.NCCL and
batch_size == 1):
implementation = CommunicationImplementation.AUTO.value
@ -1182,8 +1187,9 @@ class CollectiveAllReduce(CrossDeviceOps):
# For now, we use NCCL only when batch_size > 1.
# TODO(b/132575814): switch to NCCL for all collectives when implementation
# is NCCL.
if options.implementation == CommunicationImplementation.NCCL and len(
per_replica_values) == 1:
if (self._limited_nccl and
options.implementation == CommunicationImplementation.NCCL and
len(per_replica_values) == 1):
implementation = CommunicationImplementation.AUTO.value
gathered_values = []

View File

@ -32,9 +32,11 @@ from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
@ -70,7 +72,12 @@ def make_per_replica_value(value, devices):
"""
values = []
for device_idx, device in enumerate(devices):
v = value(device_idx) if callable(value) else value
if callable(value):
v = value(device_idx)
elif isinstance(value, list):
v = value[device_idx]
else:
v = value
if isinstance(v, IndexedSlicesValue):
with ops.device(device):
values.append(
@ -99,6 +106,11 @@ def enable_collective_ops():
task_index=cluster_resolver.task_id,
protocol=cluster_resolver.rpc_layer)
context.context().enable_collective_ops(server_def)
# Recover default flag values.
cross_device_ops_lib.CollectiveAllReduce._limited_nccl = True
cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = True
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = False
cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = False
class MultiProcessPoolRunner():
@ -858,9 +870,101 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
get_global_mpr(num_processes).run(replica_fn)
@combinations.generate(combinations.combine(num_processes=1, required_gpus=2))
def testNcclOrdering(self, num_processes, required_gpus):
def replica_fn():
cross_device_ops_lib.CollectiveAllReduce._limited_nccl = False
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = True
collective, devices, _ = self.make_collective(
num_processes, required_gpus)
options = collective_util.Options(
implementation=CommunicationImplementation.NCCL)
v_dense = make_per_replica_value([1.0, 1.0], devices)
v_sparse = make_per_replica_value([
IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
], devices)
@def_function.function
def nested_dense():
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
@def_function.function
def nested_sparse():
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
# All collectives, function calls, if clause and while loops should be
# chained by control dependencies, so that the execution order is
# deterministic.
@def_function.function
def f():
# pylint: disable=pointless-statement
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
# reducing dense value.
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
# reducing sparse value.
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
# reduce dense value in nested tf.function.
nested_dense()
# reduce sparse value in nested tf.function.
nested_sparse()
# reduce dense value in tf.cond.
if array_ops.identity(1.0) > array_ops.identity(2.0):
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
else:
v_dense
# reduce sparse value in tf.cond.
if array_ops.identity(1.0) > array_ops.identity(2.0):
v_sparse
else:
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
options)
# reduce dense value in tf.while_loop.
i = array_ops.identity(1)
while i < 3:
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
i += 1
# reduce sparse value in tf.while_loop.
i = array_ops.identity(1)
while i < 3:
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
options)
i += 1
# reducing dense and sparse value again.
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
# pylint: enable=pointless-statement
graph = f.get_concrete_function().graph
should_be_ordered = set([
"CollectiveReduce", "CollectiveGather", "If", "While",
"StatefulPartitionedCall"
])
nodes_by_device = {}
for op in graph.get_operations():
if op.type in should_be_ordered:
if op.device not in nodes_by_device:
nodes_by_device[op.device] = []
nodes_by_device[op.device].append(op)
order = test_util.topological_sort_operations(graph.get_operations())
for device in devices:
device = device_util.canonicalize(device)
# Those function ops don't have device annotations, but they contain
# collectives for both devices so we always include them.
operations = nodes_by_device[device] + nodes_by_device[""]
# Verify that we get all types of nodes we want.
self.assertEqual(set(op.type for op in operations), should_be_ordered)
test_util.assert_sequential_execution(order, operations)
get_global_mpr(num_processes).run(replica_fn)
if __name__ == "__main__":
# Set default inter op thread pool size to one to ensure we don't exhaust the
# thread pool with the additional executors to run collectives in eager.
os.environ["TF_NUM_INTEROP_THREADS"] = "1"
multi_process_runner.test_main()
# TODO(b/172304955): figure why logical devices doesn't work.
test_util.main(config_logical_devices=False)

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
INSTANCE_KEY_START_NUMBER = 100
@ -258,6 +259,7 @@ class CollectiveReplicaLauncher(object):
_use_scoped_allocator = True
_use_collective_v2 = False
_use_ordering_token = False
def __init__(self,
group_key,
@ -272,6 +274,12 @@ class CollectiveReplicaLauncher(object):
self._collective_keys = collective_keys
self._device = device
self._executor = executor
if (self._use_ordering_token and self._use_collective_v2 and
ops.executing_eagerly_outside_functions()):
with ops.init_scope(), ops.device(device):
self._ordering_token = resource_variable_ops.ResourceVariable(0.)
else:
self._ordering_token = None
def _executor_scope(self):
if context.executing_eagerly() and not self._executor:
@ -281,7 +289,7 @@ class CollectiveReplicaLauncher(object):
return ops.NullContextmanager()
def _control_input(self, control_input):
if control_input is not None:
if control_input is not None and self._ordering_token is None:
return ops.control_dependencies([control_input])
return ops.NullContextmanager()
@ -323,6 +331,11 @@ class CollectiveReplicaLauncher(object):
return self._collective_keys.get_instance_key(self._group_key,
self._device)
def _get_ordering_token(self, communication_hint):
if self._ordering_token is not None and communication_hint == 'NCCL':
return self._ordering_token.handle
return None
def all_reduce(self,
input_tensor,
control_input=None,
@ -345,6 +358,7 @@ class CollectiveReplicaLauncher(object):
The reduced tensor.
"""
instance_key = self._next_instance_key()
ordering_token = self._get_ordering_token(communication_hint)
with self._executor_scope(), \
ops.device(self._device), \
self._control_input(control_input):
@ -355,7 +369,8 @@ class CollectiveReplicaLauncher(object):
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
timeout=timeout,
ordering_token=ordering_token)
else:
return collective_ops.all_reduce(
input_tensor,
@ -381,6 +396,7 @@ class CollectiveReplicaLauncher(object):
The reduced tensor.
"""
instance_key = self._next_instance_key()
ordering_token = self._get_ordering_token(communication_hint)
with self._executor_scope(), ops.device(self._device):
if self._should_use_collective_v2():
return collective_ops.all_gather_v2(
@ -389,7 +405,8 @@ class CollectiveReplicaLauncher(object):
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
timeout=timeout,
ordering_token=ordering_token)
else:
return collective_ops.all_gather(
input_tensor,

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
import itertools
from absl import app
@ -106,3 +107,86 @@ def main(enable_v2_behavior=True, config_logical_devices=True):
v2_compat.disable_v2_behavior()
# TODO(b/131360402): configure default logical devices.
multi_process_runner.test_main()
def _op_dependencies(op):
"""Returns the data and control dependencies of a tf.Operation combined."""
deps = []
for node in itertools.chain(op.inputs, op.control_inputs):
if isinstance(node, ops.Tensor):
node = node.op
assert isinstance(node, ops.Operation)
deps.append(node)
return deps
def topological_sort_operations(operations):
"""Topological sorts a list of operations.
This does a topological sort of the operations in a graph. The edges include
both data dependencies and control dependencies. Note that the edge goes from
an operation to its dependencies.
Args:
operations: a list of tf.Operation in the same graph.
Returns:
A map from a tf.Operation to its topological order.
"""
in_degrees = {}
for op in operations:
if op not in in_degrees:
in_degrees[op] = 0
for next_op in _op_dependencies(op):
in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
nexts = []
for op, in_degree in in_degrees.items():
if in_degree == 0:
nexts.append(op)
order = {}
next_order = 0
while nexts:
op, nexts = nexts[0], nexts[1:]
order[op] = next_order
next_order += 1
for next_op in _op_dependencies(op):
in_degrees[next_op] -= 1
if in_degrees[next_op] == 0:
nexts.append(next_op)
assert len(order) == len(operations)
return order
def _exists_dependency(start, end):
"""Returns whether there exists a dependency chain from start to end."""
nexts = [start]
while nexts:
op, nexts = nexts[0], nexts[1:]
for next_op in _op_dependencies(op):
if next_op == end:
return True
nexts.append(next_op)
return False
def assert_sequential_execution(order, operations):
"""Asserts there's a deterministic execution order between the operations.
Args:
order: a map from a tf.Operation to its topological order.
operations: a list of operations that should be executed sequentially. It
can be given in any order.
"""
# Topological ordering guarantees that, if there's a dependency from N_a to
# N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
# among the operations, it always goes from a operation with a smaller
# topological order to one with a larger topological order. Therefore, we only
# need to sort the operations by their topological orders, and verify that
# there's a path of dependency between adjacent pairs.
operations = sorted(operations, key=lambda op: order[op])
for i in range(len(operations) - 1):
if not _exists_dependency(operations[i], operations[i + 1]):
print(operations[i].graph.as_graph_def())
raise AssertionError(
"No dependency between {} and {}. Graph is dumped to stdout.".format(
operations[i].name, operations[i + 1].name))

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -82,5 +83,37 @@ class LogicalDevicesTest(test.TestCase):
self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
class AssertSequentailExecutionTest(test.TestCase):
def test1(self):
@def_function.function
def f():
a = array_ops.identity(1., name='a')
b = a + 1
c = array_ops.identity(2., name='c')
d = array_ops.identity(a + c, name='d')
with ops.control_dependencies([b]):
e = array_ops.identity(3., name='e')
f = array_ops.identity(c + e, name='f')
return d, f
graph = f.get_concrete_function().graph
order = test_util.topological_sort_operations(graph.get_operations())
a = graph.get_operation_by_name('a')
c = graph.get_operation_by_name('c')
d = graph.get_operation_by_name('d')
e = graph.get_operation_by_name('e')
f = graph.get_operation_by_name('f')
test_util.assert_sequential_execution(order, [a, d])
test_util.assert_sequential_execution(order, [e, a, f])
with self.assertRaises(AssertionError):
test_util.assert_sequential_execution(order, [a, c])
with self.assertRaises(AssertionError):
test_util.assert_sequential_execution(order, [f, a, c])
with self.assertRaises(AssertionError):
test_util.assert_sequential_execution(order, [d, e, a, c])
if __name__ == '__main__':
test_util.main()