We used to infer the devices from the inputs, but sometimes the inputs don't have device placement. E.g. when passing into or returning from tf.function, the device placement may be lost. Instead of inferring from the inputs we should just be explicit about the collective devices. PiperOrigin-RevId: 316743112 Change-Id: I2f6995f2f4cc86864723e203deb7562363cdbc38
1211 lines
48 KiB
Python
1211 lines
48 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Classes for different algorithms of reduction and broadcasting."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import enum
|
|
import threading
|
|
|
|
import six
|
|
|
|
from tensorflow.python.client import device_lib
|
|
from tensorflow.python.distribute import collective_util
|
|
from tensorflow.python.distribute import cross_device_utils
|
|
from tensorflow.python.distribute import device_util
|
|
from tensorflow.python.distribute import distribute_utils
|
|
from tensorflow.python.distribute import ps_values
|
|
from tensorflow.python.distribute import reduce_util
|
|
from tensorflow.python.distribute import tpu_values
|
|
from tensorflow.python.distribute import values as value_lib
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import executor
|
|
from tensorflow.python.framework import kernels
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
from tensorflow.tools.docs import doc_controls
|
|
|
|
|
|
def check_destinations(destinations):
|
|
"""Checks whether `destinations` is not empty.
|
|
|
|
Args:
|
|
destinations: a `DistributedValues`, variable, or string object.
|
|
|
|
Returns:
|
|
Boolean which is True if `destinations` is not empty.
|
|
"""
|
|
# Calling bool() on a ResourceVariable is not allowed.
|
|
if isinstance(destinations,
|
|
(resource_variable_ops.BaseResourceVariable, ops.Tensor)):
|
|
return bool(destinations.device)
|
|
return bool(destinations)
|
|
|
|
|
|
def validate_destinations(destinations):
|
|
"""Validates the `destination` is one of expected types."""
|
|
if not isinstance(
|
|
destinations,
|
|
(value_lib.DistributedValues, ops.Tensor, ps_values.AggregatingVariable,
|
|
six.string_types, tpu_values.TPUMirroredVariable
|
|
)) and not resource_variable_ops.is_resource_variable(destinations):
|
|
raise ValueError("destinations must be one of a `DistributedValues` object,"
|
|
" a tf.Variable object, or a device string.")
|
|
|
|
if not check_destinations(destinations):
|
|
raise ValueError("destinations can not be empty")
|
|
|
|
|
|
def reduce_non_distributed_value(
|
|
reduce_op, value, destinations, num_replicas_in_graph):
|
|
"""Reduce a non-DistributedValue `value` to `destinations`."""
|
|
if isinstance(value, value_lib.DistributedValues):
|
|
raise ValueError("You are passing a `DistributedValue` to "
|
|
"`reduce_non_distributed_value`, which is not allowed.")
|
|
|
|
# If the same value is present on all replicas then the PerReplica value will
|
|
# be a single value. We also handle the case when `value` is a single value
|
|
# and equal to 0.
|
|
# TODO:(b/138823479): handle the tensor value properly.
|
|
if not tensor_util.is_tensor(value) and value == 0:
|
|
return 0
|
|
# If there is only a single value and the reduce op is MEAN,
|
|
# that value should be on all destinations.
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
return value
|
|
elif num_replicas_in_graph != 1:
|
|
# We do not support a reduce op of SUM if the value is the same across
|
|
# all replicas. We call this as part of assign functions for
|
|
# MirroredVariables and summing up identical values across replicas is not
|
|
# clearly defined.
|
|
raise ValueError("A non-DistributedValues value %s cannot be reduced with "
|
|
"the given reduce op %s." % (value, reduce_op))
|
|
else:
|
|
validate_destinations(destinations)
|
|
return simple_broadcast(value, destinations)
|
|
|
|
|
|
def _make_tensor_into_per_replica(input_tensor):
|
|
"""Converts a single tensor into a PerReplica object."""
|
|
if isinstance(input_tensor, (tuple, list)):
|
|
raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object, "
|
|
"got %r but expected a object that is not a tuple or list."
|
|
% (input_tensor,))
|
|
if isinstance(input_tensor, value_lib.PerReplica):
|
|
return input_tensor
|
|
elif hasattr(input_tensor, "device"):
|
|
return value_lib.PerReplica((input_tensor,))
|
|
else:
|
|
raise ValueError("Cannot convert `input_tensor` to a `PerReplica` object "
|
|
"because it doesn't have device set.")
|
|
|
|
|
|
def _normalize_value_destination_pairs(value_destination_pairs):
|
|
"""Converts each tensor into a PerReplica object in the input list."""
|
|
result = []
|
|
|
|
value_destination_pairs = list(value_destination_pairs)
|
|
|
|
if not isinstance(value_destination_pairs, (list, tuple)):
|
|
raise ValueError("`value_destination_pairs` should be a list or tuple")
|
|
for pair in value_destination_pairs:
|
|
if not isinstance(pair, tuple):
|
|
raise ValueError(
|
|
"Each element of `value_destination_pairs` should be a tuple.")
|
|
if len(pair) != 2:
|
|
raise ValueError("Each element of `value_destination_pairs` should be a "
|
|
"tuple of size 2.")
|
|
|
|
per_replica = _make_tensor_into_per_replica(pair[0])
|
|
result.append((per_replica, pair[1]))
|
|
return result
|
|
|
|
|
|
def _validate_value_destination_pairs(value_destination_pairs):
|
|
# TODO(yuefengz): raise exceptions instead of returning False.
|
|
# pylint: disable=g-missing-docstring
|
|
if not value_destination_pairs: return False
|
|
if not isinstance(value_destination_pairs, (list, tuple)): return False
|
|
if not all(isinstance(pair, tuple) for pair in value_destination_pairs):
|
|
return False
|
|
if not all(isinstance(v[0], value_lib.PerReplica)
|
|
for v in value_destination_pairs):
|
|
return False
|
|
return True
|
|
|
|
|
|
# TODO(yuefengz): consider calling this function in the caller of
|
|
# CrossDeviceOps.
|
|
def get_devices_from(destinations):
|
|
if isinstance(destinations, value_lib.DistributedValues):
|
|
return destinations._devices # pylint: disable=protected-access
|
|
elif isinstance(destinations, six.string_types):
|
|
return (device_util.resolve(destinations),)
|
|
return (device_util.resolve(destinations.device),)
|
|
|
|
|
|
def _devices_match(left, right):
|
|
return left is right or set(get_devices_from(left)) == set(
|
|
get_devices_from(right))
|
|
|
|
|
|
def _all_devices_match(value_destination_pairs):
|
|
if not all(_devices_match(v, d) for v, d in value_destination_pairs):
|
|
return False
|
|
if not all(_devices_match(v, value_destination_pairs[0][0])
|
|
for v, _ in value_destination_pairs[1:]):
|
|
return False
|
|
return True
|
|
|
|
|
|
def simple_broadcast(value, destinations, always_mirrored=False):
|
|
"""Broadcast `value` to `destinations` using simple copies."""
|
|
devices = get_devices_from(destinations)
|
|
if len(devices) == 1 and not always_mirrored:
|
|
return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
|
|
value, devices[0])
|
|
else:
|
|
value_updates = []
|
|
for d in devices:
|
|
value_updates.append(
|
|
cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
|
|
return distribute_utils.regroup(value_updates,
|
|
wrap_class=value_lib.Mirrored)
|
|
|
|
|
|
def _simple_reduce(per_replica_value, reduce_to_device, accumulation_fn,
|
|
reduce_op):
|
|
# pylint: disable=g-missing-docstring
|
|
all_values = per_replica_value.values
|
|
if not all_values:
|
|
raise ValueError("`per_replica_value` must be non-empty")
|
|
count = len(all_values)
|
|
|
|
with ops.device(reduce_to_device):
|
|
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
|
reduced = cross_device_utils.aggregate_tensors_or_indexed_slices(
|
|
all_values, accumulation_fn)
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
reduced = cross_device_utils.divide_by_n_tensors_or_indexed_slices(
|
|
reduced, count)
|
|
elif reduce_op != reduce_util.ReduceOp.SUM:
|
|
raise ValueError("`reduce_op` must be Reduce.SUM or Reduce.MEAN.")
|
|
return reduced
|
|
|
|
|
|
@tf_export("distribute.CrossDeviceOps")
|
|
class CrossDeviceOps(object):
|
|
"""Base class for cross-device reduction and broadcasting algorithms."""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@property
|
|
def _num_between_graph_workers(self):
|
|
# Returns 1 by default, the value may be overridden by sub classes.
|
|
return 1
|
|
|
|
def reduce(self,
|
|
reduce_op,
|
|
per_replica_value,
|
|
destinations,
|
|
experimental_hints=None):
|
|
"""Reduce `per_replica_value` to `destinations`.
|
|
|
|
It runs the reduction operation defined by `reduce_op` and put the
|
|
result on `destinations`.
|
|
|
|
Args:
|
|
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
|
|
per_replica_value will be reduced.
|
|
per_replica_value: A `tf.distribute.DistributedValues` object or a tensor
|
|
with device set.
|
|
destinations: the reduction destinations.
|
|
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
|
to perform collective operations.
|
|
|
|
Returns:
|
|
a Mirrored object.
|
|
|
|
Raises:
|
|
ValueError: if per_replica_value can't be converted to a PerReplica
|
|
object or if destinations aren't strings, Variables or DistributedValues
|
|
"""
|
|
if not isinstance(per_replica_value, value_lib.DistributedValues):
|
|
per_replica_value = _make_tensor_into_per_replica(per_replica_value)
|
|
|
|
validate_destinations(destinations)
|
|
|
|
# Shortcut if `per_replica_value` only contains one value.
|
|
if self._num_between_graph_workers == 1 and len(
|
|
per_replica_value.values) == 1 and _devices_match(
|
|
per_replica_value, destinations):
|
|
with ops.device(per_replica_value.values[0].device):
|
|
v = array_ops.identity(per_replica_value.values[0])
|
|
return distribute_utils.regroup((v,), wrap_class=value_lib.Mirrored)
|
|
|
|
if experimental_hints is None:
|
|
experimental_hints = collective_util.Hints()
|
|
return self.reduce_implementation(reduce_op, per_replica_value,
|
|
destinations, experimental_hints)
|
|
|
|
def batch_reduce(self,
|
|
reduce_op,
|
|
value_destination_pairs,
|
|
experimental_hints=None):
|
|
"""Reduce PerReplica objects in a batch.
|
|
|
|
Reduce each first element in `value_destination_pairs` to each second
|
|
element which indicates the destinations.
|
|
|
|
This can be faster than multiple individual `reduce`s because we can
|
|
fuse several tensors into one or multiple packs before reduction.
|
|
|
|
Args:
|
|
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
|
|
`per_replica_value` will be reduced.
|
|
value_destination_pairs: A list or a tuple of PerReplica objects (or
|
|
tensors with device set if there is one device) and destinations.
|
|
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
|
to perform collective operations.
|
|
|
|
Returns:
|
|
a list of Mirrored objects.
|
|
|
|
Raises:
|
|
ValueError: if `value_destination_pairs` is not an iterable of
|
|
tuples of PerReplica objects and destinations.
|
|
"""
|
|
# TODO(yuefengz): if destinations are different, split into several
|
|
# `_batch_reduce` invocations.
|
|
if not _validate_value_destination_pairs(value_destination_pairs):
|
|
# If the first element of each pair is a tensor, we try to turn it into a
|
|
# PerReplica object.
|
|
value_destination_pairs = _normalize_value_destination_pairs(
|
|
value_destination_pairs)
|
|
|
|
for _, d in value_destination_pairs:
|
|
validate_destinations(d)
|
|
|
|
# Shortcut all PerReplica objects only contain one value.
|
|
if self._num_between_graph_workers == 1 and _all_devices_match(
|
|
value_destination_pairs) and len(
|
|
value_destination_pairs[0][0].values) == 1:
|
|
return [
|
|
distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
|
|
for v, _ in value_destination_pairs
|
|
]
|
|
|
|
if experimental_hints is None:
|
|
experimental_hints = collective_util.Hints()
|
|
return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
|
|
experimental_hints)
|
|
|
|
def broadcast(self, tensor, destinations):
|
|
"""Broadcast the `tensor` to destinations.
|
|
|
|
Args:
|
|
tensor: the tensor to broadcast.
|
|
destinations: the broadcast destinations.
|
|
|
|
Returns:
|
|
a Mirrored object.
|
|
"""
|
|
validate_destinations(destinations)
|
|
return self.broadcast_implementation(tensor, destinations)
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
experimental_hints):
|
|
"""The implementation of reduce of `per_replica_value` to `destinations`.
|
|
|
|
Overriding this method is useful for subclass implementers.
|
|
|
|
It runs the reduction operation defined by `reduce_op` and put the
|
|
result on `destinations`.
|
|
|
|
Args:
|
|
reduce_op: An instance `tf.distribute.ReduceOp` that indicates of how
|
|
per_replica_value will be reduced.
|
|
per_replica_value: A PerReplica object or a tensor with device set.
|
|
destinations: the reduction destinations.
|
|
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
|
to perform collective operations.
|
|
|
|
Returns:
|
|
a Mirrored object.
|
|
|
|
Raises:
|
|
ValueError: if per_replica_value can't be converted to a PerReplica
|
|
object.
|
|
"""
|
|
raise NotImplementedError(
|
|
"_reduce method must be implemented in descendants.")
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
experimental_hints):
|
|
"""Implementation of reduce PerReplica objects in a batch.
|
|
|
|
Overriding this method is useful for subclass implementers.
|
|
|
|
Reduce each first element in `value_destination_pairs` to each second
|
|
element which indicates the destinations.
|
|
|
|
Args:
|
|
reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
|
|
per_replica_value will be reduced.
|
|
value_destination_pairs: An iterable of tuples of PerReplica objects
|
|
(or tensors with device set if there is one device) and destinations.
|
|
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
|
to perform collective operations.
|
|
|
|
Returns:
|
|
a list of Mirrored objects.
|
|
|
|
Raises:
|
|
ValueError: if `value_destination_pairs` is not an iterable of
|
|
tuples of PerReplica objects and destinations
|
|
"""
|
|
raise NotImplementedError(
|
|
"batch_reduce_implementation method must be implemented in descendants."
|
|
)
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def broadcast_implementation(self, tensor, destinations):
|
|
"""Implementation of broadcast the `tensor` to destinations.
|
|
|
|
Args:
|
|
tensor: the tensor to broadcast.
|
|
destinations: the broadcast destinations.
|
|
|
|
Returns:
|
|
a Mirrored object.
|
|
"""
|
|
return simple_broadcast(tensor, destinations, always_mirrored=True)
|
|
|
|
|
|
@tf_export("distribute.ReductionToOneDevice")
|
|
class ReductionToOneDevice(CrossDeviceOps):
|
|
"""Always do reduction to one device first and then do broadcasting.
|
|
|
|
Batch reduction is done by reduction on each element one by one.
|
|
|
|
```
|
|
mirrored_strategy = tf.distribute.MirroredStrategy(
|
|
cross_device_ops=tf.distribute.ReductionToOneDevice())
|
|
```
|
|
"""
|
|
|
|
def __init__(self, reduce_to_device=None, accumulation_fn=None):
|
|
"""Initializes with a device to reduce to and a way to accumulate.
|
|
|
|
Args:
|
|
reduce_to_device: the intermediate device to reduce to. If None, reduce
|
|
to the first device in `destinations` of the `reduce()` method.
|
|
accumulation_fn: a function that does accumulation. If None, then
|
|
`tf.math.add_n` is used.
|
|
"""
|
|
self.reduce_to_device = reduce_to_device
|
|
self.accumulation_fn = accumulation_fn or math_ops.add_n
|
|
super(ReductionToOneDevice, self).__init__()
|
|
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
experimental_hints):
|
|
del experimental_hints # Unused.
|
|
if check_destinations(destinations):
|
|
devices = get_devices_from(destinations)
|
|
else:
|
|
devices = get_devices_from(per_replica_value)
|
|
reduce_to_device = self.reduce_to_device or devices[0]
|
|
logging.log_first_n(
|
|
logging.INFO,
|
|
"Reduce to %s then broadcast to %r." % (reduce_to_device, devices), 10)
|
|
reduced = _simple_reduce(per_replica_value, reduce_to_device,
|
|
self.accumulation_fn, reduce_op)
|
|
return self.broadcast(reduced, destinations)
|
|
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
experimental_hints):
|
|
return [
|
|
self.reduce_implementation(
|
|
reduce_op, t, destinations=v, experimental_hints=experimental_hints)
|
|
for t, v in value_destination_pairs
|
|
]
|
|
|
|
|
|
def _group_value_by_device(per_replica_values):
|
|
"""Group values into sublists by their devices.
|
|
|
|
This grouping is needed to call the all-reduce library because it expects a
|
|
list of the following form:
|
|
[[(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ...],
|
|
[(grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ...],
|
|
[(grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ...],
|
|
...
|
|
]
|
|
|
|
Args:
|
|
per_replica_values: a list of PerReplica objects.
|
|
|
|
Returns:
|
|
a list of lists, each sublist has components for its corresponding device of
|
|
PerReplica objects, paired with a None.
|
|
"""
|
|
destinations = per_replica_values[0]._devices # pylint: disable=protected-access
|
|
grouped = [[] for _ in range(len(destinations))]
|
|
for per_replica_value in per_replica_values:
|
|
# pylint: disable=protected-access
|
|
for i, v in enumerate(per_replica_value.values):
|
|
assert per_replica_value._devices == destinations
|
|
grouped[i].append((v, None))
|
|
return grouped
|
|
|
|
|
|
def _ungroup_and_make_mirrored(grouped_reduced,
|
|
destinations,
|
|
reduce_op,
|
|
num_between_graph_workers=1):
|
|
"""Ungroup results from all-reduce and make Mirrored objects.
|
|
|
|
Each all-reduce result will be divided by the number of destinations before
|
|
Mirrored objects are created if reduce_op is "mean".
|
|
|
|
Args:
|
|
grouped_reduced: a list of lists, each sublist has components for each
|
|
device, paired with a None. It is the result from
|
|
cross_device_utils.aggregate_gradients_using*.
|
|
destinations: a value to colocate the result with.
|
|
reduce_op: Indicates how values will be aggregated. Accepted values
|
|
are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
|
|
num_between_graph_workers: number of workers in the between-graph
|
|
replication.
|
|
|
|
Returns:
|
|
a list of Mirrored objects.
|
|
"""
|
|
num_replicas = len(get_devices_from(destinations)) * num_between_graph_workers
|
|
index = [[] for _ in range(len(grouped_reduced[0]))]
|
|
for per_replica_reduced in grouped_reduced:
|
|
for i, (v, _) in enumerate(per_replica_reduced):
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
with ops.device(v.device):
|
|
index[i].append(v / num_replicas)
|
|
else:
|
|
index[i].append(v)
|
|
return [distribute_utils.regroup(
|
|
v, wrap_class=value_lib.Mirrored) for v in index]
|
|
|
|
|
|
class _ConcatAndSplitPacker(object):
|
|
"""Concatenate and split tensors for reduction."""
|
|
|
|
def __init__(self, num_packs=1):
|
|
"""Initialize the _ConcatAndSplitPacker object.
|
|
|
|
Args:
|
|
num_packs: specifies the number of split packs that will be
|
|
formed.
|
|
|
|
Raises:
|
|
ValueError: if num_packs is not greater than 0.
|
|
"""
|
|
if num_packs <= 0:
|
|
raise ValueError("num_packs must be greater than zero.")
|
|
self.num_packs = num_packs
|
|
|
|
def pack(self, grouped_grads_and_vars):
|
|
"""Pack tensors."""
|
|
self.grouped_grads_and_vars = grouped_grads_and_vars
|
|
self.all_device_shapes = []
|
|
self.all_device_sizes = []
|
|
|
|
device_grad_packs = []
|
|
for device_grads_and_vars in grouped_grads_and_vars:
|
|
with ops.colocate_with(device_grads_and_vars[0][0]):
|
|
# Flatten all the grads.
|
|
flat_grads = [
|
|
array_ops.reshape(g, [-1]) for g, _ in device_grads_and_vars
|
|
]
|
|
# Remember the original shape of all the grads.
|
|
device_shapes = [array_ops.shape(g) for g, _ in device_grads_and_vars]
|
|
# Remember the original sizes of all the grads.
|
|
device_sizes = [array_ops.size(g) for g, _ in device_grads_and_vars]
|
|
# Concat all the flat grads into a big flat tensor.
|
|
concat_grads = array_ops.concat(flat_grads, 0)
|
|
|
|
# Split the big tensor into num_splits packs. In cases where the
|
|
# total size is not divisible num_splits, the last pack gets
|
|
# more elements.
|
|
# TODO(zhengxq): it is also possible to optimize away all the concat
|
|
# as well.
|
|
num_splits = self.num_packs
|
|
|
|
# The array_ops.size function will sometimes remove static shapes. So if
|
|
# all gradient shapes are defined, we use another method to get the
|
|
# total size.
|
|
# TODO(yuefengz): move this logic to array_ops.size.
|
|
if all(g.shape.is_fully_defined() for g, _ in device_grads_and_vars):
|
|
total_grad_size = sum(
|
|
[g.shape.num_elements() for g, _ in device_grads_and_vars])
|
|
else:
|
|
total_grad_size = array_ops.size(concat_grads)
|
|
|
|
split_size = total_grad_size // num_splits
|
|
split_size_last = total_grad_size - split_size * (num_splits - 1)
|
|
split_sizes = [split_size] * (num_splits - 1) + [split_size_last]
|
|
grad_packs = array_ops.split(concat_grads, split_sizes)
|
|
|
|
# Ready to aggregate the repacked gradients, with fake variables.
|
|
# TODO(zhengxq): It is hacky to have to use fake variables.
|
|
# We should remove the need for variables in
|
|
# aggregate_gradients_using*.
|
|
device_grad_packs.append(zip(grad_packs, [None] * num_splits))
|
|
self.all_device_shapes.append(device_shapes)
|
|
self.all_device_sizes.append(device_sizes)
|
|
|
|
return device_grad_packs
|
|
|
|
def unpack(self, summed_device_grad_packs):
|
|
"""Reverse the pack."""
|
|
aggregated_device_grads = []
|
|
for (summed_device_grad_packs,
|
|
device_grads_and_vars, device_shapes, device_sizes) in zip(
|
|
summed_device_grad_packs, self.grouped_grads_and_vars,
|
|
self.all_device_shapes, self.all_device_sizes):
|
|
# pylint: enable=line-too-long
|
|
# Reverse the packing operations in the previous steps. Form the
|
|
# summed gradients back into their original shapes.
|
|
with ops.colocate_with(summed_device_grad_packs[0][0]):
|
|
# Form a list of the summed grad packs.
|
|
device_grad_packs = [g for g, _ in summed_device_grad_packs]
|
|
|
|
# Concat them back into a big flat tensor.
|
|
device_grads_concat = array_ops.concat(device_grad_packs, 0)
|
|
|
|
# Split the tensors back into their original sizes.
|
|
grads_with_sizes = array_ops.split(device_grads_concat, device_sizes)
|
|
|
|
# Reshape the tensors back into their original shapes.
|
|
grads_with_shapes = [
|
|
array_ops.reshape(grad, shape)
|
|
for shape, grad in zip(device_shapes, grads_with_sizes)
|
|
]
|
|
|
|
# Form the list with the original list of variables.
|
|
summed_device_grads = [
|
|
(g, v) for g, (_, v) in zip(grads_with_shapes,
|
|
device_grads_and_vars)
|
|
]
|
|
aggregated_device_grads.append(summed_device_grads)
|
|
return aggregated_device_grads
|
|
|
|
|
|
def _pack_tensors(device_grads, num_packs=0):
|
|
"""Pack tensors if specified."""
|
|
if num_packs > 0:
|
|
tensor_packer = _ConcatAndSplitPacker(num_packs)
|
|
device_grad_packs = tensor_packer.pack(device_grads)
|
|
else:
|
|
tensor_packer = None
|
|
device_grad_packs = device_grads
|
|
return device_grad_packs, tensor_packer
|
|
|
|
|
|
def _unpack_tensors(reduced, tensor_packer=None):
|
|
"""Unpack tensors if they are packed before all-reduce."""
|
|
if tensor_packer:
|
|
return tensor_packer.unpack(reduced)
|
|
return reduced
|
|
|
|
|
|
class AllReduceCrossDeviceOps(CrossDeviceOps):
|
|
"""Reduction using all-reduce."""
|
|
|
|
def __init__(self, all_reduce_alg="nccl", num_packs=1):
|
|
"""All-reduce implementation of CrossDeviceOps.
|
|
|
|
Before performing all-reduce, tensors will be packed for more efficient
|
|
cross-device transportation.
|
|
|
|
Args:
|
|
all_reduce_alg: the all-reduce algorithm to use, currently only "nccl" or
|
|
"hierarchical_copy" are supported.
|
|
num_packs: If non-zero, pack values into `num_packs` splits.
|
|
"""
|
|
self._all_reduce_alg = all_reduce_alg
|
|
self._num_packs = num_packs
|
|
self._simple_cross_replica_ops = ReductionToOneDevice()
|
|
super(AllReduceCrossDeviceOps, self).__init__()
|
|
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
experimental_hints):
|
|
del experimental_hints # Unused.
|
|
if _devices_match(per_replica_value, destinations):
|
|
return self._batch_all_reduce(reduce_op, [per_replica_value])[0]
|
|
else:
|
|
return self._simple_cross_replica_ops.reduce(reduce_op, per_replica_value,
|
|
destinations)
|
|
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
experimental_hints):
|
|
if _all_devices_match(value_destination_pairs):
|
|
return self._batch_all_reduce(reduce_op,
|
|
[v[0] for v in value_destination_pairs])
|
|
else:
|
|
return [
|
|
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
|
|
for value, dest in value_destination_pairs
|
|
]
|
|
|
|
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
|
"""All-reduce algorithm in a batch."""
|
|
dense_values, dense_indices, sparse_values, sparse_indices = (
|
|
cross_device_utils.split_by_sparsity(per_replica_values))
|
|
if dense_values:
|
|
dense_results = self._do_batch_all_reduce(reduce_op, dense_values)
|
|
else:
|
|
dense_results = []
|
|
if sparse_values:
|
|
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
|
|
sparse_values)
|
|
else:
|
|
sparse_results = []
|
|
return cross_device_utils.stitch_values(((dense_results, dense_indices),
|
|
(sparse_results, sparse_indices)))
|
|
|
|
def _do_batch_all_reduce(self, reduce_op, dense_values):
|
|
"""Run batch all-reduces."""
|
|
logging.log_first_n(
|
|
logging.INFO,
|
|
"batch_all_reduce: %d all-reduces with algorithm = %s, num_packs = %d" %
|
|
(len(dense_values), self._all_reduce_alg, self._num_packs), 10)
|
|
|
|
destinations = dense_values[0]._devices # pylint: disable=protected-access
|
|
grouped = _group_value_by_device(dense_values)
|
|
|
|
device_grad_packs, tensor_packer = _pack_tensors(grouped, self._num_packs)
|
|
|
|
# The actual aggregation of the repacked gradients. Note that they are
|
|
# sharded among different aggregation trees. So it is important to strike
|
|
# the balance on num_splits.
|
|
if self._all_reduce_alg == "nccl":
|
|
# TODO(yuefengz): merge this into the all-reduce library.
|
|
reduced = cross_device_utils.aggregate_gradients_using_nccl(
|
|
device_grad_packs)
|
|
else:
|
|
# TODO(yuefengz): check that gpu ids in `destinations` are in ascending
|
|
# order.
|
|
reduced = (
|
|
cross_device_utils.aggregate_gradients_using_hierarchical_copy(
|
|
destinations, device_grad_packs))
|
|
|
|
reduced = _unpack_tensors(reduced, tensor_packer)
|
|
return _ungroup_and_make_mirrored(reduced, dense_values[0], reduce_op)
|
|
|
|
def _do_batch_all_reduce_sparse(self, reduce_op, sparse_values):
|
|
"""Run batch all-reduce for sparse values."""
|
|
logging.log_first_n(
|
|
logging.WARN,
|
|
"Efficient allreduce is not supported for %d IndexedSlices" %
|
|
len(sparse_values), 10)
|
|
# Use `sparse_values` as destinations to do all-reduces. It is effectively
|
|
# an allgather under the hood but not an efficient one.
|
|
return self._simple_cross_replica_ops.batch_reduce(
|
|
reduce_op, zip(sparse_values, sparse_values))
|
|
|
|
|
|
# For compatibility with code using the old name of `AllReduceCrossDeviceOps`.
|
|
AllReduceCrossTowerOps = AllReduceCrossDeviceOps
|
|
|
|
|
|
AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple",
|
|
"alg shards limit")
|
|
|
|
|
|
@tf_export("distribute.NcclAllReduce")
|
|
class NcclAllReduce(AllReduceCrossDeviceOps):
|
|
"""Reduction using NCCL all-reduce."""
|
|
|
|
def __init__(self, num_packs=1):
|
|
"""NCCL all-reduce implementation of CrossDeviceOps.
|
|
|
|
It uses Nvidia NCCL for all-reduce. Before performing all-reduce, tensors
|
|
will be repacked or aggregated for more efficient cross-device
|
|
transportation.
|
|
|
|
Args:
|
|
num_packs: values will be packed in this many splits. `num_packs` should
|
|
be greater than or equals 0. When it is zero, no packing will be done.
|
|
|
|
Raises:
|
|
ValueError if `num_packs` is negative.
|
|
"""
|
|
if num_packs < 0:
|
|
raise ValueError(
|
|
"NCCL all-reduce requires num_packs >= 0, but {} is specified".format(
|
|
num_packs))
|
|
super(NcclAllReduce, self).__init__(
|
|
all_reduce_alg="nccl", num_packs=num_packs)
|
|
|
|
|
|
@tf_export("distribute.HierarchicalCopyAllReduce")
|
|
class HierarchicalCopyAllReduce(AllReduceCrossDeviceOps):
|
|
"""Reduction using hierarchical copy all-reduce.
|
|
|
|
It reduces to one GPU along edges in some hierarchy and broadcasts back to
|
|
each GPU along the same path. Before performing all-reduce, tensors will be
|
|
repacked or aggregated for more efficient cross-device transportation.
|
|
|
|
This is a reduction created for Nvidia DGX-1 which assumes GPUs connects like
|
|
that on DGX-1 machine. If you have different GPU inter-connections, it is
|
|
likely that it would be slower than `tf.distribute.ReductionToOneDevice`.
|
|
"""
|
|
|
|
def __init__(self, num_packs=1):
|
|
"""Initializes the object.
|
|
|
|
Args:
|
|
num_packs: values will be packed in this many splits. `num_packs` should
|
|
be greater than or equals 0. When it is zero, no packing will be done.
|
|
|
|
Raises:
|
|
ValueError if `num_packs` is negative.
|
|
"""
|
|
if num_packs < 0:
|
|
raise ValueError(
|
|
"HierarchicalCopy requires num_packs >= 0, but {} is specified"
|
|
.format(num_packs))
|
|
super(HierarchicalCopyAllReduce, self).__init__(
|
|
all_reduce_alg="hierarchical_copy",
|
|
num_packs=num_packs)
|
|
|
|
|
|
class MultiWorkerAllReduce(AllReduceCrossDeviceOps):
|
|
"""All-reduce algorithms for distributed TensorFlow."""
|
|
|
|
def __init__(self,
|
|
worker_devices,
|
|
num_gpus_per_worker,
|
|
all_reduce_spec=("pscpu/pscpu", 2, -1),
|
|
num_packs=0):
|
|
"""Initialize the all-reduce algorithm.
|
|
|
|
Args:
|
|
worker_devices: a list of device strings for workers participating in
|
|
all-reduce.
|
|
num_gpus_per_worker: number of GPU devices per worker.
|
|
all_reduce_spec: a tuple or a named tuple or a list of tuples specifying
|
|
the all-reduce algorithm.
|
|
1. The first element of a tuple is the name of the all-reduce algorithm.
|
|
Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd",
|
|
"nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with
|
|
a "/" are hierarchical, so two all-reduces are executed, the first one
|
|
aggregates tensors within a worker and the second aggregates across
|
|
workers.
|
|
2. The second element of a tuple is the number of shards when doing
|
|
all-reduce. Let's say its values is M, each tensor after packing will be
|
|
split into M shards and then M parallel all-reduces would be performed
|
|
before finally they are concatenated backed into a complete tensor.
|
|
3. The third element is the maximum size of tensors that will be
|
|
applicable for the algorithm specified by the first element. For
|
|
example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)],
|
|
tensors with size not larger than 1024 bytes will be applied a 2-shard
|
|
"nccl" all-reduce and other tensors will be applied a 2-shard
|
|
"pscpu/pscpu" algorithm. The third elements should be in increasing
|
|
order across tuples and end with -1 which indicates infinity.
|
|
num_packs: see AllReduceCrossDeviceOps.
|
|
"""
|
|
self._worker_devices = worker_devices
|
|
self._num_gpus_per_worker = num_gpus_per_worker
|
|
super(MultiWorkerAllReduce, self).__init__(num_packs=num_packs)
|
|
|
|
def validate_and_complete_spec(spec):
|
|
"""Validate and complete the all-reduce spec."""
|
|
# TODO(yuefengz): support namedtuple.
|
|
if not isinstance(spec, tuple):
|
|
raise ValueError(
|
|
"A tuple is expected for all-reduce spec: %r" % all_reduce_spec)
|
|
if not spec or len(spec) > 3:
|
|
raise ValueError(
|
|
"Too many elements in the all-reduce spec tuple: %r" % spec)
|
|
if len(spec) == 1:
|
|
return AllReduceSpecTuple(spec[0], 1, -1)
|
|
elif len(spec) == 2:
|
|
return AllReduceSpecTuple(spec[0], spec[1], -1)
|
|
else:
|
|
return AllReduceSpecTuple(*spec)
|
|
|
|
self._all_reduce_spec = []
|
|
if isinstance(all_reduce_spec, six.string_types):
|
|
self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1))
|
|
elif isinstance(all_reduce_spec, tuple):
|
|
self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec))
|
|
elif isinstance(all_reduce_spec, list):
|
|
self._all_reduce_spec = [
|
|
validate_and_complete_spec(spec) for spec in all_reduce_spec
|
|
]
|
|
|
|
def _batch_all_reduce(self, reduce_op, per_replica_values):
|
|
"""All-reduce algorithm in a batch."""
|
|
logging.log_first_n(
|
|
logging.INFO, "Distributed batch_all_reduce: %d all-reduces with "
|
|
"allreduce_spec = %r, num_packs = %d" %
|
|
(len(per_replica_values), self._all_reduce_spec, self._num_packs), 10)
|
|
|
|
device_grads = _group_value_by_device(per_replica_values)
|
|
|
|
# The all-reduce library requires fully defined shapes.
|
|
# TODO(yuefengz): when tensor sharding is not needed, static shapes are not
|
|
# required as well.
|
|
for device_grad in device_grads:
|
|
for grad, _ in device_grad:
|
|
if not grad.shape.is_fully_defined():
|
|
raise ValueError("Shape is unknown for node %r" % grad)
|
|
|
|
remaining_grads = device_grads
|
|
aggregated_grads = []
|
|
for spec_tuple in self._all_reduce_spec:
|
|
if spec_tuple.limit < 0:
|
|
this_grads = remaining_grads
|
|
remaining_grads = []
|
|
else:
|
|
(this_grads, remaining_grads) = cross_device_utils.split_grads_by_size(
|
|
spec_tuple.limit, remaining_grads)
|
|
if this_grads:
|
|
device_grad_packs, tensor_packer = _pack_tensors(
|
|
this_grads, self._num_packs)
|
|
range_agg_grads = cross_device_utils.sum_gradients_all_reduce(
|
|
self._worker_devices, device_grad_packs, len(self._worker_devices),
|
|
spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker))
|
|
range_agg_grads = _unpack_tensors(range_agg_grads, tensor_packer)
|
|
|
|
if not aggregated_grads:
|
|
aggregated_grads = range_agg_grads
|
|
else:
|
|
assert len(aggregated_grads) == len(range_agg_grads)
|
|
for i, range_agg_grad in enumerate(range_agg_grads):
|
|
aggregated_grads[i] += range_agg_grad
|
|
assert not remaining_grads
|
|
|
|
return _ungroup_and_make_mirrored(aggregated_grads, per_replica_values[0],
|
|
reduce_op)
|
|
|
|
|
|
@tf_export("distribute.experimental.CollectiveCommunication")
|
|
class CollectiveCommunication(enum.Enum):
|
|
"""Communication choices for CollectiveOps.
|
|
|
|
* `AUTO`: Default to runtime's automatic choices.
|
|
* `RING`: TensorFlow's ring algorithms for all-reduce and
|
|
all-gather.
|
|
* `NCCL`: Use ncclAllReduce for all-reduce, and ring algorithms for
|
|
all-gather.
|
|
"""
|
|
AUTO = "AUTO"
|
|
RING = "RING"
|
|
NCCL = "NCCL"
|
|
# TODO(ayushd): add ncclAllGather implementation.
|
|
|
|
|
|
# TODO(yuefengz): support in-graph collective all-reduce.
|
|
class CollectiveAllReduce(CrossDeviceOps):
|
|
"""All-reduce cross device ops using collective ops.
|
|
|
|
In the between-graph replicated training, it will still do all-reduces across
|
|
all workers and then put results on the right destinations.
|
|
"""
|
|
|
|
def __init__(self,
|
|
devices,
|
|
group_size,
|
|
collective_keys=None,
|
|
communication=CollectiveCommunication.AUTO):
|
|
"""Initializes the object.
|
|
|
|
Args:
|
|
devices: a list of device strings to run collectives on.
|
|
group_size: the global group size. For between-graph replicated training
|
|
it's the total number of devices across all workers.
|
|
collective_keys: an optional CollectiveKey object.
|
|
communication: indicates which collective communication to use.
|
|
"""
|
|
if group_size % len(devices) > 0:
|
|
raise ValueError("group_size must be divisible by the number of devices.")
|
|
|
|
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
|
self._group_size = group_size
|
|
self._collective_keys = (collective_keys or
|
|
cross_device_utils.CollectiveKeys())
|
|
self._communication = communication
|
|
# In a multi threaded eager program we need to ensure different groups of
|
|
# collectives don't interleave each other, otherwise there will be deadlock.
|
|
self._lock = threading.Lock()
|
|
|
|
# Collective ops requires all devices to participate and is blocking. In
|
|
# eager, we need one async executor for each device to be able to launch
|
|
# them altogether. Note that async doesn't imply concurrency. Within an
|
|
# async executor operations are still executed sequentially. In graph or
|
|
# function building, the executors are not used.
|
|
self._executors = []
|
|
for _ in range(len(devices)):
|
|
self._executors.append(executor.new_executor(enable_async=True))
|
|
|
|
super(CollectiveAllReduce, self).__init__()
|
|
|
|
@property
|
|
def _num_between_graph_workers(self):
|
|
# Currently we only support equal number of devices on each worker.
|
|
return self._group_size / len(self._devices)
|
|
|
|
def reduce_implementation(self, reduce_op, per_replica_value, destinations,
|
|
experimental_hints):
|
|
all_reduced = self._batch_all_reduce(reduce_op, [per_replica_value],
|
|
experimental_hints)[0]
|
|
devices = get_devices_from(destinations)
|
|
|
|
if _devices_match(per_replica_value, destinations):
|
|
return all_reduced
|
|
|
|
# Convert `all_reduced` to a `Mirrored` object, as a simple and uniform
|
|
# utility to access component for a particular device.
|
|
if not isinstance(all_reduced, value_lib.Mirrored):
|
|
all_reduced = value_lib.Mirrored([all_reduced])
|
|
|
|
# If we got this far, the destination devices do not match the all-reduce
|
|
# devices, so we must map from one to the other.
|
|
index = []
|
|
# We must add these control dependencies, otherwise we can get deadlock.
|
|
with ops.control_dependencies(all_reduced.values):
|
|
for d in devices:
|
|
with ops.device(d):
|
|
for v in all_reduced.values:
|
|
if v.device == d:
|
|
index.append(array_ops.identity(v))
|
|
break
|
|
else:
|
|
# TODO(josh11b): Once we add support for model parallelism, get the
|
|
# copy from the corresponding replica instead of the primary.
|
|
index.append(array_ops.identity(all_reduced._primary)) # pylint: disable=protected-access
|
|
return distribute_utils.regroup(index, wrap_class=value_lib.Mirrored)
|
|
|
|
def batch_reduce_implementation(self, reduce_op, value_destination_pairs,
|
|
experimental_hints):
|
|
all_devices_match = _all_devices_match(value_destination_pairs)
|
|
if all_devices_match:
|
|
return self._batch_all_reduce(reduce_op,
|
|
[v[0] for v in value_destination_pairs],
|
|
experimental_hints)
|
|
else:
|
|
if not all_devices_match:
|
|
logging.log_first_n(
|
|
logging.WARN, "Efficient batch_reduce is not supported if "
|
|
"destinations are different.", 10)
|
|
|
|
return [
|
|
self.reduce_implementation(reduce_op, value, dest, experimental_hints)
|
|
for value, dest in value_destination_pairs
|
|
]
|
|
|
|
def _batch_all_reduce(self, reduce_op, per_replica_values,
|
|
experimental_hints):
|
|
"""All reduce algorithm in a batch."""
|
|
dense_values, dense_indices, sparse_values, sparse_indices = (
|
|
cross_device_utils.split_by_sparsity(per_replica_values))
|
|
if dense_values:
|
|
dense_results = self._do_batch_all_reduce_dense(reduce_op, dense_values,
|
|
experimental_hints)
|
|
else:
|
|
dense_results = []
|
|
if sparse_values:
|
|
sparse_results = self._do_batch_all_reduce_sparse(reduce_op,
|
|
sparse_values)
|
|
else:
|
|
sparse_results = []
|
|
return cross_device_utils.stitch_values(
|
|
((dense_results, dense_indices), (sparse_results, sparse_indices)))
|
|
|
|
def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
|
|
experimental_hints):
|
|
"""All-reduce across all workers in a batch."""
|
|
|
|
batch_size = len(per_replica_values)
|
|
# Pass self._communication to the runtime as a communication hint.
|
|
communication = self._communication.value
|
|
# For now, we use NCCL only when batch_size > 1.
|
|
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
|
# is NCCL.
|
|
if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
|
|
communication = CollectiveCommunication.AUTO.value
|
|
|
|
# Reverse the lists so that there's better chance that values follows
|
|
# the order in which they are calculated (e.g. when they're gradients), so
|
|
# as to overlap calculation with communication. However, this may not be
|
|
# optimal for cases like gradients of complicated non-sequential models.
|
|
#
|
|
# Note that we reverse the list before packing so that the first pack won't
|
|
# be too small, since it's more likely for first few packs to have long
|
|
# queuing time due to concurrent intense computation.
|
|
#
|
|
# TODO(b/147393503): explore solutions for optimal ordering.
|
|
packs = cross_device_utils.pack_by_size(
|
|
list(reversed(per_replica_values)), experimental_hints.bytes_per_pack)
|
|
|
|
if batch_size > 1:
|
|
logging.info(
|
|
"Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
|
|
"group_size = %d, communication_hint = %s, num_packs = %d",
|
|
batch_size, len(self._devices), self._group_size, communication,
|
|
len(packs))
|
|
else:
|
|
logging.log_first_n(
|
|
logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
|
|
"num_devices = %d, group_size = %d, communication_hint = %s, "
|
|
"num_packs = %d" % (batch_size, len(
|
|
self._devices), self._group_size, communication, len(packs)), 10)
|
|
|
|
reduced_values = []
|
|
for pack in packs:
|
|
# By placing all CollectiveReduce ops in a pack under single name scope,
|
|
# we ensure they will be picked up by the `ScopedAllocator` grappler
|
|
# optimizer and packed into a single all-reduce.
|
|
with self._lock, ops.name_scope("allreduce"):
|
|
for per_replica in pack:
|
|
# Add control dependencies per device from the last gradients to the
|
|
# current set, in order to serialize NCCL launches.
|
|
if (communication == CollectiveCommunication.NCCL.value and
|
|
reduced_values):
|
|
control_inputs = list(reduced_values[-1])
|
|
else:
|
|
control_inputs = None
|
|
reduced_values.append(
|
|
cross_device_utils.build_collective_reduce(
|
|
per_replica.values,
|
|
self._devices,
|
|
self._group_size,
|
|
self._collective_keys,
|
|
"Add",
|
|
"Id",
|
|
communication,
|
|
control_inputs,
|
|
executors=self._executors))
|
|
|
|
mirrored = []
|
|
# Reverse the order of reduced value to recover the order in the input.
|
|
for value in reversed(reduced_values):
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
for i, v in enumerate(value):
|
|
with ops.device(v.device):
|
|
value[i] = v / self._group_size
|
|
mirrored.append(
|
|
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
|
return mirrored
|
|
|
|
def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values):
|
|
"""All-reduce IndexedSlices across all workers in a batch."""
|
|
|
|
logging.log_first_n(
|
|
logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
|
|
"%d all-reduces, group_size = %d" %
|
|
(len(per_replica_values), self._group_size), 10)
|
|
|
|
# Pass self._communication to the runtime as a communication hint.
|
|
communication_hint = self._communication.value
|
|
# For now, we use NCCL only when batch_size > 1.
|
|
# TODO(b/132575814): switch to NCCL for all collectives when communication
|
|
# is NCCL.
|
|
if self._communication == CollectiveCommunication.NCCL and len(
|
|
per_replica_values) == 1:
|
|
communication_hint = CollectiveCommunication.AUTO.value
|
|
|
|
gathered_values = []
|
|
with ops.name_scope("allreduce"):
|
|
for per_replica in per_replica_values:
|
|
gathered_values.append(
|
|
cross_device_utils.build_collective_gather_indexed_slices(
|
|
per_replica.values, self._devices, self._group_size,
|
|
self._collective_keys, communication_hint))
|
|
|
|
mirrored = []
|
|
for value in gathered_values:
|
|
if reduce_op == reduce_util.ReduceOp.MEAN:
|
|
# Assume each worker has the same number of replicas.
|
|
for i, v in enumerate(value):
|
|
with ops.device(v.device):
|
|
value[i].values = value[i].values / self._group_size
|
|
mirrored.append(
|
|
distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
|
|
return mirrored
|
|
|
|
def __deepcopy__(self, memo):
|
|
# distribute_coordinator deep-copies the strategy object, so
|
|
# CollectiveAllReduce needs to support deep copy as well.
|
|
return CollectiveAllReduce(self._devices, self._group_size,
|
|
self._collective_keys, self._communication)
|
|
|
|
|
|
def choose_the_best(devices, session_config=None):
|
|
"""Find the best CrossDeviceOps locally given a `tf.compat.v1.ConfigProto`.
|
|
|
|
Args:
|
|
devices: a list of devices passed to `tf.distribute.Strategy`.
|
|
session_config: a `tf.compat.v1.ConfigProto` or `None`. If `None`, it will
|
|
make decision based on all logical devices.
|
|
|
|
Returns:
|
|
A subclass of `CrossDeviceOps`.
|
|
"""
|
|
requested_devices = set(device_util.canonicalize(d) for d in devices)
|
|
if ops.executing_eagerly_outside_functions():
|
|
logical_gpus = context.context().list_logical_devices(device_type="GPU")
|
|
physical_gpus = context.context().list_physical_devices(device_type="GPU")
|
|
if len(logical_gpus) != len(physical_gpus):
|
|
logging.warning("NCCL is not supported when using virtual GPUs, falling"
|
|
"back to reduction to one device")
|
|
return ReductionToOneDevice()
|
|
|
|
machine_devices = context.context().list_logical_devices()
|
|
else:
|
|
machine_devices = device_lib.list_local_devices(
|
|
session_config=session_config)
|
|
using_devices = set()
|
|
for d in machine_devices:
|
|
if device_util.canonicalize(d.name) in requested_devices:
|
|
using_devices.add(d.name)
|
|
|
|
if len(using_devices) != len(requested_devices):
|
|
logging.warning(
|
|
"Some requested devices in `tf.distribute.Strategy` are not visible "
|
|
"to TensorFlow: %s", ",".join(list(requested_devices - using_devices)))
|
|
|
|
if any("gpu" not in d.lower() for d in requested_devices):
|
|
logging.warning("There are non-GPU devices in `tf.distribute.Strategy`, "
|
|
"not using nccl allreduce.")
|
|
return ReductionToOneDevice()
|
|
|
|
if kernels.get_registered_kernels_for_op("NcclAllReduce"):
|
|
return NcclAllReduce(num_packs=1)
|
|
else:
|
|
logging.warning("Nccl kernel is not found, not using nccl allreduce.")
|
|
return ReductionToOneDevice()
|