Remove collective v1 code path

PiperOrigin-RevId: 354577402
Change-Id: I200d98a6a80dfe1e463044f9dedef9291ff7d846
This commit is contained in:
Ran Chen 2021-01-29 11:43:39 -08:00 committed by TensorFlower Gardener
parent cb4c7d28f0
commit 1a46fdc4a2
2 changed files with 86 additions and 131 deletions

View File

@ -52,6 +52,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher
CommunicationImplementation = collective_util.CommunicationImplementation
ReduceOp = reduce_util.ReduceOp
IndexedSlicesValue = indexed_slices.IndexedSlicesValue
@ -107,9 +108,8 @@ def enable_collective_ops():
protocol=cluster_resolver.rpc_layer)
context.context().enable_collective_ops(server_def)
# Recover default flag values.
cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = True
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = False
CollectiveReplicaLauncher._prefer_unique_instance_key = True
CollectiveReplicaLauncher._prefer_ordering_token = False
class MultiProcessPoolRunner():
@ -215,12 +215,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
"gpus_per_process",
"reduce_op",
"communication_options",
"prefer_scoped_allocator",
"prefer_collective_v2",
"prefer_unique_instance_key",
])
RunOptions.__new__.__defaults__ = (["eager",
"func_graph"], 2, 0, ReduceOp.SUM,
collective_util.Options(), True, False)
collective_util.Options(), True)
def reduce_and_verify(self, inputs, expect, options):
"""Reduce the given `inputs` and verify the output matches `expect`.
@ -234,8 +233,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
"""
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
options.prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
options.prefer_unique_instance_key)
collective, devices, pid = self.make_collective(options.num_processes,
options.gpus_per_process)
@ -273,10 +272,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
"""
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = (
options.prefer_scoped_allocator)
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
options.prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
options.prefer_unique_instance_key)
collective, devices, pid = self.make_collective(options.num_processes,
options.gpus_per_process)
@ -321,9 +318,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
CommunicationImplementation.NCCL,
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testAllReduceDense(self, num_processes, required_gpus, implementation,
reduce_op, prefer_collective_v2):
reduce_op, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -337,7 +334,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
reduce_op=reduce_op,
communication_options=collective_util.Options(
implementation=implementation),
prefer_collective_v2=prefer_collective_v2)
prefer_unique_instance_key=prefer_unique_instance_key)
group_size = options.num_processes * (options.gpus_per_process or 1)
inputs_data = [1.0, 2.0, 3.0, 4.0]
@ -363,9 +360,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
],
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM,
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op, prefer_collective_v2):
reduce_op, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -380,7 +377,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
reduce_op=reduce_op,
communication_options=collective_util.Options(
implementation=implementation),
prefer_collective_v2=prefer_collective_v2)
prefer_unique_instance_key=prefer_unique_instance_key)
group_size = options.num_processes * (options.gpus_per_process or 1)
inputs_data = [
@ -412,8 +409,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
self.reduce_and_verify(inputs, expect, options)
@combinations.generate(
combinations.combine(prefer_collective_v2=[True, False]))
def testAllReduceSparseVariableLength(self, prefer_collective_v2):
combinations.combine(prefer_unique_instance_key=[True, False]))
def testAllReduceSparseVariableLength(self, prefer_unique_instance_key):
# One device per process, 2 processes, 2 replicas in total.
inputs = [
IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
@ -431,7 +428,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
mode=["func_graph"], # Sparse reduce is not supported in eager.
num_processes=2,
reduce_op=ReduceOp.SUM,
prefer_collective_v2=prefer_collective_v2))
prefer_unique_instance_key=prefer_unique_instance_key))
@combinations.generate(
combinations.combine(
@ -443,11 +440,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
CommunicationImplementation.NCCL,
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
prefer_scoped_allocator=[True, False],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testBatchAllReduceDense(self, num_processes, required_gpus,
implementation, reduce_op,
prefer_scoped_allocator, prefer_collective_v2):
prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -462,8 +458,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
reduce_op=reduce_op,
communication_options=collective_util.Options(
implementation=implementation),
prefer_scoped_allocator=prefer_scoped_allocator,
prefer_collective_v2=prefer_collective_v2)
prefer_unique_instance_key=prefer_unique_instance_key)
group_size = options.num_processes * (options.gpus_per_process or 1)
inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
@ -489,11 +484,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
],
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM,
prefer_scoped_allocator=[True, False],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testBatchAllReduceSparse(self, num_processes, required_gpus,
implementation, reduce_op,
prefer_scoped_allocator, prefer_collective_v2):
prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
@ -509,8 +503,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
reduce_op=reduce_op,
communication_options=collective_util.Options(
implementation=implementation),
prefer_scoped_allocator=prefer_scoped_allocator,
prefer_collective_v2=prefer_collective_v2)
prefer_unique_instance_key=prefer_unique_instance_key)
group_size = options.num_processes * (options.gpus_per_process or 1)
inputs_data = ([
@ -578,13 +571,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
func_mode, axis, prefer_collective_v2):
func_mode, axis, prefer_unique_instance_key):
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
@ -624,7 +617,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation):
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
CollectiveReplicaLauncher._prefer_unique_instance_key = True
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
@ -653,15 +646,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation=[
CommunicationImplementation.NCCL, CommunicationImplementation.RING
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
required_gpus,
implementation,
prefer_collective_v2):
prefer_unique_instance_key):
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
@ -715,13 +708,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation=[
CommunicationImplementation.NCCL, CommunicationImplementation.RING
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testInputsAreFunctionArgs(self, num_processes, required_gpus,
implementation, prefer_collective_v2):
implementation, prefer_unique_instance_key):
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(implementation=implementation)
@ -757,16 +750,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testTimeoutReduceDense(self, num_processes, implementation, required_gpus,
prefer_collective_v2):
prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, task_id = self.make_collective(
num_processes, required_gpus)
if task_id != 0:
@ -794,16 +788,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testTimeoutBatchReduceDense(self, num_processes, implementation,
required_gpus, prefer_collective_v2):
required_gpus, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, task_id = self.make_collective(
num_processes, required_gpus)
if task_id != 0:
@ -832,16 +826,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testTimeoutReduceSparse(self, num_processes, implementation,
required_gpus, prefer_collective_v2):
required_gpus, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, task_id = self.make_collective(
num_processes, required_gpus)
if task_id != 0:
@ -871,16 +865,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
prefer_unique_instance_key=[True, False]))
def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
implementation, prefer_collective_v2):
implementation, prefer_unique_instance_key):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
CollectiveReplicaLauncher._prefer_unique_instance_key = (
prefer_unique_instance_key)
collective, devices, task_id = self.make_collective(
num_processes, required_gpus)
if task_id != 0:
@ -908,8 +902,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
def testNcclOrdering(self, num_processes, required_gpus):
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = True
CollectiveReplicaLauncher._prefer_unique_instance_key = True
CollectiveReplicaLauncher._prefer_ordering_token = True
collective, devices, _ = self.make_collective(num_processes,
required_gpus)
options = collective_util.Options(

View File

@ -257,8 +257,7 @@ class CollectiveKeys(object):
class CollectiveReplicaLauncher(object):
"""Launch collectives on one replica."""
_prefer_scoped_allocator = True
_prefer_collective_v2 = True
_prefer_unique_instance_key = True
_prefer_ordering_token = True
def __init__(self,
@ -281,26 +280,19 @@ class CollectiveReplicaLauncher(object):
return ops.control_dependencies([control_input])
return ops.NullContextmanager()
def _use_collective_v2(self):
def _use_unique_instance_key(self):
if not ops.executing_eagerly_outside_functions():
return False
return CollectiveReplicaLauncher._prefer_collective_v2
def _use_scoped_allocator(self):
if self._use_collective_v2():
# ScopedAllocator doesn't support collective V2.
return False
return CollectiveReplicaLauncher._prefer_scoped_allocator
return CollectiveReplicaLauncher._prefer_unique_instance_key
def _use_ordering_token(self):
if not self._use_collective_v2():
# Only collective V2 supports ordering token.
if not ops.executing_eagerly_outside_functions():
return False
return CollectiveReplicaLauncher._prefer_ordering_token
def _next_instance_key(self):
"""Returns the next instance key."""
if self._use_collective_v2():
if self._use_unique_instance_key():
# Assigning instance keys at function building time have issues since
# different workers may retrace the function at different times. With
# collective V2 we can use capture_call_time_value to use a placeholder as
@ -360,23 +352,14 @@ class CollectiveReplicaLauncher(object):
ordering_token = self._get_ordering_token(communication_hint)
with ops.device(self._device), \
self._control_input(control_input):
if self._use_collective_v2():
return collective_ops.all_reduce_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout,
ordering_token=ordering_token)
else:
return collective_ops.all_reduce(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
return collective_ops.all_reduce_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout,
ordering_token=ordering_token)
def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
"""All-gather a dense tensor.
@ -393,23 +376,14 @@ class CollectiveReplicaLauncher(object):
instance_key = self._next_instance_key()
ordering_token = self._get_ordering_token(communication_hint)
with ops.device(self._device):
if self._use_collective_v2():
return collective_ops.all_gather_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout,
ordering_token=ordering_token)
else:
return collective_ops.all_gather(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
return collective_ops.all_gather_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout,
ordering_token=ordering_token)
def batch_all_reduce(self,
input_tensor_packs,
@ -430,17 +404,17 @@ class CollectiveReplicaLauncher(object):
Returns:
A flat list of reduced tensors.
"""
# We don't batch with concat in eager. It's easy to get it wrong because
# we need to avoid any numpy() calls on values produced by the async
# executor. This effectively disables batching in eager, but it's unlikely
# to all-reduce a large number of tensors in eager.
batch_with_concat = (not self._use_scoped_allocator() and
not context.executing_eagerly())
outputs = []
for pack in input_tensor_packs:
# TODO(b/169168846): inserts a parallel all_gather to verify packings
# are the same on each replica.
if batch_with_concat:
if context.executing_eagerly():
# We don't batch in eager as it sometimes makes the performance worse
# due the concat/split ops.
for input_tensor in pack:
outputs.append(
self.all_reduce(input_tensor, None, communication_hint, timeout))
else:
# TODO(b/169168846): inserts a parallel all_gather to verify packings
# are the same on each replica.
with ops.device(self._device):
flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
shapes = [array_ops.shape(t) for t in pack]
@ -455,19 +429,6 @@ class CollectiveReplicaLauncher(object):
flat_outputs = array_ops.split(reduced, num_elements, axis=0)
for shape, flat_output in zip(shapes, flat_outputs):
outputs.append(array_ops.reshape(flat_output, shape))
else:
# By placing all CollectiveReduce ops in a batch under single name
# scope, we ensure they will be picked up by the `ScopedAllocator`
# grappler optimizer and packed into a single all-reduce.
with ops.name_scope('allreduce'):
for input_tensor in pack:
if communication_hint == 'NCCL' and outputs:
control_input = outputs[-1]
else:
control_input = None
outputs.append(
self.all_reduce(input_tensor, control_input, communication_hint,
timeout))
return outputs