Swap the use of NcclAllReduce for NCCL Collective Ops in MirroredStrategy.

Also remove the use of async executor to launch collective ops in eager mode and use one thread per device instead. This resolves the issue of not being able to call numpy() on the result of async executor. This change applies to MWMS too.

PiperOrigin-RevId: 353355403
Change-Id: I9c9f30dfe18dc830a4a8fa9bbaec042c7c2edd8f
This commit is contained in:
Xinyi Wang 2021-01-22 18:11:37 -08:00 committed by TensorFlower Gardener
parent 9304b8af0e
commit 30fb80d468
7 changed files with 135 additions and 74 deletions

View File

@ -20,6 +20,8 @@ from __future__ import print_function
import collections
import copy
import multiprocessing.dummy
import multiprocessing.pool
import threading
import six
@ -36,7 +38,6 @@ from tensorflow.python.distribute import values as value_lib
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import executor as executor_lib
from tensorflow.python.framework import kernels
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -1013,8 +1014,8 @@ class CollectiveAllReduce(CrossDeviceOps):
# deadlocks. E.g. if two user threads both are launching collectives:
# user-thread-0 device0 device1
# user-thread-1 device0 device1
# In eager mode, we use one executor per device. Executors use single FIFO
# queues, so the above launch sequences end up with the following queues:
# In eager mode, we use one thread per device to launch collective ops, so
# the above launch sequences end up with the following queues:
# device-0 collective-0 collective-1
# device-1 collective-1 collective-0
# This deadlocks since neither collective is able to finish.
@ -1022,26 +1023,20 @@ class CollectiveAllReduce(CrossDeviceOps):
self._devices = tuple(device_util.canonicalize(d) for d in devices)
group_key = self._collective_keys.get_group_key(self._devices)
# Collective ops requires all devices to participate and is blocking. In
# eager, we need one async executor for each device to be able to launch
# them altogether. Note that async doesn't imply concurrency. Within an
# async executor operations are still executed sequentially. In graph or
# function building, the executors are not used.
self._executors = []
self._launchers = []
# Whether to only use NCCL for batched all-reduce when NCCL is requested.
# This is because of the lack of mechanism to order NCCL operations
# deterministically.
self._limited_nccl = False
for device in self._devices:
executor = executor_lib.new_executor(enable_async=True)
self._executors.append(executor)
launcher = cross_device_utils.CollectiveReplicaLauncher(
group_key, group_size, self._collective_keys, device, executor)
group_key, group_size, self._collective_keys, device)
self._launchers.append(launcher)
if not launcher.can_order_nccl():
self._limited_nccl = True
self._pool = multiprocessing.pool.ThreadPool(len(self._devices))
super(CollectiveAllReduce, self).__init__()
@property
@ -1147,22 +1142,31 @@ class CollectiveAllReduce(CrossDeviceOps):
for i in range(len(self._devices)):
values_by_device[i].append(per_replica.values[i])
outputs_by_device = []
with self._lock:
for i in range(len(self._devices)):
packs = cross_device_utils.group_by_size(
values_by_device[i], options.bytes_per_pack)
if not context.executing_eagerly() and i == 0:
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
"group_size = %d, implementation = %s, num_packs = %d",
batch_size, len(self._launchers), self._group_size,
implementation, len(packs))
outputs_by_device.append(self._launchers[i].batch_all_reduce(
packs, implementation, options.timeout_seconds))
if context.executing_eagerly():
def thread_fn(device_id):
with context.eager_mode():
packs = cross_device_utils.group_by_size(values_by_device[device_id],
options.bytes_per_pack)
return self._launchers[device_id].batch_all_reduce(
packs, implementation, options.timeout_seconds)
for e in self._executors:
e.wait()
num_devices = len(self._devices)
with self._lock:
outputs_by_device = self._pool.map(thread_fn, list(range(num_devices)))
else:
outputs_by_device = []
with self._lock:
for i in range(len(self._devices)):
packs = cross_device_utils.group_by_size(
values_by_device[i], options.bytes_per_pack)
if i == 0:
logging.info(
"Collective batch_all_reduce: %d all-reduces, num_devices = %d,"
" group_size = %d, implementation = %s, num_packs = %d",
batch_size, len(self._launchers), self._group_size,
implementation, len(packs))
outputs_by_device.append(self._launchers[i].batch_all_reduce(
packs, implementation, options.timeout_seconds))
mirrored = []
for values in zip(*outputs_by_device):

