Enable NCCL for all all-reduces

PiperOrigin-RevId: 344134571
Change-Id: I7234c42d196716570c820c1714376a2b4311cc06
This commit is contained in:
Ran Chen 2020-11-24 14:40:57 -08:00 committed by TensorFlower Gardener
parent d45e8258f1
commit 239a96c0ac
3 changed files with 48 additions and 13 deletions

View File

@ -1030,6 +1030,7 @@ cuda_py_test(
name = "cross_device_ops_test",
srcs = ["cross_device_ops_test.py"],
python_version = "PY3",
shard_count = 2,
tags = [
"multi_and_single_gpu",
"no_cuda_asan", # times out

View File

@ -316,15 +316,21 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
num_processes=[1, 2],
required_gpus=[0, 1, 2],
implementation=[
# NCCL is only used for batch reduce, so we are not including
# NCCL combination here.
CommunicationImplementation.AUTO,
CommunicationImplementation.RING
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
prefer_collective_v2=[True, False]))
def testAllReduceDense(self, num_processes, required_gpus, implementation,
reduce_op, prefer_collective_v2):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
if (num_processes == 2 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
"physical GPUs for every process.")
options = self.RunOptions(
num_processes=num_processes,
gpus_per_process=required_gpus,
@ -351,16 +357,22 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
num_processes=[1, 2],
required_gpus=[0, 1, 2],
implementation=[
# NCCL is only used for batch reduce, so we are not including
# NCCL combination here.
CommunicationImplementation.AUTO,
CommunicationImplementation.RING
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
],
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
reduce_op=ReduceOp.SUM,
prefer_collective_v2=[True, False]))
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
reduce_op, prefer_collective_v2):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
if (num_processes == 2 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip NCCL + 2 processes combination. NCCL requires "
"physical GPUs for every process.")
options = self.RunOptions(
mode=["func_graph"], # Sparse reduce is not supported in eager.
num_processes=num_processes,
@ -427,7 +439,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
required_gpus=[0, 1, 2],
implementation=[
CommunicationImplementation.AUTO,
CommunicationImplementation.RING, CommunicationImplementation.NCCL
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
],
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
prefer_scoped_allocator=[True, False],
@ -561,8 +574,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
axis=[0, 1, 2],
func_mode=["eager", "func_graph"],
implementation=[
CommunicationImplementation.AUTO,
CommunicationImplementation.RING,
CommunicationImplementation.NCCL,
CommunicationImplementation.AUTO, CommunicationImplementation.RING
],
prefer_collective_v2=[True, False]))
def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
@ -740,11 +754,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
implementation=[CommunicationImplementation.RING],
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
def testTimeoutReduceDense(self, num_processes, implementation, required_gpus,
prefer_collective_v2):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
prefer_collective_v2)
@ -772,10 +791,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
implementation=[CommunicationImplementation.RING],
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
def testTimeoutBatchReduceDense(self, num_processes, implementation,
required_gpus, prefer_collective_v2):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
@ -805,10 +829,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
implementation=[CommunicationImplementation.RING],
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
def testTimeoutReduceSparse(self, num_processes, implementation,
required_gpus, prefer_collective_v2):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
@ -839,10 +868,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
combinations.combine(
num_processes=2,
required_gpus=[0, 1],
implementation=[CommunicationImplementation.RING],
implementation=[
CommunicationImplementation.RING, CommunicationImplementation.NCCL
],
prefer_collective_v2=[True, False]))
def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
implementation, prefer_collective_v2):
if (required_gpus == 0 and
implementation == CommunicationImplementation.NCCL):
self.skipTest("Skip CPU + NCCL combination")
def replica_fn():
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (

View File

@ -259,7 +259,7 @@ class CollectiveReplicaLauncher(object):
_prefer_scoped_allocator = True
_prefer_collective_v2 = True
_prefer_ordering_token = False
_prefer_ordering_token = True
def __init__(self,
group_key,