Retire MultiWorkerAllReduce
MultiWorkerAllReduce is for in graph replicated multi worker training, which is no longer supported. MultiWorkerAllReduce is no longer exposed so we shouldn't expect any usage of it. Existing tests with MultiWorkerAllReduce are changed to use CollectiveAllReduce. It's not officially supported but we should grandfather existing tests since it's non-trivial to migrate them to MultiWorkerMirroredStrategy. PiperOrigin-RevId: 328685198 Change-Id: Ib3e65fa772f05ca6a69f9e98864ac42bfdf4af51
This commit is contained in:
parent
9394966102
commit
bf0b6db744
@ -802,117 +802,6 @@ class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
|
||||
num_packs=num_packs)
|
||||
|
||||
|
||||
class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
|
||||
"""All-reduce algorithms for distributed TensorFlow."""
|
||||
|
||||
def __init__(self,
|
||||
worker_devices,
|
||||
num_gpus_per_worker,
|
||||
all_reduce_spec=("pscpu/pscpu", 2, -1),
|
||||
num_packs=0):
|
||||
"""Initialize the all-reduce algorithm.
|
||||
|
||||
Args:
|
||||
worker_devices: a list of device strings for workers participating in
|
||||
all-reduce.
|
||||
num_gpus_per_worker: number of GPU devices per worker.
|
||||
all_reduce_spec: a tuple or a named tuple or a list of tuples specifying
|
||||
the all-reduce algorithm.
|
||||
1. The first element of a tuple is the name of the all-reduce algorithm.
|
||||
Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd",
|
||||
"nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with
|
||||
a "/" are hierarchical, so two all-reduces are executed, the first one
|
||||
aggregates tensors within a worker and the second aggregates across
|
||||
workers.
|
||||
2. The second element of a tuple is the number of shards when doing
|
||||
all-reduce. Let's say its values is M, each tensor after packing will be
|
||||
split into M shards and then M parallel all-reduces would be performed
|
||||
before finally they are concatenated backed into a complete tensor.
|
||||
3. The third element is the maximum size of tensors that will be
|
||||
applicable for the algorithm specified by the first element. For
|
||||
example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)],
|
||||
tensors with size not larger than 1024 bytes will be applied a 2-shard
|
||||
"nccl" all-reduce and other tensors will be applied a 2-shard
|
||||
"pscpu/pscpu" algorithm. The third elements should be in increasing
|
||||
order across tuples and end with -1 which indicates infinity.
|
||||
num_packs: see AllReduceCrossDeviceOps.
|
||||
"""
|
||||
self._worker_devices = worker_devices
|
||||
self._num_gpus_per_worker = num_gpus_per_worker
|
||||
super(MultiWorkerAllReduce, self).__init__(num_packs=num_packs)
|
||||
|
||||
def validate_and_complete_spec(spec):
|
||||
"""Validate and complete the all-reduce spec."""
|
||||
# TODO(yuefengz): support namedtuple.
|
||||
if not isinstance(spec, tuple):
|
||||
raise ValueError(
|
||||
"A tuple is expected for all-reduce spec: %r" % all_reduce_spec)
|
||||
if not spec or len(spec) > 3:
|
||||
raise ValueError(
|
||||
"Too many elements in the all-reduce spec tuple: %r" % spec)
|
||||
if len(spec) == 1:
|
||||
return AllReduceSpecTuple(spec[0], 1, -1)
|
||||
elif len(spec) == 2:
|
||||
return AllReduceSpecTuple(spec[0], spec[1], -1)
|
||||
else:
|
||||
return AllReduceSpecTuple(*spec)
|
||||
|
||||
self._all_reduce_spec = []
|
||||
if isinstance(all_reduce_spec, six.string_types):
|
||||
self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1))
|
||||
elif isinstance(all_reduce_spec, tuple):
|
||||
self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec))
|
||||
elif isinstance(all_reduce_spec, list):
|
||||
self._all_reduce_spec = [
|
||||
validate_and_complete_spec(spec) for spec in all_reduce_spec
|
||||
]
|
||||
|
||||
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
||||
"""All-reduce algorithm in a batch."""
|
||||
logging.log_first_n(
|
||||
logging.INFO, "Distributed batch_all_reduce: %d all-reduces with "
|
||||
"allreduce_spec = %r, num_packs = %d" %
|
||||
(len(per_replica_values), self._all_reduce_spec, self._num_packs), 10)
|
||||
|
||||
device_grads = _group_value_by_device(per_replica_values)
|
||||
|
||||
# The all-reduce library requires fully defined shapes.
|
||||
# TODO(yuefengz): when tensor sharding is not needed, static shapes are not
|
||||
# required as well.
|
||||
for device_grad in device_grads:
|
||||
for grad, _ in device_grad:
|
||||
if not grad.shape.is_fully_defined():
|
||||
raise ValueError("Shape is unknown for node %r" % grad)
|
||||
|
||||
remaining_grads = device_grads
|
||||
aggregated_grads = []
|
||||
for spec_tuple in self._all_reduce_spec:
|
||||
if spec_tuple.limit < 0:
|
||||
this_grads = remaining_grads
|
||||
remaining_grads = []
|
||||
else:
|
||||
(this_grads, remaining_grads) = cross_device_utils.split_grads_by_size(
|
||||
spec_tuple.limit, remaining_grads)
|
||||
if this_grads:
|
||||
device_grad_packs, tensor_packer = _pack_tensors(
|
||||
this_grads, self._num_packs)
|
||||
range_agg_grads = cross_device_utils.sum_gradients_all_reduce(
|
||||
self._worker_devices, device_grad_packs, len(self._worker_devices),
|
||||
spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
|
||||
range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
|
||||
|
||||
if not aggregated_grads:
|
||||
aggregated_grads = range_agg_grads
|
||||
else:
|
||||
assert len(aggregated_grads) == len(range_agg_grads)
|
||||
for i, range_agg_grad in enumerate(range_agg_grads):
|
||||
aggregated_grads[i] += range_agg_grad
|
||||
assert not remaining_grads
|
||||
|
||||
return _ungroup_and_make_mirrored(aggregated_grads, per_replica_values[0],
|
||||
reduce_op)
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.CollectiveCommunication")
|
||||
class CollectiveCommunication(enum.Enum):
|
||||
"""Communication choices for CollectiveOps.
|
||||
|
@ -433,55 +433,6 @@ class SingleWorkerCrossDeviceOpsTest(CrossDeviceOpsTestBase):
|
||||
self.assertAllEqual(self.evaluate(result.values), [1.0, 1.0])
|
||||
|
||||
|
||||
class MultiWorkerCrossDeviceOpsTest(multi_worker_test_base.MultiWorkerTestBase,
|
||||
CrossDeviceOpsTestBase):
|
||||
|
||||
worker_devices = [
|
||||
"/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
|
||||
]
|
||||
multi_worker_allreduce_combinations = combinations.combine(
|
||||
cross_device_ops=[
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReduce",
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(worker_devices, 2,
|
||||
("pscpu/pscpu", 2, -1),
|
||||
0)),
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReducePack",
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(worker_devices, 2,
|
||||
("pscpu/pscpu", 2, -1),
|
||||
1)),
|
||||
combinations.NamedObject(
|
||||
"MultiWorkerAllReduceMultipleSpecs",
|
||||
cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
worker_devices, 2, [("pscpu/pscpu", 2, 100),
|
||||
("xring", 2, -1)], 0)),
|
||||
],
|
||||
devices=[
|
||||
[
|
||||
"/job:worker/replica:0/task:0/device:CPU:0",
|
||||
"/job:worker/replica:0/task:1/device:CPU:0"
|
||||
],
|
||||
[
|
||||
"/job:worker/replica:0/task:0/device:GPU:0",
|
||||
"/job:worker/replica:0/task:1/device:GPU:0"
|
||||
],
|
||||
[
|
||||
"/job:worker/replica:0/task:0/device:GPU:0",
|
||||
"/job:worker/replica:0/task:0/device:GPU:1",
|
||||
"/job:worker/replica:0/task:1/device:GPU:0",
|
||||
"/job:worker/replica:0/task:1/device:GPU:1"
|
||||
],
|
||||
],
|
||||
mode=["graph"])
|
||||
|
||||
@combinations.generate(multi_worker_allreduce_combinations)
|
||||
def testReductionAndBroadcast(self, cross_device_ops, devices):
|
||||
# Mimic the default device of multi-worker strategies.
|
||||
with ops.device("/job:worker/replica:0/task:0"):
|
||||
self._testReductionAndBroadcast(cross_device_ops, devices)
|
||||
|
||||
|
||||
NUM_WORKERS = 3
|
||||
|
||||
CollectiveCommunication = cross_device_ops_lib.CollectiveCommunication
|
||||
|
@ -976,8 +976,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), self.cached_session() as sess:
|
||||
if auto_shard_policy == AutoShardPolicy.AUTO:
|
||||
@ -1003,8 +1002,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
|
||||
worker_devices = self._cpu_devices()
|
||||
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
||||
|
||||
@ -1042,8 +1040,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=(self._cpu_and_one_gpu_devices()[0][1] +
|
||||
self._cpu_and_one_gpu_devices()[1][1]),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 2))
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
||||
|
||||
@ -1075,8 +1072,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
enable_get_next_as_optional):
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=(self._cpu_devices()[0][1] + self._cpu_devices()[1][1]),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 1))
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
|
||||
worker_devices = self._cpu_devices()
|
||||
|
||||
def dataset_fn(ctx):
|
||||
@ -1118,8 +1114,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=(self._cpu_and_one_gpu_devices()[0][1] +
|
||||
self._cpu_and_one_gpu_devices()[1][1]),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 2))
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(9).batch(2)
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
@ -1231,8 +1226,7 @@ class DistributedIteratorMultiWorkerTest(
|
||||
strategy = mirrored_strategy.MirroredStrategy(
|
||||
devices=(self._cpu_and_one_gpu_devices()[0][1] +
|
||||
self._cpu_and_one_gpu_devices()[1][1]),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:worker/task:0", "/job:worker/task:1"], 2))
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice())
|
||||
worker_devices = self._cpu_and_one_gpu_devices()
|
||||
with context.graph_mode(), strategy.scope(), self.cached_session() as sess:
|
||||
|
||||
|
@ -86,7 +86,7 @@ def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker):
|
||||
for task_type in ("chief", "worker"):
|
||||
for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
|
||||
if num_gpus_per_worker == 0:
|
||||
devices.append("/job:%s/task:%d" % (task_type, task_id))
|
||||
devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id))
|
||||
else:
|
||||
devices.extend([
|
||||
"/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id)
|
||||
@ -378,8 +378,10 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
self._is_multi_worker_training = True
|
||||
|
||||
if len(workers) > 1:
|
||||
if not isinstance(self._cross_device_ops,
|
||||
cross_device_ops_lib.MultiWorkerAllReduce):
|
||||
# Grandfather usage in the legacy tests if they're configured properly.
|
||||
if (not isinstance(self._cross_device_ops,
|
||||
cross_device_ops_lib.ReductionToOneDevice) or
|
||||
self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access
|
||||
raise ValueError(
|
||||
"In-graph multi-worker training with `MirroredStrategy` is not "
|
||||
"supported.")
|
||||
|
@ -1148,9 +1148,9 @@ class MirroredStrategyDefunTest(test.TestCase):
|
||||
# pylint: disable=g-long-lambda
|
||||
lambda: mirrored_strategy.MirroredStrategy(
|
||||
devices=mirrored_strategy.all_local_devices(),
|
||||
cross_device_ops=cross_device_ops_lib.MultiWorkerAllReduce([
|
||||
"/job:worker/task:0", "/job:worker/task:1"
|
||||
], context.num_gpus())),
|
||||
cross_device_ops=cross_device_ops_lib.ReductionToOneDevice(
|
||||
),
|
||||
),
|
||||
required_gpus=1)
|
||||
],
|
||||
mode=["graph"]))
|
||||
@ -1288,9 +1288,7 @@ class MultiWorkerMirroredStrategyTestWithChief(
|
||||
cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
|
||||
|
||||
def _make_cross_device_ops(self):
|
||||
return cross_device_ops_lib.MultiWorkerAllReduce(
|
||||
["/job:chief/task:0", "/job:worker/task:0", "/job:worker/task:1"],
|
||||
context.num_gpus())
|
||||
return cross_device_ops_lib.ReductionToOneDevice()
|
||||
|
||||
def testMinimizeLossGraph(self):
|
||||
with context.graph_mode():
|
||||
|
Loading…
Reference in New Issue
Block a user