Support V2 collective in dist strat
With V2 collective ops we can defer instance key assignment to execution time. This avoid the problem that different worker may retrace at different pace. We use capture_call_time_value instead of a instance key tf.Variable. The latter doesn't work if there're exceptions in the execution since the variable on different workers can go out of sync. PiperOrigin-RevId: 338175667 Change-Id: Ie0f5e607a25485c4c10de4a6cae137cb2b7ad729
This commit is contained in:
parent
c3a4d7ee32
commit
c3f422ffb9
tensorflow/python/distribute
@ -46,6 +46,8 @@ from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
CommunicationImplemenation = collective_util.CommunicationImplemenation
|
||||
@ -203,10 +205,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"reduce_op",
|
||||
"communication_options",
|
||||
"use_scoped_allocator",
|
||||
"use_collective_v2",
|
||||
])
|
||||
RunOptions.__new__.__defaults__ = (["eager",
|
||||
"func_graph"], 2, 0, ReduceOp.SUM,
|
||||
collective_util.Options(), True)
|
||||
collective_util.Options(), True, False)
|
||||
|
||||
def reduce_and_verify(self, inputs, expect, options):
|
||||
"""Reduce the given `inputs` and verify the output matches `expect`.
|
||||
@ -220,6 +223,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
"""
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
options.use_collective_v2)
|
||||
collective, devices, pid = self.make_collective(options.num_processes,
|
||||
options.gpus_per_process)
|
||||
|
||||
@ -259,6 +264,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_scoped_allocator = (
|
||||
options.use_scoped_allocator)
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
options.use_collective_v2)
|
||||
collective, devices, pid = self.make_collective(options.num_processes,
|
||||
options.gpus_per_process)
|
||||
|
||||
@ -303,15 +310,17 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
CommunicationImplemenation.AUTO,
|
||||
CommunicationImplemenation.RING
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN]))
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
use_collective_v2=[True, False]))
|
||||
def testAllReduceDense(self, num_processes, required_gpus, implementation,
|
||||
reduce_op):
|
||||
reduce_op, use_collective_v2):
|
||||
options = self.RunOptions(
|
||||
num_processes=num_processes,
|
||||
gpus_per_process=required_gpus,
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation))
|
||||
implementation=implementation),
|
||||
use_collective_v2=use_collective_v2)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [1.0, 2.0, 3.0, 4.0]
|
||||
@ -337,16 +346,18 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
CommunicationImplemenation.RING
|
||||
],
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM))
|
||||
reduce_op=ReduceOp.SUM,
|
||||
use_collective_v2=[True, False]))
|
||||
def testAllReduceSparse(self, num_processes, required_gpus, implementation,
|
||||
reduce_op):
|
||||
reduce_op, use_collective_v2):
|
||||
options = self.RunOptions(
|
||||
mode=["func_graph"], # Sparse reduce is not supported in eager.
|
||||
num_processes=num_processes,
|
||||
gpus_per_process=required_gpus,
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation))
|
||||
implementation=implementation),
|
||||
use_collective_v2=use_collective_v2)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [
|
||||
@ -377,7 +388,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
self.reduce_and_verify(inputs, expect, options)
|
||||
|
||||
def testAllReduceSparseVariableLength(self):
|
||||
@combinations.generate(combinations.combine(use_collective_v2=[True, False]))
|
||||
def testAllReduceSparseVariableLength(self, use_collective_v2):
|
||||
# One device per process, 2 processes, 2 replicas in total.
|
||||
inputs = [
|
||||
IndexedSlicesValue(values=[[1.]], indices=[0], dense_shape=[10, 1]),
|
||||
@ -394,7 +406,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
self.RunOptions(
|
||||
mode=["func_graph"], # Sparse reduce is not supported in eager.
|
||||
num_processes=2,
|
||||
reduce_op=ReduceOp.SUM))
|
||||
reduce_op=ReduceOp.SUM,
|
||||
use_collective_v2=use_collective_v2))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
@ -405,9 +418,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
CommunicationImplemenation.NCCL
|
||||
],
|
||||
reduce_op=[ReduceOp.SUM, ReduceOp.MEAN],
|
||||
use_scoped_allocator=[True, False]))
|
||||
use_scoped_allocator=[True, False],
|
||||
use_collective_v2=[True, False]))
|
||||
def testBatchAllReduceDense(self, num_processes, required_gpus,
|
||||
implementation, reduce_op, use_scoped_allocator):
|
||||
implementation, reduce_op, use_scoped_allocator,
|
||||
use_collective_v2):
|
||||
if required_gpus == 0 and implementation == CommunicationImplemenation.NCCL:
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
if num_processes == 2 and implementation == CommunicationImplemenation.NCCL:
|
||||
@ -420,7 +435,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
use_scoped_allocator=use_scoped_allocator)
|
||||
use_scoped_allocator=use_scoped_allocator,
|
||||
use_collective_v2=use_collective_v2)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]
|
||||
@ -446,9 +462,11 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
],
|
||||
# TODO(b/166682130): add MEAN reduce once the bug is fixed.
|
||||
reduce_op=ReduceOp.SUM,
|
||||
use_scoped_allocator=[True, False]))
|
||||
use_scoped_allocator=[True, False],
|
||||
use_collective_v2=[True, False]))
|
||||
def testBatchAllReduceSparse(self, num_processes, required_gpus,
|
||||
implementation, reduce_op, use_scoped_allocator):
|
||||
implementation, reduce_op, use_scoped_allocator,
|
||||
use_collective_v2):
|
||||
if required_gpus == 0 and implementation == CommunicationImplemenation.NCCL:
|
||||
self.skipTest("Skip CPU + NCCL combination")
|
||||
if num_processes == 2 and implementation == CommunicationImplemenation.NCCL:
|
||||
@ -462,7 +480,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
reduce_op=reduce_op,
|
||||
communication_options=collective_util.Options(
|
||||
implementation=implementation),
|
||||
use_scoped_allocator=use_scoped_allocator)
|
||||
use_scoped_allocator=use_scoped_allocator,
|
||||
use_collective_v2=use_collective_v2)
|
||||
group_size = options.num_processes * (options.gpus_per_process or 1)
|
||||
|
||||
inputs_data = ([
|
||||
@ -528,11 +547,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
implementation=[
|
||||
CommunicationImplemenation.NCCL, CommunicationImplemenation.AUTO,
|
||||
CommunicationImplemenation.RING
|
||||
]))
|
||||
],
|
||||
use_collective_v2=[True, False]))
|
||||
def testAllGatherSameShape(self, num_processes, required_gpus, implementation,
|
||||
func_mode, axis):
|
||||
func_mode, axis, use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -563,18 +585,53 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
num_processes=[1, 2],
|
||||
required_gpus=[0, 1, 2],
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
def testCollectiveV2ControlFlow(self, num_processes, required_gpus,
|
||||
implementation):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = True
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
value = make_per_replica_value(constant_op.constant([1.]), devices)
|
||||
|
||||
@def_function.function
|
||||
def reduce_fn():
|
||||
|
||||
def cond_body():
|
||||
reduced = collective.reduce(reduce_util.ReduceOp.SUM, value, value,
|
||||
options)
|
||||
return math_ops.add_n(self.as_list(reduced)) / len(devices)
|
||||
|
||||
return control_flow_ops.cond(
|
||||
array_ops.identity(False), cond_body, cond_body)
|
||||
|
||||
num_replicas = num_processes * len(devices)
|
||||
self.assertAllEqual(reduce_fn(), [1. * num_replicas])
|
||||
|
||||
get_global_mpr(num_processes).run(replica_fn)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
num_processes=1,
|
||||
required_gpus=2,
|
||||
implementation=[
|
||||
CommunicationImplemenation.NCCL, CommunicationImplemenation.RING
|
||||
]))
|
||||
],
|
||||
use_collective_v2=[True, False]))
|
||||
def testMultiThreadedCollectiveLaunchNoInterleave(self, num_processes,
|
||||
required_gpus,
|
||||
implementation):
|
||||
implementation,
|
||||
use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -627,11 +684,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
required_gpus=2,
|
||||
implementation=[
|
||||
CommunicationImplemenation.NCCL, CommunicationImplemenation.RING
|
||||
]))
|
||||
],
|
||||
use_collective_v2=[True, False]))
|
||||
def testInputsAreFunctionArgs(self, num_processes, required_gpus,
|
||||
implementation):
|
||||
implementation, use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, _ = self.make_collective(num_processes,
|
||||
required_gpus)
|
||||
options = collective_util.Options(implementation=implementation)
|
||||
@ -664,11 +724,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
def testTimeoutReduceDense(self, num_processes, implementation,
|
||||
required_gpus):
|
||||
implementation=[CommunicationImplemenation.RING],
|
||||
use_collective_v2=[True, False]))
|
||||
def testTimeoutReduceDense(self, num_processes, implementation, required_gpus,
|
||||
use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -680,7 +743,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def reduce_dense():
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
|
||||
return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're three workers in total.
|
||||
@ -693,11 +756,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
implementation=[CommunicationImplemenation.RING],
|
||||
use_collective_v2=[True, False]))
|
||||
def testTimeoutBatchReduceDense(self, num_processes, implementation,
|
||||
required_gpus):
|
||||
required_gpus, use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -709,8 +775,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def batch_reduce_dense():
|
||||
collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
|
||||
options)
|
||||
return collective.batch_reduce(reduce_util.ReduceOp.SUM,
|
||||
[(v, v), (v, v)], options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're two workers in total.
|
||||
@ -723,11 +789,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
implementation=[CommunicationImplemenation.RING],
|
||||
use_collective_v2=[True, False]))
|
||||
def testTimeoutReduceSparse(self, num_processes, implementation,
|
||||
required_gpus):
|
||||
required_gpus, use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -741,7 +810,7 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def reduce_sparse():
|
||||
collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
|
||||
return collective.reduce(reduce_util.ReduceOp.SUM, v, v, options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're two workers in total.
|
||||
@ -754,11 +823,14 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
combinations.combine(
|
||||
num_processes=2,
|
||||
required_gpus=[0, 1],
|
||||
implementation=[CommunicationImplemenation.RING]))
|
||||
implementation=[CommunicationImplemenation.RING],
|
||||
use_collective_v2=[True, False]))
|
||||
def testTimeoutBatchReduceSparse(self, num_processes, required_gpus,
|
||||
implementation):
|
||||
implementation, use_collective_v2):
|
||||
|
||||
def replica_fn():
|
||||
cross_device_utils.CollectiveReplicaLauncher._use_collective_v2 = (
|
||||
use_collective_v2)
|
||||
collective, devices, task_id = self.make_collective(
|
||||
num_processes, required_gpus)
|
||||
if task_id != 0:
|
||||
@ -772,8 +844,8 @@ class CollectiveOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@def_function.function
|
||||
def batch_reduce_sparse():
|
||||
collective.batch_reduce(reduce_util.ReduceOp.SUM, [(v, v), (v, v)],
|
||||
options)
|
||||
return collective.batch_reduce(reduce_util.ReduceOp.SUM,
|
||||
[(v, v), (v, v)], options)
|
||||
|
||||
# The collective should time out because we only launch it on worker-0,
|
||||
# while there're two workers in total.
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -256,6 +257,7 @@ class CollectiveReplicaLauncher(object):
|
||||
"""Launch collectives on one replica."""
|
||||
|
||||
_use_scoped_allocator = True
|
||||
_use_collective_v2 = False
|
||||
|
||||
def __init__(self,
|
||||
group_key,
|
||||
@ -283,6 +285,44 @@ class CollectiveReplicaLauncher(object):
|
||||
return ops.control_dependencies([control_input])
|
||||
return ops.NullContextmanager()
|
||||
|
||||
def _should_use_collective_v2(self):
|
||||
if not CollectiveReplicaLauncher._use_collective_v2:
|
||||
return False
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return False
|
||||
return True
|
||||
|
||||
def _next_instance_key(self):
|
||||
"""Returns the next instance key."""
|
||||
if self._should_use_collective_v2():
|
||||
# Assigning instance keys at function building time have issues since
|
||||
# different workers may retrace the function at different times. With
|
||||
# collective V2 we can use capture_call_time_value to use a placeholder as
|
||||
# the instance key and feed it at function call time. In this way we also
|
||||
# don't reuse instance keys, which allows for per-instance cancellation.
|
||||
graph = ops.get_default_graph()
|
||||
# Control flow ops don't work with capture_call_time_value, so we put the
|
||||
# capture in the function graph of that control flow op.
|
||||
while getattr(graph, 'is_control_flow_graph', False):
|
||||
graph = graph.outer_graph
|
||||
if not context.executing_eagerly() and graph.building_function:
|
||||
with graph.as_default():
|
||||
# Capture self._next_instance_key so that when building a function
|
||||
# that calls another tf.function, the instance key assignment is
|
||||
# further delayed until we actually call the function in eager. Note
|
||||
# that capture_call_time_value doesn't automatically propagate the
|
||||
# deferred capture to the outer function.
|
||||
return graph.capture_call_time_value(
|
||||
self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32))
|
||||
else:
|
||||
instance_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
with ops.device('CPU:0'):
|
||||
return ops.convert_to_tensor(instance_key, dtype=dtypes.int32)
|
||||
else:
|
||||
return self._collective_keys.get_instance_key(self._group_key,
|
||||
self._device)
|
||||
|
||||
def all_reduce(self,
|
||||
input_tensor,
|
||||
control_input=None,
|
||||
@ -304,18 +344,60 @@ class CollectiveReplicaLauncher(object):
|
||||
Returns:
|
||||
The reduced tensor.
|
||||
"""
|
||||
instance_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
instance_key = self._next_instance_key()
|
||||
with self._executor_scope(), \
|
||||
ops.device(self._device), \
|
||||
self._control_input(control_input):
|
||||
return collective_ops.all_reduce(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
if self._should_use_collective_v2():
|
||||
return collective_ops.all_reduce_v2(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
else:
|
||||
return collective_ops.all_reduce(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
|
||||
def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
|
||||
"""All-gather a dense tensor.
|
||||
|
||||
This can be called in eager mode if an async executor is supplied when
|
||||
creating the launcher.
|
||||
|
||||
Args:
|
||||
input_tensor: a dense tensor. It must have the same shape on all replicas.
|
||||
communication_hint: string providing hint to runtime for choosing
|
||||
collective implementation.
|
||||
timeout: a float. The timeout in seconds.
|
||||
|
||||
Returns:
|
||||
The reduced tensor.
|
||||
"""
|
||||
instance_key = self._next_instance_key()
|
||||
with self._executor_scope(), ops.device(self._device):
|
||||
if self._should_use_collective_v2():
|
||||
return collective_ops.all_gather_v2(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
else:
|
||||
return collective_ops.all_gather(
|
||||
input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key,
|
||||
communication_hint=communication_hint,
|
||||
timeout=timeout)
|
||||
|
||||
def batch_all_reduce(self,
|
||||
input_tensor_packs,
|
||||
@ -408,10 +490,6 @@ class CollectiveReplicaLauncher(object):
|
||||
if context.executing_eagerly():
|
||||
raise RuntimeError('all_gather in eager mode is not supported')
|
||||
|
||||
instance_key_tensor = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
instance_key_shape = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
with ops.device(self._device), \
|
||||
ops.control_dependencies([array_ops.identity(input_tensor)]):
|
||||
# 1. Transpose
|
||||
@ -425,11 +503,8 @@ class CollectiveReplicaLauncher(object):
|
||||
axis=0)
|
||||
input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
|
||||
# 2. Pad
|
||||
gathered_shape = collective_ops.all_gather(
|
||||
gathered_shape = self._all_gather(
|
||||
array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key_shape,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
first_dims = gathered_shape[:, 0]
|
||||
@ -437,16 +512,11 @@ class CollectiveReplicaLauncher(object):
|
||||
padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)
|
||||
|
||||
# 3. Gather
|
||||
gather_padded_out_tensor = collective_ops.all_gather(
|
||||
padded_input_tensor,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
instance_key_tensor,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
gather_padded_out_tensor = self._all_gather(
|
||||
padded_input_tensor, communication_hint, timeout=timeout)
|
||||
# 4. Unpad
|
||||
split_tensors = []
|
||||
for i in range(first_dims.shape[0]):
|
||||
for i in range(self._group_size):
|
||||
start_pos = i * full_axis_dim
|
||||
split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
|
||||
first_dims[i]])
|
||||
@ -483,15 +553,6 @@ class CollectiveReplicaLauncher(object):
|
||||
raise RuntimeError(
|
||||
'all_reduce_indexed_slices in eager mode is not supported')
|
||||
|
||||
gather_length_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
gather_indices_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
gather_values_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
reduce_densified_key = self._collective_keys.get_instance_key(
|
||||
self._group_key, self._device)
|
||||
|
||||
# Current CollectiveAllGather implementations require input IndexedSlices to
|
||||
# have consistent length across the board, we handle the reduction of
|
||||
# IndexedSlices as follows:
|
||||
@ -503,23 +564,13 @@ class CollectiveReplicaLauncher(object):
|
||||
|
||||
def all_gather():
|
||||
"""Use all_gather to aggregate `IndexedSlices`."""
|
||||
all_values = collective_ops.all_gather(
|
||||
input_slices.values,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
gather_values_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
all_values = self._all_gather(
|
||||
input_slices.values, communication_hint, timeout=timeout)
|
||||
# Add control dependency to order the all-gather.
|
||||
control = [all_values] if communication_hint == 'NCCL' else []
|
||||
with ops.control_dependencies(control):
|
||||
all_indices = collective_ops.all_gather(
|
||||
input_slices.indices,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
gather_indices_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
all_indices = self._all_gather(
|
||||
input_slices.indices, communication_hint, timeout=timeout)
|
||||
return ops.IndexedSlices(
|
||||
values=all_values,
|
||||
indices=all_indices,
|
||||
@ -528,15 +579,8 @@ class CollectiveReplicaLauncher(object):
|
||||
def densify_and_all_reduce():
|
||||
"""Use all_reduce to aggregate `IndexedSlices`."""
|
||||
densified = ops.convert_to_tensor(input_slices)
|
||||
reduced = collective_ops.all_reduce(
|
||||
densified,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
reduce_densified_key,
|
||||
'Add',
|
||||
'Id', [0],
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
reduced = self.all_reduce(
|
||||
densified, communication_hint=communication_hint, timeout=timeout)
|
||||
# We have to convert dense grad to IndexedSlice because all_reduce()
|
||||
# and all_gather() must have the same return type as required by
|
||||
# control_flow_ops.cond.
|
||||
@ -546,13 +590,8 @@ class CollectiveReplicaLauncher(object):
|
||||
dense_shape=input_slices.dense_shape)
|
||||
|
||||
length = array_ops.shape(input_slices.indices)
|
||||
all_lengths = collective_ops.all_gather(
|
||||
length,
|
||||
self._group_size,
|
||||
self._group_key,
|
||||
gather_length_key,
|
||||
communication_hint,
|
||||
timeout=timeout)
|
||||
all_lengths = self._all_gather(
|
||||
length, communication_hint, timeout=timeout)
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(
|
||||
math_ops.reduce_max(all_lengths),
|
||||
|
Loading…
Reference in New Issue
Block a user