We can only order NCCL for collective V2. This allows to enable NCCL for all collectives just for V2 kernels, and leaving TF1 users unaffected. PiperOrigin-RevId: 343958688 Change-Id: Ib01309e220670d09f78dab7c8f1ae1e01a872f91
1337 lines
54 KiB
Python
1337 lines
54 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Classes for different algorithms of reduction and broadcasting."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import copy
|
|
import threading
|
|
|
|
import six
|
|
|
|
from tensorflow.python.client import device_lib
|
|
from tensorflow.python.distribute import collective_util
|
|
from tensorflow.python.distribute import cross_device_utils
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.distribute import distribute_utils
|
|
from tensorflow.python.distribute import ps_values
|
|
from tensorflow.python.distribute import reduce_util
|
|
from tensorflow.python.distribute import tpu_values
|
|
from tensorflow.python.distribute import values as value_lib
|
|
from tensorflow.python.distribute import values_util
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import executor as executor_lib
|
|
from tensorflow.python.framework import kernels
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
from tensorflow.tools.docs import doc_controls
|
|
|
|
|
|
def check_destinations(destinations):
|
|
"""Checks whether `destinations` is not empty.
|
|
|
|
Args:
|
|
destinations: a `DistributedValues`, variable, or string object.
|
|
|
|
Returns:
|
|
Boolean which is True if `destinations` is not empty.
|
|
"""
|
|
# Calling bool() on a ResourceVariable is not allowed.
|
|
if isinstance(destinations,
|
|
(resource_variable_ops.BaseResourceVariable, ops.Tensor)):
|
|
return bool(destinations.device)
|
|
return bool(destinations)
|
|
|
|
|
|
def validate_destinations(destinations):
|
|
"""Validates the `destination` is one of expected types."""
|
|
if not isinstance(
|
|
destinations,
|
|
(value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable,
|
|
six.string_types, tpu_values.TPUMirroredVariable
|
|
)) and not resource_variable_ops.is_resource_variable(destinations):
|
|
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
|
" a tf.Variable object, or a device string.")
|
|
|
|
if not check_destinations(destinations):
|
|
raise ValueError("destinations can not be empty")
|
|
|
|
|
|
def reduce_non_distributed_value(
|
|
reduce_op, value, destinations, num_replicas_in_graph):
|
|
"""Reduce a non-DistributedValue `value` to `destinations`."""
|
|
if isinstance(value, value_lib.DistributedValues):
|
|
raise ValueError("You are passing a `DistributedValues` to "
|
|
"`reduce_non_distributed_value`, which is not allowed.")
|
|
|
|
# If the same value is present on all replicas then the PerReplica value will
|
|
# be a single value. We also handle the case when `value` is a single value
|
|
# and equal to 0.
|
|
# TODO:(b/138823479): handle the tensor value properly.
|
|
if not tensor_util.is_tensor(value) and value == 0:
|
|
return 0
|
|
# If there is only a single value and the reduce op is MEAN,
|
|
# that value should be on all destinations.
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
return value
|
|
elif num_replicas_in_graph != 1:
|
|
# We do not support a reduce op of SUM if the value is the same across
|
|
# all replicas. We call this as part of assign functions for
|
|
# MirroredVariables and summing up identical values across replicas is not
|
|
# clearly defined.
|
|
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
|
|
"the given reduce op %s." % (value, reduce_op))
|
|
else:
|
|
validate_destinations(destinations)
|
|
return simple_broadcast(value, destinations)
|
|
|
|
|
|
def _make_tensor_into_per_replica(input_tensor):
|
|
"""Converts a single tensor into a PerReplica object."""
|
|
if isinstance(input_tensor, (tuple, list)):
|
|
raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
|
|
"got %r but expected a object that is not a tuple or list."
|
|
% (input_tensor,))
|
|
if isinstance(input_tensor, value_lib.PerReplica):
|
|
return input_tensor
|
|
elif hasattr(input_tensor, "device"):
|
|
return value_lib.PerReplica((input_tensor,))
|
|
else:
|
|
raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
|
|
"because it doesn't have device set.")
|
|
|
|
|
|
def _normalize_value_destination_pairs(value_destination_pairs):
|
|
"""Converts each tensor into a PerReplica object in the input list."""
|
|
result = []
|
|
|
|
value_destination_pairs = list(value_destination_pairs)
|
|
|
|
if not isinstance(value_destination_pairs, (list, tuple)):
|
|
raise ValueError("`value_destination_pairs` should be a list or tuple")
|
|
for pair in value_destination_pairs:
|
|
if not isinstance(pair, tuple):
|
|
raise ValueError(
|
|
"Each element of `value_destination_pairs` should be a tuple.")
|
|
if len(pair) != 2:
|
|
raise ValueError("Each element of `value_destination_pairs` should be a "
|
|
"tuple of size 2.")
|
|
|
|
per_replica = _make_tensor_into_per_replica(pair[0])
|
|
result.append((per_replica, pair[1]))
|
|
return result
|
|
|
|
|
|
def _validate_value_destination_pairs(value_destination_pairs):
|
|
"""Validates value_destination_pairs are valid."""
|
|
# TODO(yuefengz): raise exceptions instead of returning False.
|
|
if not value_destination_pairs: return False
|
|
if not isinstance(value_destination_pairs, (list, tuple)): return False
|
|
if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
|
|
return False
|
|
if not all(isinstance(v[0], value_lib.PerReplica)
|
|
for v in value_destination_pairs):
|
|
return False
|
|
return True
|
|
|
|
|
|
# TODO(yuefengz): consider calling this function in the caller of
|
|
# CrossDeviceOps.
|
|
def get_devices_from(destinations):
|
|
if isinstance(destinations, value_lib.DistributedValues):
|
|
return destinations._devices # pylint: disable=protected-access
|
|
elif isinstance(destinations, six.string_types):
|
|
return (device_util.resolve(destinations),)
|
|
return (device_util.resolve(destinations.device),)
|
|
|
|
|
|
def _devices_match(left, right):
|
|
return left is right or set(get_devices_from(left)) == set(
|
|
get_devices_from(right))
|
|
|
|
|
|
def _all_devices_match(value_destination_pairs):
|
|
if not all(_devices_match(v, d) for v, d in value_destination_pairs):
|
|
return False
|
|
if not all(_devices_match(v, value_destination_pairs[0][0])
|
|
for v, _ in value_destination_pairs[1:]):
|
|
return False
|
|
return True
|
|
|
|
|
|
def simple_broadcast(value, destinations, always_mirrored=False):
|
|
"""Broadcast `value` to `destinations` using simple copies."""
|
|
devices = get_devices_from(destinations)
|
|
if len(devices) == 1 and not always_mirrored:
|
|
return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
|
|
value, devices[0])
|
|
else:
|
|
value_updates = []
|
|
for d in devices:
|
|
value_updates.append(
|
|
cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
|
|
return distribute_utils.regroup(value_updates,
|
|
wrap_class=value_lib.Mirrored)
|
|
|
|
|
|
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
|
reduce_op):
|
|
"""Reduces the value by accumulation_fn and reduce_op."""
|
|
all_values = per_replica_value.values
|
|
if not all_values:
|
|
raise ValueError("`per_replica_value` must be non-empty")
|
|
count = len(all_values)
|
|
|
|
with ops.device(reduce_to_device):
|
|
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
|
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
|
|
all_values, accumulation_fn)
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
|
|
reduced, count)
|
|
elif reduce_op != reduce_util.ReduceOp.SUM:
|
|
raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
|
|
return reduced
|
|
|
|
|
|
def _simple_gather(per_replica_value, reduce_to_device, axis):
|
|
"""Concatenate all values in the DistributedValues input and return."""
|
|
all_values = per_replica_value.values
|
|
if not all_values:
|
|
raise ValueError("`per_replica_value` must be non-empty")
|
|
|
|
with ops.device(reduce_to_device):
|
|
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
|
gathered = array_ops.concat(all_values, axis)
|
|
return gathered
|
|
|
|
|
|
@tf_export("distribute.CrossDeviceOps")
|
|
class CrossDeviceOps(object):
|
|
"""Base class for cross-device reduction and broadcasting algorithms.
|
|
|
|
The main purpose of this class is to be passed to
|
|
`tf.distribute.MirroredStrategy` in order to choose among different cross
|
|
device communication implementations. Prefer using the methods of
|
|
`tf.distribute.Strategy` instead of the ones of this class.
|
|
|
|
Implementations:
|
|
* `tf.distribute.ReductionToOneDevice`
|
|
* `tf.distribute.NcclAllReduce`
|
|
* `tf.distribute.HierarchicalCopyAllReduce`
|
|
"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@property
|
|
def _num_between_graph_workers(self):
|
|
# Returns 1 by default, the value may be overridden by sub classes.
|
|
return 1
|
|
|
|
def reduce(self, reduce_op, per_replica_value, destinations, options=None):
|
|
"""Reduce `per_replica_value` to `destinations`.
|
|
|
|
See `tf.distribute.StrategyExtended.reduce_to`. This can only be called in
|
|
the cross-replica context.
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
|
|
combined.
|
|
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
|
|
like object.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to reduce to. To perform an all-reduce, pass the same to `value` and
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
|
to the devices of that variable, and this method doesn't update the
|
|
variable.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details.
|
|
|
|
Returns:
|
|
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
|
|
|
Raises:
|
|
ValueError: if per_replica_value can't be converted to a
|
|
`tf.distribute.DistributedValues` or if destinations is not a string,
|
|
`tf.Variable` or `tf.distribute.DistributedValues`.
|
|
"""
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
if not isinstance(per_replica_value, value_lib.DistributedValues):
|
|
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
|
|
|
validate_destinations(destinations)
|
|
|
|
# Shortcut if `per_replica_value` only contains one value.
|
|
if self._num_between_graph_workers == 1 and len(
|
|
per_replica_value.values) == 1 and _devices_match(
|
|
per_replica_value, destinations):
|
|
with ops.device(per_replica_value.values[0].device):
|
|
v = array_ops.identity(per_replica_value.values[0])
|
|
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
|
|
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
return self.reduce_implementation(reduce_op, per_replica_value,
|
|
destinations, options)
|
|
|
|
def _gather(self, per_replica_value, destinations, axis, options=None):
|
|
"""Gather `per_replica_value` to `destinations`.
|
|
|
|
Args:
|
|
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
|
|
like object.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to gather to. To perform an all-gather, pass the same to `value` and
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is gathered
|
|
to the devices of that variable, and this method doesn't update the
|
|
variable.
|
|
axis: specifies the dimension to gather along within each replica's
|
|
tensor.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details.
|
|
|
|
Returns:
|
|
A `tf.Tensor` or `tf.distribute.DistributedValues`
|
|
|
|
Raises:
|
|
ValueError: if per_replica_value can't be converted to a
|
|
`tf.distribute.DistributedValues` or if destinations is not a string,
|
|
`tf.Variable` or `tf.distribute.DistributedValues`.
|
|
"""
|
|
if isinstance(per_replica_value, ops.IndexedSlices):
|
|
raise NotImplementedError("gather/all_gather does not support "
|
|
"IndexedSlices")
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
|
|
if not isinstance(per_replica_value, value_lib.DistributedValues):
|
|
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
|
|
|
validate_destinations(destinations)
|
|
|
|
# Shortcut if `per_replica_value` only contains one value.
|
|
if self._num_between_graph_workers == 1 and len(
|
|
per_replica_value.values) == 1 and _devices_match(
|
|
per_replica_value, destinations):
|
|
with ops.device(per_replica_value.values[0].device):
|
|
v = array_ops.identity(per_replica_value.values[0])
|
|
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
|
|
|
|
return self._gather_implementation(per_replica_value, destinations, axis,
|
|
options)
|
|
|
|
def _gather_implementation(self, per_replica_value, destinations, axis,
|
|
options):
|
|
"""Implementation of `gather` method of `tf.distribute.CrossDeviceOps`.
|
|
|
|
Overriding this method is useful for subclass implementers.
|
|
|
|
Args:
|
|
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
|
|
like object.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to gather to. To perform an all-gather, pass the same to `value` and
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is gathered
|
|
to the devices of that variable, this method doesn't update the
|
|
variable.
|
|
axis: specifies the dimension to gather along within each replica's
|
|
tensor.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details.
|
|
|
|
Returns:
|
|
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
|
|
|
Raises:
|
|
ValueError: if per_replica_value can't be converted to a
|
|
`tf.distribute.DistributedValues` or if destinations is not a string,
|
|
`tf.Variable` or `tf.distribute.DistributedValues`.
|
|
"""
|
|
raise NotImplementedError(
|
|
"_gather method must be implemented in descendants.")
|
|
|
|
def batch_reduce(self, reduce_op, value_destination_pairs, options=None):
|
|
"""Reduce values to destinations in batches.
|
|
|
|
See `tf.distribute.StrategyExtended.batch_reduce_to`. This can only be
|
|
called in the cross-replica context.
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
|
|
combined.
|
|
value_destination_pairs: a sequence of (value, destinations) pairs. See
|
|
`tf.distribute.CrossDeviceOps.reduce` for descriptions.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details.
|
|
|
|
Returns:
|
|
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
|
|
in `value_destination_pairs`.
|
|
|
|
Raises:
|
|
ValueError: if `value_destination_pairs` is not an iterable of
|
|
tuples of `tf.distribute.DistributedValues` and destinations.
|
|
"""
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
# TODO(yuefengz): if destinations are different, split into several
|
|
# `_batch_reduce` invocations.
|
|
if not _validate_value_destination_pairs(value_destination_pairs):
|
|
# If the first element of each pair is a tensor, we try to turn it into a
|
|
# PerReplica object.
|
|
value_destination_pairs = _normalize_value_destination_pairs(
|
|
value_destination_pairs)
|
|
|
|
for _, d in value_destination_pairs:
|
|
validate_destinations(d)
|
|
|
|
# Shortcut all PerReplica objects only contain one value.
|
|
if self._num_between_graph_workers == 1 and _all_devices_match(
|
|
value_destination_pairs) and len(
|
|
value_destination_pairs[0][0].values) == 1:
|
|
return [
|
|
distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
|
|
for v, _ in value_destination_pairs
|
|
]
|
|
|
|
if options is None:
|
|
options = collective_util.Options()
|
|
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
|
|
options)
|
|
|
|
def broadcast(self, tensor, destinations):
|
|
"""Broadcast `tensor` to `destinations`.
|
|
|
|
This can only be called in the cross-replica context.
|
|
|
|
Args:
|
|
tensor: a `tf.Tensor` like object. The value to broadcast.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to broadcast to. Note that if it's a `tf.Variable`, the value is
|
|
broadcasted to the devices of that variable, this method doesn't update
|
|
the variable.
|
|
|
|
Returns:
|
|
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
|
"""
|
|
validate_destinations(destinations)
|
|
return self.broadcast_implementation(tensor, destinations)
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
options):
|
|
"""Implementation of `reduce`.
|
|
|
|
Overriding this method is useful for subclass implementers.
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
|
|
combined.
|
|
per_replica_value: a `tf.distribute.DistributedValues`, or a `tf.Tensor`
|
|
like object.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to reduce to. To perform an all-reduce, pass the same to `value` and
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is reduced
|
|
to the devices of that variable, this method doesn't update the
|
|
variable.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details.
|
|
|
|
Returns:
|
|
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
|
|
|
Raises:
|
|
ValueError: if per_replica_value can't be converted to a
|
|
`tf.distribute.DistributedValues` or if destinations is not a string,
|
|
`tf.Variable` or `tf.distribute.DistributedValues`.
|
|
"""
|
|
raise NotImplementedError(
|
|
"_reduce method must be implemented in descendants.")
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
options):
|
|
"""Implementation of `batch_reduce`.
|
|
|
|
Overriding this method is useful for subclass implementers.
|
|
|
|
Args:
|
|
reduce_op: a `tf.distribute.ReduceOp` specifying how values should be
|
|
combined.
|
|
value_destination_pairs: a sequence of (value, destinations) pairs. See
|
|
`reduce` for descriptions.
|
|
options: a `tf.distribute.experimental.CommunicationOptions`. See
|
|
`tf.distribute.experimental.CommunicationOptions` for details.
|
|
|
|
Returns:
|
|
A list of `tf.Tensor` or `tf.distribute.DistributedValues`, one per pair
|
|
in `value_destination_pairs`.
|
|
|
|
Raises:
|
|
ValueError: if `value_destination_pairs` is not an iterable of
|
|
tuples of `tf.distribute.DistributedValues` and destinations.
|
|
"""
|
|
raise NotImplementedError(
|
|
"batch_reduce_implementation method must be implemented in descendants."
|
|
)
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def broadcast_implementation(self, tensor, destinations):
|
|
"""Implementation of `broadcast`.
|
|
|
|
Args:
|
|
tensor: a `tf.Tensor` like object. The value to broadcast.
|
|
destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a
|
|
`tf.Tensor` alike object, or a device string. It specifies the devices
|
|
to broadcast to.
|
|
`destinations`. Note that if it's a `tf.Variable`, the value is
|
|
broadcasted to the devices of that variable, this method doesn't update
|
|
the variable.
|
|
|
|
Returns:
|
|
A `tf.Tensor` or `tf.distribute.DistributedValues`.
|
|
"""
|
|
return simple_broadcast(tensor, destinations, always_mirrored=True)
|
|
|
|
|
|
@tf_export("distribute.ReductionToOneDevice")
|
|
class ReductionToOneDevice(CrossDeviceOps):
|
|
"""A CrossDeviceOps implementation that copies values to one device to reduce.
|
|
|
|
This implementation always copies values to one device to reduce them, then
|
|
broadcast reduced values to the destinations. It doesn't support efficient
|
|
batching.
|
|
|
|
Here is how you can use `ReductionToOneDevice` in
|
|
`tf.distribute.MirroredStrategy`:
|
|
|
|
```
|
|
strategy = tf.distribute.MirroredStrategy(
|
|
cross_device_ops=tf.distribute.ReductionToOneDevice())
|
|
```
|
|
"""
|
|
|
|
def __init__(self, reduce_to_device=None, accumulation_fn=None):
|
|
"""Initializes with a device to reduce to and a way to accumulate.
|
|
|
|
Args:
|
|
reduce_to_device: the intermediate device to reduce to. If None, reduce
|
|
to the first device in `destinations` of the `reduce` method.
|
|
accumulation_fn: a function that does accumulation. If None,
|
|
`tf.math.add_n` is used.
|
|
"""
|
|
self.reduce_to_device = reduce_to_device
|
|
self.accumulation_fn = accumulation_fn or math_ops.add_n
|
|
super(ReductionToOneDevice, self).__init__()
|
|
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
options):
|
|
del options # Unused.
|
|
if check_destinations(destinations):
|
|
devices = get_devices_from(destinations)
|
|
else:
|
|
devices = get_devices_from(per_replica_value)
|
|
reduce_to_device = self.reduce_to_device or devices[0]
|
|
logging.log_first_n(
|
|
logging.INFO,
|
|
"Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
|
|
reduced = _simple_reduce(per_replica_value, reduce_to_device,
|
|
self.accumulation_fn, reduce_op)
|
|
return self.broadcast(reduced, destinations)
|
|
|
|
def _gather_implementation(self, per_replica_value, destinations, axis,
|
|
options):
|
|
del options # Unused.
|
|
if check_destinations(destinations):
|
|
devices = get_devices_from(destinations)
|
|
else:
|
|
devices = get_devices_from(per_replica_value)
|
|
reduce_to_device = self.reduce_to_device or devices[0]
|
|
logging.log_first_n(
|
|
logging.INFO,
|
|
"Gather to %s then broadcast to %r." % (reduce_to_device, devices), 10)
|
|
gathered = _simple_gather(per_replica_value, reduce_to_device, axis)
|
|
return self.broadcast(gathered, destinations)
|
|
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
options):
|
|
return [
|
|
self.reduce_implementation(
|
|
reduce_op, t, destinations=v, options=options)
|
|
for t, v in value_destination_pairs
|
|
]
|
|
|
|
|
|
def _group_value_by_device(per_replica_values):
|
|
"""Group values into sublists by their devices.
|
|
|
|
This grouping is needed to call the all-reduce library because it expects a
|
|
list of the following form:
|
|
[[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
|
|
[(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
|
|
[(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
|
|
...
|
|
]
|
|
|
|
Args:
|
|
per_replica_values: a list of PerReplica objects.
|
|
|
|
Returns:
|
|
a list of lists, each sublist has components for its corresponding device of
|
|
PerReplica objects, paired with a None.
|
|
"""
|
|
destinations = per_replica_values[0]._devices # pylint: disable=protected-access
|
|
grouped = [[] for _ in range(len(destinations))]
|
|
for per_replica_value in per_replica_values:
|
|
# pylint: disable=protected-access
|
|
for i, v in enumerate(per_replica_value.values):
|
|
assert per_replica_value._devices == destinations
|
|
grouped[i].append((v, None))
|
|
return grouped
|
|
|
|
|
|
def _ungroup_and_make_mirrored(grouped_reduced,
|
|
destinations,
|
|
reduce_op,
|
|
num_between_graph_workers=1):
|
|
"""Ungroup results from all-reduce and make Mirrored objects.
|
|
|
|
Each all-reduce result will be divided by the number of destinations before
|
|
Mirrored objects are created if reduce_op is "mean".
|
|
|
|
Args:
|
|
grouped_reduced: a list of lists, each sublist has components for each
|
|
device, paired with a None. It is the result from
|
|
cross_device_utils.aggregate_gradients_using*.
|
|
destinations: a value to colocate the result with.
|
|
reduce_op: Indicates how values will be aggregated. Accepted values
|
|
are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
|
|
num_between_graph_workers: number of workers in the between-graph
|
|
replication.
|
|
|
|
Returns:
|
|
a list of Mirrored objects.
|
|
"""
|
|
num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
|
|
index = [[] for _ in range(len(grouped_reduced[0]))]
|
|
for per_replica_reduced in grouped_reduced:
|
|
for i, (v, _) in enumerate(per_replica_reduced):
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
with ops.device(v.device):
|
|
index[i].append(v / num_replicas)
|
|
else:
|
|
index[i].append(v)
|
|
return [distribute_utils.regroup(
|
|
v, wrap_class=value_lib.Mirrored) for v in index]
|
|
|
|
|
|
class _ConcatAndSplitPacker(object):
|
|
"""Concatenate and split tensors for reduction."""
|
|
|
|
def __init__(self, num_packs=1):
|
|
"""Initialize the _ConcatAndSplitPacker object.
|
|
|
|
Args:
|
|
num_packs: specifies the number of split packs that will be
|
|
formed.
|
|
|
|
Raises:
|
|
ValueError: if num_packs is not greater than 0.
|
|
"""
|
|
if num_packs <= 0:
|
|
raise ValueError("num_packs must be greater than zero.")
|
|
self.num_packs = num_packs
|
|
|
|
def pack(self, grouped_grads_and_vars):
|
|
"""Pack tensors."""
|
|
self.grouped_grads_and_vars = grouped_grads_and_vars
|
|
self.all_device_shapes = []
|
|
self.all_device_sizes = []
|
|
|
|
device_grad_packs = []
|
|
for device_grads_and_vars in grouped_grads_and_vars:
|
|
with ops.colocate_with(device_grads_and_vars[0][0]):
|
|
# Flatten all the grads.
|
|
flat_grads = [
|
|
array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
|
|
]
|
|
# Remember the original shape of all the grads.
|
|
device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
|
|
# Remember the original sizes of all the grads.
|
|
device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
|
|
# Concat all the flat grads into a big flat tensor.
|
|
concat_grads = array_ops.concat(flat_grads, 0)
|
|
|
|
# Split the big tensor into num_splits packs. In cases where the
|
|
# total size is not divisible num_splits, the last pack gets
|
|
# more elements.
|
|
# TODO(zhengxq): it is also possible to optimize away all the concat
|
|
# as well.
|
|
num_splits = self.num_packs
|
|
|
|
# The array_ops.size function will sometimes remove static shapes. So if
|
|
# all gradient shapes are defined, we use another method to get the
|
|
# total size.
|
|
# TODO(yuefengz): move this logic to array_ops.size.
|
|
if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
|
|
total_grad_size = sum(
|
|
[g.shape.num_elements() for g, _ in device_grads_and_vars])
|
|
else:
|
|
total_grad_size = array_ops.size(concat_grads)
|
|
|
|
split_size = total_grad_size // num_splits
|
|
split_size_last = total_grad_size - split_size * (num_splits - 1)
|
|
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
|
|
grad_packs = array_ops.split(concat_grads, split_sizes)
|
|
|
|
# Ready to aggregate the repacked gradients, with fake variables.
|
|
# TODO(zhengxq): It is hacky to have to use fake variables.
|
|
# We should remove the need for variables in
|
|
# aggregate_gradients_using*.
|
|
device_grad_packs.append(zip(grad_packs, [None] * num_splits))
|
|
self.all_device_shapes.append(device_shapes)
|
|
self.all_device_sizes.append(device_sizes)
|
|
|
|
return device_grad_packs
|
|
|
|
def unpack(self, summed_device_grad_packs):
|
|
"""Reverse the pack."""
|
|
aggregated_device_grads = []
|
|
for (summed_device_grad_packs,
|
|
device_grads_and_vars, device_shapes, device_sizes) in zip(
|
|
summed_device_grad_packs, self.grouped_grads_and_vars,
|
|
self.all_device_shapes, self.all_device_sizes):
|
|
# pylint: enable=line-too-long
|
|
# Reverse the packing operations in the previous steps. Form the
|
|
# summed gradients back into their original shapes.
|
|
with ops.colocate_with(summed_device_grad_packs[0][0]):
|
|
# Form a list of the summed grad packs.
|
|
device_grad_packs = [g for g, _ in summed_device_grad_packs]
|
|
|
|
# Concat them back into a big flat tensor.
|
|
device_grads_concat = array_ops.concat(device_grad_packs, 0)
|
|
|
|
# Split the tensors back into their original sizes.
|
|
grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
|
|
|
|
# Reshape the tensors back into their original shapes.
|
|
grads_with_shapes = [
|
|
array_ops.reshape(grad, shape)
|
|
for shape, grad in zip(device_shapes, grads_with_sizes)
|
|
]
|
|
|
|
# Form the list with the original list of variables.
|
|
summed_device_grads = [
|
|
(g, v) for g, (_, v) in zip(grads_with_shapes,
|
|
device_grads_and_vars)
|
|
]
|
|
aggregated_device_grads.append(summed_device_grads)
|
|
return aggregated_device_grads
|
|
|
|
|
|
def _pack_tensors(device_grads, num_packs=0):
|
|
"""Pack tensors if specified."""
|
|
if num_packs > 0:
|
|
tensor_packer = _ConcatAndSplitPacker(num_packs)
|
|
device_grad_packs = tensor_packer.pack(device_grads)
|
|
else:
|
|
tensor_packer = None
|
|
device_grad_packs = device_grads
|
|
return device_grad_packs, tensor_packer
|
|
|
|
|
|
def _unpack_tensors(reduced, tensor_packer=None):
|
|
"""Unpack tensors if they are packed before all-reduce."""
|
|
if tensor_packer:
|
|
return tensor_packer.unpack(reduced)
|
|
return reduced
|
|
|
|
|
|
class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|
"""All-reduce implementation of CrossDeviceOps.
|
|
|
|
It performs all-reduce when applicable using NCCL or hierarchical copy. For
|
|
the batch API, tensors will be repacked or aggregated for more efficient
|
|
cross-device transportation.
|
|
|
|
For reduces that are not all-reduce, it falls back to
|
|
`tf.distribute.ReductionToOneDevice`.
|
|
"""
|
|
|
|
def __init__(self, all_reduce_alg="nccl", num_packs=1):
|
|
"""Initializes the object.
|
|
|
|
Args:
|
|
all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
|
|
"hierarchical_copy" are supported.
|
|
num_packs: a non-negative integer. The number of packs to split values
|
|
into. If zero, no packing will be done.
|
|
"""
|
|
self._all_reduce_alg = all_reduce_alg
|
|
self._num_packs = num_packs
|
|
self._simple_cross_replica_ops = ReductionToOneDevice()
|
|
super(AllReduceCrossDeviceOps, self).__init__()
|
|
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
options):
|
|
del options # Unused.
|
|
# To use NCCL or all-reduce, source and destination devices should match,
|
|
# and none of the devices should be CPU.
|
|
if (_devices_match(per_replica_value, destinations) and
|
|
not any("cpu" in d.lower() for d in get_devices_from(destinations))):
|
|
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
|
else:
|
|
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
|
|
destinations)
|
|
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
options):
|
|
if _all_devices_match(value_destination_pairs):
|
|
return self._batch_all_reduce(reduce_op,
|
|
[v[0] for v in value_destination_pairs])
|
|
else:
|
|
return [
|
|
self.reduce_implementation(reduce_op, value, dest, options)
|
|
for value, dest in value_destination_pairs
|
|
]
|
|
|
|
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
|
"""All-reduce algorithm in a batch."""
|
|
dense_values, dense_indices, sparse_values, sparse_indices = (
|
|
cross_device_utils.split_by_sparsity(per_replica_values))
|
|
if dense_values:
|
|
dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
|
|
else:
|
|
dense_results = []
|
|
if sparse_values:
|
|
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
|
|
sparse_values)
|
|
else:
|
|
sparse_results = []
|
|
return cross_device_utils.stitch_values(((dense_results, dense_indices),
|
|
(sparse_results, sparse_indices)))
|
|
|
|
def _do_batch_all_reduce(self, reduce_op, dense_values):
|
|
"""Run batch all-reduces."""
|
|
logging.log_first_n(
|
|
logging.INFO,
|
|
"batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
|
|
(len(dense_values), self._all_reduce_alg, self._num_packs), 10)
|
|
|
|
destinations = dense_values[0]._devices # pylint: disable=protected-access
|
|
grouped = _group_value_by_device(dense_values)
|
|
|
|
device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
|
|
|
|
# The actual aggregation of the repacked gradients. Note that they are
|
|
# sharded among different aggregation trees. So it is important to strike
|
|
# the balance on num_splits.
|
|
if self._all_reduce_alg == "nccl":
|
|
# TODO(yuefengz): merge this into the all-reduce library.
|
|
reduced = cross_device_utils.aggregate_gradients_using_nccl(
|
|
device_grad_packs)
|
|
else:
|
|
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
|
|
# order.
|
|
reduced = (
|
|
cross_device_utils.aggregate_gradients_using_hierarchical_copy(
|
|
destinations, device_grad_packs))
|
|
|
|
reduced = _unpack_tensors(reduced, tensor_packer)
|
|
return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
|
|
|
|
def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
|
|
"""Run batch all-reduce for sparse values."""
|
|
logging.log_first_n(
|
|
logging.WARN,
|
|
"Efficient allreduce is not supported for %d IndexedSlices" %
|
|
len(sparse_values), 10)
|
|
# Use `sparse_values` as destinations to do all-reduces. It is effectively
|
|
# an allgather under the hood but not an efficient one.
|
|
return self._simple_cross_replica_ops.batch_reduce(
|
|
reduce_op, zip(sparse_values, sparse_values))
|
|
|
|
def _gather_implementation(self, per_replica_value, destinations, axis,
|
|
options):
|
|
logging.warning("gather/all_gather with NCCL or HierarchicalCopy is not "
|
|
"supported. Falling back to gather on one device and "
|
|
"then broadcast. We're working on a more efficient "
|
|
"implementation.")
|
|
return ReductionToOneDevice()._gather(per_replica_value, destinations, axis, # pylint: disable=protected-access
|
|
options)
|
|
|
|
|
|
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
|
|
AllReduceCrossTowerOps = AllReduceCrossDeviceOps
|
|
|
|
|
|
AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
|
|
"alg shards limit")
|
|
|
|
|
|
@tf_export("distribute.NcclAllReduce")
|
|
class NcclAllReduce(AllReduceCrossDeviceOps):
|
|
"""NCCL all-reduce implementation of CrossDeviceOps.
|
|
|
|
It uses Nvidia NCCL for all-reduce. For the batch API, tensors will be
|
|
repacked or aggregated for more efficient cross-device transportation.
|
|
|
|
For reduces that are not all-reduce, it falls back to
|
|
`tf.distribute.ReductionToOneDevice`.
|
|
|
|
Here is how you can use `NcclAllReduce` in `tf.distribute.MirroredStrategy`:
|
|
|
|
|
|
```
|
|
strategy = tf.distribute.MirroredStrategy(
|
|
cross_device_ops=tf.distribute.NcclAllReduce())
|
|
```
|
|
"""
|
|
|
|
def __init__(self, num_packs=1):
|
|
"""Initializes the object.
|
|
|
|
Args:
|
|
num_packs: a non-negative integer. The number of packs to split values
|
|
into. If zero, no packing will be done.
|
|
|
|
Raises:
|
|
ValueError: if `num_packs` is negative.
|
|
"""
|
|
if num_packs < 0:
|
|
raise ValueError(
|
|
"NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
|
|
num_packs))
|
|
super(NcclAllReduce, self).__init__(
|
|
all_reduce_alg="nccl", num_packs=num_packs)
|
|
|
|
|
|
@tf_export("distribute.HierarchicalCopyAllReduce")
|
|
class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
|
|
"""Hierarchical copy all-reduce implementation of CrossDeviceOps.
|
|
|
|
It reduces to one GPU along edges in some hierarchy and broadcasts back to
|
|
each GPU along the same path. For the batch API, tensors will be repacked or
|
|
aggregated for more efficient cross-device transportation.
|
|
|
|
This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
|
|
that on DGX-1 machine. If you have different GPU inter-connections, it is
|
|
likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
|
|
|
|
For reduces that are not all-reduce, it falls back to
|
|
`tf.distribute.ReductionToOneDevice`.
|
|
|
|
Here is how you can use `HierarchicalCopyAllReduce` in
|
|
`tf.distribute.MirroredStrategy`:
|
|
|
|
```
|
|
strategy = tf.distribute.MirroredStrategy(
|
|
cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
|
|
```
|
|
"""
|
|
|
|
def __init__(self, num_packs=1):
|
|
"""Initializes the object.
|
|
|
|
Args:
|
|
num_packs: a non-negative integer. The number of packs to split values
|
|
into. If zero, no packing will be done.
|
|
|
|
Raises:
|
|
ValueError if `num_packs` is negative.
|
|
"""
|
|
if num_packs < 0:
|
|
raise ValueError(
|
|
"HierarchicalCopy requires num_packs >= 0, but {} is specified"
|
|
.format(num_packs))
|
|
super(HierarchicalCopyAllReduce, self).__init__(
|
|
all_reduce_alg="hierarchical_copy",
|
|
num_packs=num_packs)
|
|
|
|
|
|
# TODO(crccw): remove after migrating all callers.
|
|
CollectiveCommunication = collective_util.CommunicationImplementation
|
|
CommunicationImplementation = collective_util.CommunicationImplementation
|
|
|
|
|
|
# TODO(yuefengz): support in-graph collective all-reduce.
|
|
class CollectiveAllReduce(CrossDeviceOps):
|
|
"""All-reduce cross device ops using collective ops.
|
|
|
|
In the between-graph replicated training, it will still do all-reduces across
|
|
all workers and then put results on the right destinations.
|
|
"""
|
|
|
|
def __init__(self, devices, group_size, collective_keys=None):
|
|
"""Initializes the object.
|
|
|
|
Args:
|
|
devices: a list of device strings to run collectives on.
|
|
group_size: the global group size. For between-graph replicated training
|
|
it's the total number of devices across all workers.
|
|
collective_keys: an optional CollectiveKey object.
|
|
"""
|
|
if group_size % len(devices) > 0:
|
|
raise ValueError("group_size must be divisible by the number of devices.")
|
|
|
|
self._group_size = group_size
|
|
self._collective_keys = (collective_keys or
|
|
cross_device_utils.CollectiveKeys())
|
|
# This lock guards all collective launches, i.e. calls to
|
|
# cross_device_utils.build_collectve_*.
|
|
#
|
|
# In a multi threaded eager program we need to ensure different groups of
|
|
# collectives don't interleave each other, otherwise there could be
|
|
# deadlocks. E.g. if two user threads both are launching collectives:
|
|
# user-thread-0 device0 device1
|
|
# user-thread-1 device0 device1
|
|
# In eager mode, we use one executor per device. Executors use single FIFO
|
|
# queues, so the above launch sequences end up with the following queues:
|
|
# device-0 collective-0 collective-1
|
|
# device-1 collective-1 collective-0
|
|
# This deadlocks since neither collective is able to finish.
|
|
self._lock = threading.Lock()
|
|
|
|
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
|
group_key = self._collective_keys.get_group_key(self._devices)
|
|
# Collective ops requires all devices to participate and is blocking. In
|
|
# eager, we need one async executor for each device to be able to launch
|
|
# them altogether. Note that async doesn't imply concurrency. Within an
|
|
# async executor operations are still executed sequentially. In graph or
|
|
# function building, the executors are not used.
|
|
self._executors = []
|
|
self._launchers = []
|
|
# Whether to only use NCCL for batched all-reduce when NCCL is requested.
|
|
# This is because of the lack of mechanism to order NCCL operations
|
|
# deterministically.
|
|
self._limited_nccl = False
|
|
for device in self._devices:
|
|
executor = executor_lib.new_executor(enable_async=True)
|
|
self._executors.append(executor)
|
|
launcher = cross_device_utils.CollectiveReplicaLauncher(
|
|
group_key, group_size, self._collective_keys, device, executor)
|
|
self._launchers.append(launcher)
|
|
if not launcher.can_order_nccl():
|
|
self._limited_nccl = True
|
|
|
|
super(CollectiveAllReduce, self).__init__()
|
|
|
|
@property
|
|
def _num_between_graph_workers(self):
|
|
# Currently we only support equal number of devices on each worker.
|
|
return self._group_size / len(self._devices)
|
|
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
options):
|
|
values_util.mark_as_unsaveable()
|
|
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
|
|
options)[0]
|
|
devices = get_devices_from(destinations)
|
|
|
|
if _devices_match(per_replica_value, destinations):
|
|
return all_reduced
|
|
|
|
# Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
|
|
# utility to access component for a particular device.
|
|
if not isinstance(all_reduced, value_lib.Mirrored):
|
|
all_reduced = value_lib.Mirrored([all_reduced])
|
|
|
|
# If we got this far, the destination devices do not match the all-reduce
|
|
# devices, so we must map from one to the other.
|
|
index = []
|
|
# We must add these control dependencies, otherwise we can get deadlock.
|
|
with ops.control_dependencies(all_reduced.values):
|
|
for d in devices:
|
|
with ops.device(d):
|
|
for v in all_reduced.values:
|
|
if v.device == d:
|
|
index.append(array_ops.identity(v))
|
|
break
|
|
else:
|
|
# TODO(josh11b): Once we add support for model parallelism, get the
|
|
# copy from the corresponding replica instead of the primary.
|
|
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
|
|
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
|
|
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
options):
|
|
values_util.mark_as_unsaveable()
|
|
all_devices_match = _all_devices_match(value_destination_pairs)
|
|
if all_devices_match:
|
|
return self._batch_all_reduce(reduce_op,
|
|
[v[0] for v in value_destination_pairs],
|
|
options)
|
|
else:
|
|
if not all_devices_match:
|
|
logging.log_first_n(
|
|
logging.WARN, "Efficient batch_reduce is not supported if "
|
|
"destinations are different.", 10)
|
|
|
|
return [
|
|
self.reduce_implementation(reduce_op, value, dest, options)
|
|
for value, dest in value_destination_pairs
|
|
]
|
|
|
|
def _batch_all_reduce(self, reduce_op, per_replica_values, options):
|
|
"""All reduce algorithm in a batch."""
|
|
dense_values, dense_indices, sparse_values, sparse_indices = (
|
|
cross_device_utils.split_by_sparsity(per_replica_values))
|
|
if dense_values:
|
|
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
|
|
options)
|
|
else:
|
|
dense_results = []
|
|
if sparse_values:
|
|
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
|
|
sparse_values, options)
|
|
else:
|
|
sparse_results = []
|
|
return cross_device_utils.stitch_values(
|
|
((dense_results, dense_indices), (sparse_results, sparse_indices)))
|
|
|
|
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values, options):
|
|
"""All-reduce across all workers in a batch."""
|
|
|
|
batch_size = len(per_replica_values)
|
|
implementation = options.implementation.value
|
|
# For now, we use NCCL only when batch_size > 1 since we don't have a way to
|
|
# order NCCL launches. We're hoping that there's only one batched
|
|
# all-reduce, which is the gradients.
|
|
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
|
# is NCCL if and only if we can order collectives deterministically.
|
|
if (self._limited_nccl and
|
|
options.implementation == CommunicationImplementation.NCCL and
|
|
batch_size == 1):
|
|
implementation = CommunicationImplementation.AUTO.value
|
|
|
|
# Reverse the lists so that there's better chance that values follows
|
|
# the order in which they are calculated (e.g. when they're gradients), so
|
|
# as to overlap calculation with communication. However, this may not be
|
|
# optimal for cases like gradients of complicated non-sequential models.
|
|
#
|
|
# Note that we reverse the list before packing so that the first pack won't
|
|
# be too small, since it's more likely for first few packs to have long
|
|
# queuing time due to concurrent intense computation.
|
|
#
|
|
# TODO(b/147393503): explore solutions for optimal ordering.
|
|
values_by_device = [[] for _ in range(len(self._devices))]
|
|
for per_replica in reversed(per_replica_values):
|
|
for i in range(len(self._devices)):
|
|
values_by_device[i].append(per_replica.values[i])
|
|
|
|
outputs_by_device = []
|
|
with self._lock:
|
|
for i in range(len(self._devices)):
|
|
packs = cross_device_utils.group_by_size(
|
|
values_by_device[i], options.bytes_per_pack)
|
|
if not context.executing_eagerly() and i == 0:
|
|
logging.info(
|
|
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
|
|
"group_size = %d, implementation = %s, num_packs = %d",
|
|
batch_size, len(self._launchers), self._group_size,
|
|
implementation, len(packs))
|
|
outputs_by_device.append(self._launchers[i].batch_all_reduce(
|
|
packs, implementation, options.timeout_seconds))
|
|
|
|
for e in self._executors:
|
|
e.wait()
|
|
|
|
mirrored = []
|
|
for values in zip(*outputs_by_device):
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
values = list(values)
|
|
for i, v in enumerate(values):
|
|
with ops.device(v.device):
|
|
values[i] = v / self._group_size
|
|
mirrored.append(
|
|
distribute_utils.regroup(values, wrap_class=value_lib.Mirrored))
|
|
# Reverse the order of reduced value to recover the order in the input.
|
|
return list(reversed(mirrored))
|
|
|
|
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values, options):
|
|
"""All-reduce IndexedSlices across all workers in a batch."""
|
|
|
|
logging.log_first_n(
|
|
logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
|
|
"%d all-reduces, group_size = %d" %
|
|
(len(per_replica_values), self._group_size), 10)
|
|
|
|
implementation = options.implementation.value
|
|
# For now, we use NCCL only when batch_size > 1.
|
|
# TODO(b/132575814): switch to NCCL for all collectives when implementation
|
|
# is NCCL.
|
|
if (self._limited_nccl and
|
|
options.implementation == CommunicationImplementation.NCCL and
|
|
len(per_replica_values) == 1):
|
|
implementation = CommunicationImplementation.AUTO.value
|
|
|
|
gathered_values = []
|
|
with self._lock:
|
|
for per_replica in per_replica_values:
|
|
outputs = []
|
|
for i in range(len(self._devices)):
|
|
outputs.append(self._launchers[i].all_reduce_indexed_slices(
|
|
per_replica.values[i], implementation, options.timeout_seconds))
|
|
gathered_values.append(outputs)
|
|
|
|
mirrored = []
|
|
for value in gathered_values:
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
# Assume each worker has the same number of replicas.
|
|
for i, v in enumerate(value):
|
|
with ops.device(v.device):
|
|
value[i].values = value[i].values / self._group_size
|
|
mirrored.append(
|
|
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
|
return mirrored
|
|
|
|
def _gather_implementation(self, per_replica_value, destinations, axis,
|
|
options):
|
|
all_gathered = self._batch_all_gather([per_replica_value], axis, options)[0]
|
|
values_util.mark_as_unsaveable()
|
|
devices = get_devices_from(destinations)
|
|
|
|
if _devices_match(per_replica_value, destinations):
|
|
return all_gathered
|
|
|
|
# Convert `all_gathered` to a `Mirrored` object, as a simple and uniform
|
|
# utility to access component for a particular device.
|
|
if not isinstance(all_gathered, value_lib.Mirrored):
|
|
all_gathered = value_lib.Mirrored([all_gathered])
|
|
|
|
# If we got this far, the destination devices do not match the all-gather
|
|
# devices, so we must map from one to the other.
|
|
index = []
|
|
# We must add these control dependencies, otherwise we can get deadlock.
|
|
with ops.control_dependencies(all_gathered.values):
|
|
for d in devices:
|
|
with ops.device(d):
|
|
for v in all_gathered.values:
|
|
if v.device == d:
|
|
index.append(array_ops.identity(v))
|
|
break
|
|
else:
|
|
index.append(array_ops.identity(all_gathered._primary)) # pylint: disable=protected-access
|
|
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
|
|
|
|
def _batch_all_gather(self, per_replica_values, axis, options):
|
|
"""all gather multiple per-replica-values."""
|
|
batch_size = len(per_replica_values)
|
|
# Pass options.implementation to the runtime as a communication
|
|
# implementation hint.
|
|
implementation = options.implementation.value
|
|
# For now, we use NCCL only when batch_size > 1.
|
|
# TODO(b/132575814): switch to NCCL for all collectives when implementation
|
|
# is NCCL.
|
|
if (options.implementation == CommunicationImplementation.NCCL and
|
|
batch_size == 1):
|
|
implementation = CommunicationImplementation.AUTO.value
|
|
|
|
logging.log_first_n(
|
|
logging.INFO, "Collective batch_all_gather: %d all-gathers, "
|
|
"num_devices = %d, group_size = %d, implementation = %s, " %
|
|
(batch_size, len(self._devices), self._group_size, implementation), 10)
|
|
|
|
def compute_gathered_values():
|
|
gathered_values = []
|
|
with self._lock, ops.name_scope("allgather"):
|
|
for per_replica in per_replica_values:
|
|
outputs = []
|
|
for i in range(len(self._devices)):
|
|
outputs.append(self._launchers[i].all_gather(
|
|
per_replica.values[i], axis, implementation,
|
|
options.timeout_seconds))
|
|
gathered_values.append(outputs)
|
|
return gathered_values
|
|
|
|
if context.executing_eagerly():
|
|
gathered_values = def_function.function(compute_gathered_values)()
|
|
else:
|
|
gathered_values = compute_gathered_values()
|
|
|
|
mirrored = []
|
|
for value in gathered_values:
|
|
mirrored.append(
|
|
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
|
return mirrored
|
|
|
|
def __deepcopy__(self, memo):
|
|
# distribute_coordinator deep-copies the strategy object, so
|
|
# CollectiveAllReduce needs to support deep copy as well.
|
|
collective_keys = copy.deepcopy(self._collective_keys, memo)
|
|
return CollectiveAllReduce(self._devices, self._group_size, collective_keys)
|
|
|
|
|
|
def select_cross_device_ops(devices, session_config=None):
|
|
"""Find the best `CrossDeviceOps` locally given a `tf.compat.v1.ConfigProto`.
|
|
|
|
Args:
|
|
devices: a list of devices passed to `tf.distribute.Strategy`.
|
|
session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will
|
|
make decision based on all logical devices.
|
|
|
|
Returns:
|
|
A subclass of `CrossDeviceOps`.
|
|
"""
|
|
requested_devices = set(device_util.canonicalize(d) for d in devices)
|
|
if ops.executing_eagerly_outside_functions():
|
|
logical_gpus = context.context().list_logical_devices(device_type="GPU")
|
|
physical_gpus = context.context().list_physical_devices(device_type="GPU")
|
|
if len(logical_gpus) != len(physical_gpus):
|
|
logging.warning("NCCL is not supported when using virtual GPUs, falling"
|
|
"back to reduction to one device")
|
|
return ReductionToOneDevice()
|
|
|
|
machine_devices = context.context().list_logical_devices()
|
|
else:
|
|
machine_devices = device_lib.list_local_devices(
|
|
session_config=session_config)
|
|
using_devices = set()
|
|
for d in machine_devices:
|
|
if device_util.canonicalize(d.name) in requested_devices:
|
|
using_devices.add(d.name)
|
|
|
|
if len(using_devices) != len(requested_devices):
|
|
logging.warning(
|
|
"Some requested devices in `tf.distribute.Strategy` are not visible "
|
|
"to TensorFlow: %s", ",".join(list(requested_devices - using_devices)))
|
|
|
|
if any("gpu" not in d.lower() for d in requested_devices):
|
|
logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, "
|
|
"not using nccl allreduce.")
|
|
return ReductionToOneDevice()
|
|
|
|
if kernels.get_registered_kernels_for_op("NcclAllReduce"):
|
|
return NcclAllReduce(num_packs=1)
|
|
else:
|
|
logging.warning("Nccl kernel is not found, not using nccl allreduce.")
|
|
return ReductionToOneDevice()
|