Order NCCL all-reduce with ordering token
Auto control dep will chain operations with the same resource input. We'll do the same thing for all-gather after some refactoring is done. PiperOrigin-RevId: 341868107 Change-Id: I5570a28c2e1c638980e3509088c0525e957c463b
This commit is contained in:
parent
28835e4103
commit
07b9eccccf
@ -106,6 +106,7 @@ py_library(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nccl_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
@ -1079,6 +1080,7 @@ cuda_py_test(
|
||||
":multi_process_runner",
|
||||
":multi_worker_test_base",
|
||||
":reduce_util",
|
||||
":test_util",
|
||||
":values",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:collective_ops",
|
||||
|
@ -990,6 +990,11 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
all workers and then put results on the right destinations.
|
||||
"""
|
||||
|
||||
# Whether to only use NCCL for batched all-reduce when NCCL is requested. This
|
||||
# is because of the lack of mechanism to order NCCL operations
|
||||
# deterministically.
|
||||
_limited_nccl = True
|
||||
|
||||
def __init__(self, devices, group_size, collective_keys=None):
|
||||
"""Initializes the object.
|
||||
|
||||
@ -1121,8 +1126,8 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# all-reduce, which is the gradients.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
||||
# is NCCL if and only if we can order collectives deterministically.
|
||||
# is NCCL.
|
||||
if (options.implementation == CommunicationImplementation.NCCL and
|
||||
if (self._limited_nccl and
|
||||
options.implementation == CommunicationImplementation.NCCL and
|
||||
batch_size == 1):
|
||||
implementation = CommunicationImplementation.AUTO.value
|
||||
|
||||
@ -1182,8 +1187,9 @@ class CollectiveAllReduce(CrossDeviceOps):
|
||||
# For now, we use NCCL only when batch_size > 1.
|
||||
# TODO(b/132575814): switch to NCCL for all collectives when implementation
|
||||
# is NCCL.
|
||||
if options.implementation == CommunicationImplementation.NCCL and len(
|
||||
per_replica_values) == 1:
|
||||
if (self._limited_nccl and
|
||||
options.implementation == CommunicationImplementation.NCCL and
|
||||
len(per_replica_values) == 1):
|
||||
implementation = CommunicationImplementation.AUTO.value
|
||||
|
||||
gathered_values = []
|
||||
|
@ -32,9 +32,11 @@ from tensorflow.python.distribute import collective_util
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
|
||||
from tensorflow.python.distribute import cross_device_utils
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import multi_process_runner
|
||||
from tensorflow.python.distribute import multi_worker_test_base
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import values as value_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -70,7 +72,12 @@ def make_per_replica_value(value, devices):
|
||||
"""
|
||||
values = []
|
||||
for device_idx, device in enumerate(devices):
|
||||
v = value(device_idx) if callable(value) else value
|
||||
if callable(value):
|
||||
v = value(device_idx)
|
||||
elif isinstance(value, list):
|
||||
v = value[device_idx]
|
||||
else:
|
||||
v = value
|
||||
if isinstance(v, IndexedSlicesValue):
|
||||
with ops.device(device):
|
||||
values.append(
|
||||
@ -99,6 +106,11 @@ def enable_collective_ops():
|
||||
task_index=cluster_resolver.task_id,
|
||||
protocol=cluster_resolver.rpc_layer)
|
||||
context.context().enable_collective_ops(server_def)
|
||||
# Recover default flag values.
|
||||
cross_device_ops_lib.CollectiveAllReduce._limited_nccl = True
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = True
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = False
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = False
|
||||
|
||||
|
||||
class MultiProcessPoolRunner():
|
||||
@ -858,9 +870,101 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
@combinations.generate(combinations.combine(num_processes=1, required_gpus=2))
|
||||
def testNcclOrdering(self, num_processes, required_gpus):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_ops_lib.CollectiveAllReduce._limited_nccl = False
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_ordering_token = True
|
||||
collective, devices, _ = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
options = collective_util.Options(
|
||||
implementation=CommunicationImplementation.NCCL)
|
||||
|
||||
v_dense = make_per_replica_value([1.0, 1.0], devices)
|
||||
v_sparse = make_per_replica_value([
|
||||
IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
|
||||
IndexedSlicesValue([[4., 6.], [5., 6.]], [1, 3], [5, 2]),
|
||||
], devices)
|
||||
|
||||
@def_function.function
|
||||
def nested_dense():
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
|
||||
|
||||
@def_function.function
|
||||
def nested_sparse():
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
|
||||
|
||||
# All collectives, function calls, if clause and while loops should be
|
||||
# chained by control dependencies, so that the execution order is
|
||||
# deterministic.
|
||||
@def_function.function
|
||||
def f():
|
||||
# pylint: disable=pointless-statement
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
|
||||
# reducing dense value.
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
|
||||
# reducing sparse value.
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
|
||||
# reduce dense value in nested tf.function.
|
||||
nested_dense()
|
||||
# reduce sparse value in nested tf.function.
|
||||
nested_sparse()
|
||||
# reduce dense value in tf.cond.
|
||||
if array_ops.identity(1.0) > array_ops.identity(2.0):
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
|
||||
else:
|
||||
v_dense
|
||||
# reduce sparse value in tf.cond.
|
||||
if array_ops.identity(1.0) > array_ops.identity(2.0):
|
||||
v_sparse
|
||||
else:
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
|
||||
options)
|
||||
# reduce dense value in tf.while_loop.
|
||||
i = array_ops.identity(1)
|
||||
while i < 3:
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
|
||||
i += 1
|
||||
# reduce sparse value in tf.while_loop.
|
||||
i = array_ops.identity(1)
|
||||
while i < 3:
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse,
|
||||
options)
|
||||
i += 1
|
||||
# reducing dense and sparse value again.
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_dense, v_dense, options)
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v_sparse, v_sparse, options)
|
||||
# pylint: enable=pointless-statement
|
||||
|
||||
graph = f.get_concrete_function().graph
|
||||
should_be_ordered = set([
|
||||
"CollectiveReduce", "CollectiveGather", "If", "While",
|
||||
"StatefulPartitionedCall"
|
||||
])
|
||||
nodes_by_device = {}
|
||||
for op in graph.get_operations():
|
||||
if op.type in should_be_ordered:
|
||||
if op.device not in nodes_by_device:
|
||||
nodes_by_device[op.device] = []
|
||||
nodes_by_device[op.device].append(op)
|
||||
order = test_util.topological_sort_operations(graph.get_operations())
|
||||
for device in devices:
|
||||
device = device_util.canonicalize(device)
|
||||
# Those function ops don't have device annotations, but they contain
|
||||
# collectives for both devices so we always include them.
|
||||
operations = nodes_by_device[device] + nodes_by_device[""]
|
||||
# Verify that we get all types of nodes we want.
|
||||
self.assertEqual(set(op.type for op in operations), should_be_ordered)
|
||||
test_util.assert_sequential_execution(order, operations)
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Set default inter op thread pool size to one to ensure we don't exhaust the
|
||||
# thread pool with the additional executors to run collectives in eager.
|
||||
os.environ["TF_NUM_INTEROP_THREADS"] = "1"
|
||||
multi_process_runner.test_main()
|
||||
# TODO(b/172304955): figure why logical devices doesn't work.
|
||||
test_util.main(config_logical_devices=False)
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nccl_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
INSTANCE_KEY_START_NUMBER = 100
|
||||
@ -258,6 +259,7 @@ class CollectiveReplicaLauncher(object):
|
||||
|
||||
_use_scoped_allocator = True
|
||||
_use_collective_v2 = False
|
||||
_use_ordering_token = False
|
||||
|
||||
def __init__(self,
|
||||
group_key,
|
||||
@ -272,6 +274,12 @@ class CollectiveReplicaLauncher(object):
|
||||
self._collective_keys = collective_keys
|
||||
self._device = device
|
||||
self._executor = executor
|
||||
if (self._use_ordering_token and self._use_collective_v2 and
|
||||
ops.executing_eagerly_outside_functions()):
|
||||
with ops.init_scope(), ops.device(device):
|
||||
self._ordering_token = resource_variable_ops.ResourceVariable(0.)
|
||||
else:
|
||||
self._ordering_token = None
|
||||
|
||||
def _executor_scope(self):
|
||||
if context.executing_eagerly() and not self._executor:
|
||||
@ -281,7 +289,7 @@ class CollectiveReplicaLauncher(object):
|
||||
return ops.NullContextmanager()
|
||||
|
||||
def _control_input(self, control_input):
|
||||
if control_input is not None:
|
||||
if control_input is not None and self._ordering_token is None:
|
||||
return ops.control_dependencies([control_input])
|
||||
return ops.NullContextmanager()
|
||||
|
||||
@ -323,6 +331,11 @@ class CollectiveReplicaLauncher(object):
|
||||
return self._collective_keys.get_instance_key(self._group_key,
|
||||
self._device)
|
||||
|
||||
def _get_ordering_token(self, communication_hint):
|
||||
if self._ordering_token is not None and communication_hint == 'NCCL':
|
||||
return self._ordering_token.handle
|
||||
return None
|
||||
|
||||
def all_reduce(self,
|
||||
input_tensor,
|
||||
control_input=None,
|
||||
@ -345,6 +358,7 @@ class CollectiveReplicaLauncher(object):
|
||||
The reduced tensor.
|
||||
"""
|
||||
instance_key = self._next_instance_key()
|
||||
ordering_token = self._get_ordering_token(communication_hint)
|
||||
with self._executor_scope(), \
|
||||
ops.device(self._device), \
|
||||
self._control_input(control_input):
|
||||
@ -355,7 +369,8 @@ class CollectiveReplicaLauncher(object):
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
timeout=timeout,
|
||||
ordering_token=ordering_token)
|
||||
else:
|
||||
return collective_ops.all_reduce(
|
||||
input_tensor,
|
||||
@ -381,6 +396,7 @@ class CollectiveReplicaLauncher(object):
|
||||
The reduced tensor.
|
||||
"""
|
||||
instance_key = self._next_instance_key()
|
||||
ordering_token = self._get_ordering_token(communication_hint)
|
||||
with self._executor_scope(), ops.device(self._device):
|
||||
if self._should_use_collective_v2():
|
||||
return collective_ops.all_gather_v2(
|
||||
@ -389,7 +405,8 @@ class CollectiveReplicaLauncher(object):
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
timeout=timeout,
|
||||
ordering_token=ordering_token)
|
||||
else:
|
||||
return collective_ops.all_gather(
|
||||
input_tensor,
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
from absl import app
|
||||
|
||||
@ -106,3 +107,86 @@ def main(enable_v2_behavior=True, config_logical_devices=True):
|
||||
v2_compat.disable_v2_behavior()
|
||||
# TODO(b/131360402): configure default logical devices.
|
||||
multi_process_runner.test_main()
|
||||
|
||||
|
||||
def _op_dependencies(op):
|
||||
"""Returns the data and control dependencies of a tf.Operation combined."""
|
||||
deps = []
|
||||
for node in itertools.chain(op.inputs, op.control_inputs):
|
||||
if isinstance(node, ops.Tensor):
|
||||
node = node.op
|
||||
assert isinstance(node, ops.Operation)
|
||||
deps.append(node)
|
||||
return deps
|
||||
|
||||
|
||||
def topological_sort_operations(operations):
|
||||
"""Topological sorts a list of operations.
|
||||
|
||||
This does a topological sort of the operations in a graph. The edges include
|
||||
both data dependencies and control dependencies. Note that the edge goes from
|
||||
an operation to its dependencies.
|
||||
|
||||
Args:
|
||||
operations: a list of tf.Operation in the same graph.
|
||||
|
||||
Returns:
|
||||
A map from a tf.Operation to its topological order.
|
||||
"""
|
||||
in_degrees = {}
|
||||
for op in operations:
|
||||
if op not in in_degrees:
|
||||
in_degrees[op] = 0
|
||||
for next_op in _op_dependencies(op):
|
||||
in_degrees[next_op] = in_degrees.get(next_op, 0) + 1
|
||||
nexts = []
|
||||
for op, in_degree in in_degrees.items():
|
||||
if in_degree == 0:
|
||||
nexts.append(op)
|
||||
order = {}
|
||||
next_order = 0
|
||||
while nexts:
|
||||
op, nexts = nexts[0], nexts[1:]
|
||||
order[op] = next_order
|
||||
next_order += 1
|
||||
for next_op in _op_dependencies(op):
|
||||
in_degrees[next_op] -= 1
|
||||
if in_degrees[next_op] == 0:
|
||||
nexts.append(next_op)
|
||||
assert len(order) == len(operations)
|
||||
return order
|
||||
|
||||
|
||||
def _exists_dependency(start, end):
|
||||
"""Returns whether there exists a dependency chain from start to end."""
|
||||
nexts = [start]
|
||||
while nexts:
|
||||
op, nexts = nexts[0], nexts[1:]
|
||||
for next_op in _op_dependencies(op):
|
||||
if next_op == end:
|
||||
return True
|
||||
nexts.append(next_op)
|
||||
return False
|
||||
|
||||
|
||||
def assert_sequential_execution(order, operations):
|
||||
"""Asserts there's a deterministic execution order between the operations.
|
||||
|
||||
Args:
|
||||
order: a map from a tf.Operation to its topological order.
|
||||
operations: a list of operations that should be executed sequentially. It
|
||||
can be given in any order.
|
||||
"""
|
||||
# Topological ordering guarantees that, if there's a dependency from N_a to
|
||||
# N_b, then order[N_a] < order[N_b]. If there do exist a path of dependencies
|
||||
# among the operations, it always goes from a operation with a smaller
|
||||
# topological order to one with a larger topological order. Therefore, we only
|
||||
# need to sort the operations by their topological orders, and verify that
|
||||
# there's a path of dependency between adjacent pairs.
|
||||
operations = sorted(operations, key=lambda op: order[op])
|
||||
for i in range(len(operations) - 1):
|
||||
if not _exists_dependency(operations[i], operations[i + 1]):
|
||||
print(operations[i].graph.as_graph_def())
|
||||
raise AssertionError(
|
||||
"No dependency between {} and {}. Graph is dumped to stdout.".format(
|
||||
operations[i].name, operations[i + 1].name))
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
@ -82,5 +83,37 @@ class LogicalDevicesTest(test.TestCase):
|
||||
self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
|
||||
|
||||
|
||||
class AssertSequentailExecutionTest(test.TestCase):
|
||||
|
||||
def test1(self):
|
||||
|
||||
@def_function.function
|
||||
def f():
|
||||
a = array_ops.identity(1., name='a')
|
||||
b = a + 1
|
||||
c = array_ops.identity(2., name='c')
|
||||
d = array_ops.identity(a + c, name='d')
|
||||
with ops.control_dependencies([b]):
|
||||
e = array_ops.identity(3., name='e')
|
||||
f = array_ops.identity(c + e, name='f')
|
||||
return d, f
|
||||
|
||||
graph = f.get_concrete_function().graph
|
||||
order = test_util.topological_sort_operations(graph.get_operations())
|
||||
a = graph.get_operation_by_name('a')
|
||||
c = graph.get_operation_by_name('c')
|
||||
d = graph.get_operation_by_name('d')
|
||||
e = graph.get_operation_by_name('e')
|
||||
f = graph.get_operation_by_name('f')
|
||||
test_util.assert_sequential_execution(order, [a, d])
|
||||
test_util.assert_sequential_execution(order, [e, a, f])
|
||||
with self.assertRaises(AssertionError):
|
||||
test_util.assert_sequential_execution(order, [a, c])
|
||||
with self.assertRaises(AssertionError):
|
||||
test_util.assert_sequential_execution(order, [f, a, c])
|
||||
with self.assertRaises(AssertionError):
|
||||
test_util.assert_sequential_execution(order, [d, e, a, c])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_util.main()
|
||||
|
Loading…
Reference in New Issue
Block a user