View File

@ -265,28 +265,17 @@ class CollectiveReplicaLauncher(object):
group_key,
group_size,
collective_keys,
device,
executor=None):
if executor and not executor.is_async():
raise ValueError('executor must be async')
device):
self._group_key = group_key
self._group_size = group_size
self._collective_keys = collective_keys
self._device = device
self._executor = executor
if self._use_ordering_token():
with ops.init_scope(), ops.device(device):
self._ordering_token = resource_variable_ops.ResourceVariable(0.)
else:
self._ordering_token = None
def _executor_scope(self):
if context.executing_eagerly() and not self._executor:
raise ValueError('collectives requires a async executor in eager mode')
if context.executing_eagerly():
return context.executor_scope(self._executor)
return ops.NullContextmanager()
def _control_input(self, control_input):
if control_input is not None and not self._use_ordering_token():
return ops.control_dependencies([control_input])
@ -356,9 +345,6 @@ class CollectiveReplicaLauncher(object):
timeout=0):
"""All-reduce a dense tensor.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
control_input: if not None, add control edges between control_input and
@ -372,8 +358,7 @@ class CollectiveReplicaLauncher(object):
"""
instance_key = self._next_instance_key()
ordering_token = self._get_ordering_token(communication_hint)
with self._executor_scope(), \
ops.device(self._device), \
with ops.device(self._device), \
self._control_input(control_input):
if self._use_collective_v2():
return collective_ops.all_reduce_v2(
@ -396,9 +381,6 @@ class CollectiveReplicaLauncher(object):
def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
"""All-gather a dense tensor.
This can be called in eager mode if an async executor is supplied when
creating the launcher.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
communication_hint: string providing hint to runtime for choosing
@ -410,7 +392,7 @@ class CollectiveReplicaLauncher(object):
"""
instance_key = self._next_instance_key()
ordering_token = self._get_ordering_token(communication_hint)
with self._executor_scope(), ops.device(self._device):
with ops.device(self._device):
if self._use_collective_v2():
return collective_ops.all_gather_v2(
input_tensor,
@ -439,9 +421,6 @@ class CollectiveReplicaLauncher(object):
benefit that it doesn't need to wait for all inputs to be ready to start the
all-reduce.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensor_packs: a list of lists of dense tensors.
communication_hint: string providing hint to runtime for choosing

View File

@ -27,6 +27,7 @@ from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import test_util
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
@ -991,4 +992,4 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
if __name__ == "__main__":
test.main()
test_util.main()

View File

@ -22,6 +22,7 @@ import copy
from tensorflow.python.distribute import collective_util
from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribute_utils
@ -37,6 +38,7 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -268,6 +270,9 @@ class MirroredStrategy(distribute_lib.Strategy):
the particular hardware is available.
"""
# Only set this in tests.
_collective_key_base = 0
def __init__(self, devices=None, cross_device_ops=None):
extended = MirroredExtended(
self, devices=devices, cross_device_ops=cross_device_ops)
@ -281,6 +286,9 @@ class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missin
__doc__ = MirroredStrategy.__doc__
# Only set this in tests.
_collective_key_base = 0
def __init__(self, devices=None, cross_device_ops=None):
extended = MirroredExtended(
self, devices=devices, cross_device_ops=cross_device_ops)
@ -293,6 +301,10 @@ class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missin
class MirroredExtended(distribute_lib.StrategyExtendedV1):
"""Implementation of MirroredStrategy."""
# If this is set to True, use NCCL collective ops instead of NCCL cross device
# ops.
_prefer_collective_ops = True
def __init__(self, container_strategy, devices=None, cross_device_ops=None):
super(MirroredExtended, self).__init__(container_strategy)
if context.executing_eagerly():
@ -314,7 +326,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
assert devices, ("Got an empty `devices` list and unable to recognize "
"any local devices.")
self._cross_device_ops = cross_device_ops
self._communication_options = collective_util.Options()
self._communication_options = collective_util.Options(
implementation=collective_util.CommunicationImplementation.NCCL)
self._collective_ops_in_use = False
self._collective_key_base = container_strategy._collective_key_base
self._initialize_strategy(devices)
# TODO(b/128995245): Enable last partial batch support in graph mode.
@ -333,9 +348,34 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
"No duplicates allowed in `devices` argument: %s" % (devices,))
if _is_device_list_single_worker(devices):
self._initialize_single_worker(devices)
if self._prefer_collective_ops and isinstance(
self._cross_device_ops,
cross_device_ops_lib.NcclAllReduce) or isinstance(
self._inferred_cross_device_ops,
cross_device_ops_lib.NcclAllReduce):
self._use_collective_ops(devices)
self._inferred_cross_device_ops = None
logging.info("Using MirroredStrategy with devices %r", devices)
else:
self._initialize_multi_worker(devices)
def _use_collective_ops(self, devices):
if ops.executing_eagerly_outside_functions():
try:
context.context().configure_collective_ops(
scoped_allocator_enabled_ops=("CollectiveReduce",))
except RuntimeError:
logging.warning("Collective ops is not configured at program startup."
" Some performance features may not be enabled.")
self._collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=1 + self._collective_key_base) # pylint: disable=protected-access
self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
devices=self._devices,
group_size=len(self._devices),
collective_keys=self._collective_keys)
self._collective_ops_in_use = True
def _initialize_single_worker(self, devices):
"""Initializes the object for single-worker training."""
self._devices = tuple(device_util.canonicalize(d) for d in devices)
@ -347,7 +387,6 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
self._host_input_device = numpy_dataset.SingleDevice(
self._input_workers_devices[0][0])
self._is_multi_worker_training = False
logging.info("Using MirroredStrategy with devices %r", devices)
device_spec = tf_device.DeviceSpec.from_string(
self._input_workers_devices[0][0])
# Ensures when we enter strategy.scope() we use the correct default device
@ -652,8 +691,16 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return updated_config
def _get_cross_device_ops(self, value):
del value # Unused.
return self._cross_device_ops or self._inferred_cross_device_ops
if isinstance(value, values.DistributedValues):
value_int32 = True in {
dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values
}
else:
value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32
if self._collective_ops_in_use and value_int32:
return cross_device_ops_lib.ReductionToOneDevice()
else:
return self._cross_device_ops or self._inferred_cross_device_ops
def _gather_to_implementation(self, value, destinations, axis, options):
if not isinstance(value, values.DistributedValues):
@ -677,6 +724,12 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
# be 0.
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, value, destinations, self._num_replicas_in_sync)
if self._collective_ops_in_use and (
(not cross_device_ops_lib._devices_match(value, destinations) or # pylint: disable=protected-access
any("cpu" in d.lower()
for d in cross_device_ops_lib.get_devices_from(destinations)))):
return cross_device_ops_lib.ReductionToOneDevice().reduce(
reduce_op, value, destinations)
return self._get_cross_device_ops(value).reduce(
reduce_op,
value,

View File

@ -39,6 +39,7 @@ from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import test_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
@ -1313,10 +1314,7 @@ class MultiWorkerMirroredStrategyTestWithChief(
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
strategy = mirrored_strategy.MirroredStrategy()
if context.num_gpus() > 0:
self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
cross_device_ops_lib.NcclAllReduce)
else:
if context.num_gpus() == 0:
self.assertIsInstance(strategy.extended._inferred_cross_device_ops,
cross_device_ops_lib.ReductionToOneDevice)
self.skipTest("b/130551176, run the following once fixed.")
@ -1437,4 +1435,5 @@ def _replica_id_as_int():
if __name__ == "__main__":
test.main()
# TODO(b/172304955)
test_util.main(config_logical_devices=False)

View File

@ -119,6 +119,12 @@ def _get_tpu_strategy_creator(steps_per_run,
return _create_tpu_strategy
def _mirrored_strategy_with_collective_key_base(devices):
mirrored_lib.MirroredStrategyV1._collective_key_base += 100000
mirrored_lib.MirroredStrategy._collective_key_base += 100000
return MirroredStrategy(devices)
def _get_multi_worker_mirrored_creator(required_gpus):
def _create_multi_worker_mirrored():
@ -244,20 +250,24 @@ cloud_tpu_strategy = combinations.NamedDistribution(
required_tpu=True,
use_cloud_tpu=True)
mirrored_strategy_with_one_cpu = combinations.NamedDistribution(
"Mirrored1CPU", lambda: MirroredStrategy(["/cpu:0"]))
"Mirrored1CPU",
lambda: _mirrored_strategy_with_collective_key_base(["/cpu:0"]))
mirrored_strategy_with_one_gpu = combinations.NamedDistribution(
"Mirrored1GPU", lambda: MirroredStrategy(["/gpu:0"]), required_gpus=1)
"Mirrored1GPU",
lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0"]),
required_gpus=1)
mirrored_strategy_with_gpu_and_cpu = combinations.NamedDistribution(
"MirroredCPUAndGPU",
lambda: MirroredStrategy(["/gpu:0", "/cpu:0"]),
lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/cpu:0"]),
required_gpus=1)
mirrored_strategy_with_two_gpus = combinations.NamedDistribution(
"Mirrored2GPUs",
lambda: MirroredStrategy(["/gpu:0", "/gpu:1"]),
lambda: _mirrored_strategy_with_collective_key_base(["/gpu:0", "/gpu:1"]),
required_gpus=2)
# Should call set_virtual_cpus_to_at_least(3) in your test's setUp methods.
mirrored_strategy_with_cpu_1_and_2 = combinations.NamedDistribution(
"Mirrored2CPU", lambda: MirroredStrategy(["/cpu:1", "/cpu:2"]))
"Mirrored2CPU",
lambda: _mirrored_strategy_with_collective_key_base(["/cpu:1", "/cpu:2"]))
mirrored_strategy_with_cpu_1_and_2.__doc__ = (
"""Mirrored strategy with 2 virtual CPUs.

View File

@ -248,10 +248,20 @@ class GatherTest(test.TestCase, parameterized.TestCase):
elif isinstance(
strategy,
(mirrored_strategy.MirroredStrategy,
central_storage_strategy.CentralStorageStrategy)) and pure_eager:
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Ranks of all input tensors should match'):
run()
central_storage_strategy.CentralStorageStrategy)):
if pure_eager:
with self.assertRaises(errors.InvalidArgumentError) as e:
run()
# Different error message depending on whether collective ops is used.
self.assertRegexMatch(
str(e.exception),
['Ranks of all input tensors should match', 'Shape mismatch'])
else:
with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e:
run()
self.assertRegexMatch(
str(e.exception),
[r'Shape must be rank \d but is rank \d', 'Shape mismatch'])
elif _is_tpu_strategy(strategy) and pure_eager:
with self.assertRaisesRegex(ValueError,
r'Dimension \d in both shapes must be equal'):
@ -562,13 +572,18 @@ class GatherTest(test.TestCase, parameterized.TestCase):
(mirrored_strategy.MirroredStrategy,
central_storage_strategy.CentralStorageStrategy)):
if pure_eager:
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'Ranks of all input tensors should match'):
with self.assertRaises(errors.InvalidArgumentError) as e:
strategy.run(run, args=(per_replica_value,))
# Different error message depending on whether collective ops is used.
self.assertRegexMatch(
str(e.exception),
['Ranks of all input tensors should match', 'Shape mismatch'])
else:
with self.assertRaisesRegex(ValueError,
r'Shape must be rank \d but is rank \d'):
with self.assertRaises((errors.InvalidArgumentError, ValueError)) as e:
strategy.run(run, args=(per_replica_value,))
self.assertRegexMatch(
str(e.exception),
[r'Shape must be rank \d but is rank \d', 'Shape mismatch'])
else:
with self.assertRaisesRegex(ValueError,
r'Dimension \d in both shapes must be equal'):