Remove collective v1 code path
PiperOrigin-RevId: 354577402 Change-Id: I200d98a6a80dfe1e463044f9dedef9291ff7d846
This commit is contained in:
parent
cb4c7d28f0
commit
1a46fdc4a2
@ -52,6 +52,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
CollectiveReplicaLauncher = cross_device_utils.CollectiveReplicaLauncher
|
||||
CommunicationImplementation = collective_util.CommunicationImplementation
|
||||
ReduceOp = reduce_util.ReduceOp
|
||||
IndexedSlicesValue = indexed_slices.IndexedSlicesValue
|
||||
@ -107,9 +108,8 @@ def enable_collective_ops():
|
||||
protocol=cluster_resolver.rpc_layer)
|
||||
context.context().enable_collective_ops(server_def)
|
||||
# Recover default flag values.
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = True
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = False
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = True
|
||||
CollectiveReplicaLauncher._prefer_ordering_token = False
|
||||
|
||||
|
||||
class MultiProcessPoolRunner():
|
||||
@ -215,12 +215,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"gpus_per_process",
|
||||
"reduce_op",
|
||||
"communication_options",
|
||||
"prefer_scoped_allocator",
|
||||
"prefer_collective_v2",
|
||||
"prefer_unique_instance_key",
|
||||
])
|
||||
RunOptions.__new__.__defaults__ = (["eager",
|
||||
"func_graph"], 2, 0, ReduceOp.SUM,
|
||||
collective_util.Options(), True, False)
|
||||
collective_util.Options(), True)
|
||||
|
||||
def reduce_and_verify(self, inputs, expect, options):
|
||||
"""Reduce the given `inputs` and verify the output matches `expect`.
|
||||
@ -234,8 +233,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
options.prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
options.prefer_unique_instance_key)
|
||||
collective, devices, pid = self.make_collective(options.num_processes,
|
||||
options.gpus_per_process)
|
||||
|
||||
@ -273,10 +272,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_scoped_allocator = (
|
||||
options.prefer_scoped_allocator)
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
options.prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
options.prefer_unique_instance_key)
|
||||
collective, devices, pid = self.make_collective(options.num_processes,
|
||||
options.gpus_per_process)
|
||||
|
||||
@ -321,9 +318,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
CommunicationImplementation.NCCL,
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testAllReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_collective_v2):
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -337,7 +334,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
prefer_collective_v2=prefer_collective_v2)
|
||||
prefer_unique_instance_key=prefer_unique_instance_key)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [1.0, 2.0, 3.0, 4.0]
|
||||
@ -363,9 +360,9 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
],
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM,
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op, prefer_collective_v2):
|
||||
reduce_op, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -380,7 +377,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
prefer_collective_v2=prefer_collective_v2)
|
||||
prefer_unique_instance_key=prefer_unique_instance_key)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [
|
||||
@ -412,8 +409,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.reduce_and_verify(inputs, expect, options)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(prefer_collective_v2=[True, False]))
|
||||
def testAllReduceSparseVariableLength(self, prefer_collective_v2):
|
||||
combinations.combine(prefer_unique_instance_key=[True, False]))
|
||||
def testAllReduceSparseVariableLength(self, prefer_unique_instance_key):
|
||||
# One device per process, 2 processes, 2 replicas in total.
|
||||
inputs = [
|
||||
IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
|
||||
@ -431,7 +428,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
mode=["func_graph"], # Sparse reduce is not supported in eager.
|
||||
num_processes=2,
|
||||
reduce_op=ReduceOp.SUM,
|
||||
prefer_collective_v2=prefer_collective_v2))
|
||||
prefer_unique_instance_key=prefer_unique_instance_key))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
@ -443,11 +440,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
CommunicationImplementation.NCCL,
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
prefer_scoped_allocator=[True, False],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testBatchAllReduceDense(self, num_processes, required_gpus,
|
||||
implementation, reduce_op,
|
||||
prefer_scoped_allocator, prefer_collective_v2):
|
||||
prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -462,8 +458,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
prefer_scoped_allocator=prefer_scoped_allocator,
|
||||
prefer_collective_v2=prefer_collective_v2)
|
||||
prefer_unique_instance_key=prefer_unique_instance_key)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
|
||||
@ -489,11 +484,10 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
],
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM,
|
||||
prefer_scoped_allocator=[True, False],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testBatchAllReduceSparse(self, num_processes, required_gpus,
|
||||
implementation, reduce_op,
|
||||
prefer_scoped_allocator, prefer_collective_v2):
|
||||
prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
@ -509,8 +503,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
prefer_scoped_allocator=prefer_scoped_allocator,
|
||||
prefer_collective_v2=prefer_collective_v2)
|
||||
prefer_unique_instance_key=prefer_unique_instance_key)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = ([
|
||||
@ -578,13 +571,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
CommunicationImplementation.RING,
|
||||
CommunicationImplementation.NCCL,
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
|
||||
func_mode, axis, prefer_collective_v2):
|
||||
func_mode, axis, prefer_unique_instance_key):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -624,7 +617,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = True
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -653,15 +646,15 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplementation.NCCL, CommunicationImplementation.RING
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
|
||||
required_gpus,
|
||||
implementation,
|
||||
prefer_collective_v2):
|
||||
prefer_unique_instance_key):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -715,13 +708,13 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplementation.NCCL, CommunicationImplementation.RING
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testInputsAreFunctionArgs(self, num_processes, required_gpus,
|
||||
implementation, prefer_collective_v2):
|
||||
implementation, prefer_unique_instance_key):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -757,16 +750,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplementation.RING, CommunicationImplementation.NCCL
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testTimeoutReduceDense(self, num_processes, implementation, required_gpus,
|
||||
prefer_collective_v2):
|
||||
prefer_unique_instance_key):
|
||||
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -794,16 +788,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplementation.RING, CommunicationImplementation.NCCL
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testTimeoutBatchReduceDense(self, num_processes, implementation,
|
||||
required_gpus, prefer_collective_v2):
|
||||
required_gpus, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -832,16 +826,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplementation.RING, CommunicationImplementation.NCCL
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testTimeoutReduceSparse(self, num_processes, implementation,
|
||||
required_gpus, prefer_collective_v2):
|
||||
required_gpus, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -871,16 +865,16 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplementation.RING, CommunicationImplementation.NCCL
|
||||
],
|
||||
prefer_collective_v2=[True, False]))
|
||||
prefer_unique_instance_key=[True, False]))
|
||||
def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
|
||||
implementation, prefer_collective_v2):
|
||||
implementation, prefer_unique_instance_key):
|
||||
if (required_gpus == 0 and
|
||||
implementation == CommunicationImplementation.NCCL):
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = (
|
||||
prefer_collective_v2)
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = (
|
||||
prefer_unique_instance_key)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -908,8 +902,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def testNcclOrdering(self, num_processes, required_gpus):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_collective_v2 = True
|
||||
cross_device_utils.CollectiveReplicaLauncher._prefer_ordering_token = True
|
||||
CollectiveReplicaLauncher._prefer_unique_instance_key = True
|
||||
CollectiveReplicaLauncher._prefer_ordering_token = True
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(
|
||||
|
@ -257,8 +257,7 @@ class CollectiveKeys(object):
|
||||
class CollectiveReplicaLauncher(object):
|
||||
"""Launch collectives on one replica."""
|
||||
|
||||
_prefer_scoped_allocator = True
|
||||
_prefer_collective_v2 = True
|
||||
_prefer_unique_instance_key = True
|
||||
_prefer_ordering_token = True
|
||||
|
||||
def __init__(self,
|
||||
@ -281,26 +280,19 @@ class CollectiveReplicaLauncher(object):
|
||||
return ops.control_dependencies([control_input])
|
||||
return ops.NullContextmanager()
|
||||
|
||||
def _use_collective_v2(self):
|
||||
def _use_unique_instance_key(self):
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return False
|
||||
return CollectiveReplicaLauncher._prefer_collective_v2
|
||||
|
||||
def _use_scoped_allocator(self):
|
||||
if self._use_collective_v2():
|
||||
# ScopedAllocator doesn't support collective V2.
|
||||
return False
|
||||
return CollectiveReplicaLauncher._prefer_scoped_allocator
|
||||
return CollectiveReplicaLauncher._prefer_unique_instance_key
|
||||
|
||||
def _use_ordering_token(self):
|
||||
if not self._use_collective_v2():
|
||||
# Only collective V2 supports ordering token.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return False
|
||||
return CollectiveReplicaLauncher._prefer_ordering_token
|
||||
|
||||
def _next_instance_key(self):
|
||||
"""Returns the next instance key."""
|
||||
if self._use_collective_v2():
|
||||
if self._use_unique_instance_key():
|
||||
# Assigning instance keys at function building time have issues since
|
||||
# different workers may retrace the function at different times. With
|
||||
# collective V2 we can use capture_call_time_value to use a placeholder as
|
||||
@ -360,23 +352,14 @@ class CollectiveReplicaLauncher(object):
|
||||
ordering_token = self._get_ordering_token(communication_hint)
|
||||
with ops.device(self._device), \
|
||||
self._control_input(control_input):
|
||||
if self._use_collective_v2():
|
||||
return collective_ops.all_reduce_v2(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout,
|
||||
ordering_token=ordering_token)
|
||||
else:
|
||||
return collective_ops.all_reduce(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
return collective_ops.all_reduce_v2(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout,
|
||||
ordering_token=ordering_token)
|
||||
|
||||
def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
|
||||
"""All-gather a dense tensor.
|
||||
@ -393,23 +376,14 @@ class CollectiveReplicaLauncher(object):
|
||||
instance_key = self._next_instance_key()
|
||||
ordering_token = self._get_ordering_token(communication_hint)
|
||||
with ops.device(self._device):
|
||||
if self._use_collective_v2():
|
||||
return collective_ops.all_gather_v2(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout,
|
||||
ordering_token=ordering_token)
|
||||
else:
|
||||
return collective_ops.all_gather(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
return collective_ops.all_gather_v2(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout,
|
||||
ordering_token=ordering_token)
|
||||
|
||||
def batch_all_reduce(self,
|
||||
input_tensor_packs,
|
||||
@ -430,17 +404,17 @@ class CollectiveReplicaLauncher(object):
|
||||
Returns:
|
||||
A flat list of reduced tensors.
|
||||
"""
|
||||
# We don't batch with concat in eager. It's easy to get it wrong because
|
||||
# we need to avoid any numpy() calls on values produced by the async
|
||||
# executor. This effectively disables batching in eager, but it's unlikely
|
||||
# to all-reduce a large number of tensors in eager.
|
||||
batch_with_concat = (not self._use_scoped_allocator() and
|
||||
not context.executing_eagerly())
|
||||
outputs = []
|
||||
for pack in input_tensor_packs:
|
||||
# TODO(b/169168846): inserts a parallel all_gather to verify packings
|
||||
# are the same on each replica.
|
||||
if batch_with_concat:
|
||||
if context.executing_eagerly():
|
||||
# We don't batch in eager as it sometimes makes the performance worse
|
||||
# due the concat/split ops.
|
||||
for input_tensor in pack:
|
||||
outputs.append(
|
||||
self.all_reduce(input_tensor, None, communication_hint, timeout))
|
||||
else:
|
||||
# TODO(b/169168846): inserts a parallel all_gather to verify packings
|
||||
# are the same on each replica.
|
||||
with ops.device(self._device):
|
||||
flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
|
||||
shapes = [array_ops.shape(t) for t in pack]
|
||||
@ -455,19 +429,6 @@ class CollectiveReplicaLauncher(object):
|
||||
flat_outputs = array_ops.split(reduced, num_elements, axis=0)
|
||||
for shape, flat_output in zip(shapes, flat_outputs):
|
||||
outputs.append(array_ops.reshape(flat_output, shape))
|
||||
else:
|
||||
# By placing all CollectiveReduce ops in a batch under single name
|
||||
# scope, we ensure they will be picked up by the `ScopedAllocator`
|
||||
# grappler optimizer and packed into a single all-reduce.
|
||||
with ops.name_scope('allreduce'):
|
||||
for input_tensor in pack:
|
||||
if communication_hint == 'NCCL' and outputs:
|
||||
control_input = outputs[-1]
|
||||
else:
|
||||
control_input = None
|
||||
outputs.append(
|
||||
self.all_reduce(input_tensor, control_input, communication_hint,
|
||||
timeout))
|
||||
|
||||
return outputs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user