STT-tensorflow/tensorflow/python/distribute/cross_device_utils.py
Ran Chen c3f422ffb9 Support V2 collective in dist strat
With V2 collective ops we can defer instance key assignment to execution time. This avoid the problem that different worker may retrace at different pace.

We use capture_call_time_value instead of a instance key tf.Variable. The latter doesn't work if there're exceptions in the execution since the variable on different workers can go out of sync.

PiperOrigin-RevId: 338175667
Change-Id: Ie0f5e607a25485c4c10de4a6cae137cb2b7ad729
2020-10-20 18:29:04 -07:00

738 lines
27 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for cross_device_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import threading
from tensorflow.python.distribute import values as value_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
from tensorflow.python.platform import tf_logging as logging
INSTANCE_KEY_START_NUMBER = 100
def aggregate_gradients_using_nccl(replica_grads):
"""Aggregate gradients using nccl allreduce."""
agg_all_g_and_v = []
for single_g_and_v in zip(*replica_grads):
single_grads = [g for g, _ in single_g_and_v]
agg_grads = nccl_ops.all_sum(single_grads)
agg_all_g_and_v.append(
[(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)])
agg_all_g_and_v = list(zip(*agg_all_g_and_v))
return agg_all_g_and_v
def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads):
"""Aggregate gradients using hierarchical copies.
Args:
avail_devices: available GPU devices.
replica_grads: List of lists of (gradient, variable) tuples. The outer list
is over replicas. The inner list is over individual gradients.
Returns:
The list of (aggregated_gradient, variable), where the gradient has been
summed across all replicas and the variable is chosen from the first
replica.
"""
# This only works for DGX-1 type of machine topology
# Device peer to peer matrix
# DMA: 0 1 2 3 4 5 6 7
# 0: Y Y Y Y Y N N N
# 1: Y Y Y Y N Y N N
# 2: Y Y Y Y N N Y N
# 3: Y Y Y Y N N N Y
# 4: Y N N N Y Y Y Y
# 5: N Y N N Y Y Y Y
# 6: N N Y N Y Y Y Y
# 7: N N N Y Y Y Y Y
agg_grads = []
num_devices = len(avail_devices)
# In the special case of DGX-1 machine topology, the two groups have equal
# size.
group_size = num_devices // 2
for i, single_grads in enumerate(zip(*replica_grads)):
group_0_main_device = i % num_devices
group_1_main_device = (group_0_main_device + group_size) % num_devices
if group_0_main_device < group_size:
group_0_begin = 0
group_1_begin = group_size
else:
group_0_begin = group_size
group_1_begin = 0
# Aggregate the first group.
group_0_device_grads = single_grads[group_0_begin:
group_0_begin + group_size]
with ops.device(avail_devices[group_0_main_device]):
group_0_agg_grads, _ = aggregate_single_gradient_using_copy(
group_0_device_grads, False, False)
# Aggregate the second group.
group_1_device_grads = single_grads[group_1_begin:
group_1_begin + group_size]
with ops.device(avail_devices[group_1_main_device]):
group_1_agg_grads, _ = aggregate_single_gradient_using_copy(
group_1_device_grads, False, False)
# Aggregate between the groups.
with ops.device(avail_devices[group_0_main_device]):
(agg_total_grads, _), _ = aggregate_single_gradient_using_copy(
[group_0_agg_grads, group_1_agg_grads], False, False)
# Broadcast the result back into the root of each group.
with ops.device(avail_devices[group_0_main_device]):
group_0_agg_grads_bcast = array_ops.identity(agg_total_grads)
with ops.device(avail_devices[group_1_main_device]):
group_1_agg_grads_bcast = array_ops.identity(agg_total_grads)
agg_grads_bcast = []
for j in range(len(single_grads)):
with ops.device(avail_devices[j]):
# Broadcast the result back to each member in the group from the root.
if (group_0_main_device < group_size) == (j < group_size):
src_device_grad = group_0_agg_grads_bcast
else:
src_device_grad = group_1_agg_grads_bcast
agg_grads_bcast.append(array_ops.identity(src_device_grad))
agg_grads.append(
[(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)])
agg_grads = list(zip(*agg_grads))
return agg_grads
def aggregate_single_gradient_using_copy(grad_and_vars, use_mean,
check_inf_nan):
"""Calculate the average gradient for a shared variable across all replicas.
Note that this function provides a synchronization point across all replicas.
Args:
grad_and_vars: A list or tuple of (gradient, variable) tuples. Each
(gradient, variable) pair within the outer list represents the gradient
of the variable calculated for a single replica, and the number of pairs
equals the number of replicas.
use_mean: if True, mean is taken, else sum of gradients is taken.
check_inf_nan: check grads for nans and infs.
Returns:
The tuple ([(average_gradient, variable),], has_nan_or_inf) where the
gradient has been averaged across all replicas. The variable is chosen
from the first replica. The has_nan_or_inf indicates the grads has nan or
inf.
"""
grads = [g for g, _ in grad_and_vars]
grad = math_ops.add_n(grads)
if use_mean and len(grads) > 1:
grad = array_ops.multiply(grad, 1.0 / len(grads))
v = grad_and_vars[0][1]
if check_inf_nan:
has_nan_or_inf = array_ops.logical_not(
array_ops.reduce_all(array_ops.is_finite(grads)))
return (grad, v), has_nan_or_inf
else:
return (grad, v), None
# TODO(yuefengz): use random key starts to avoid reusing keys?
class CollectiveKeys(object):
"""Class that manages collective keys.
We need to manage three different keys for collective:
*Group key*: an integer key to identify the set of cooperative devices.
Collective ops work under the same set of devices must using the same group
key.
*Instance key*: an integer key to identify the set of same counterpart of
tensors on different devices in a device group that need to be all-reduced.
This class is thread safe.
"""
def __init__(self, group_key_start=1):
"""Initializes the object.
Args:
group_key_start: the starting integer of group key.
"""
self._group_key = group_key_start
self._group_key_table = {}
self._instance_key_table = {}
self._lock = threading.Lock()
def get_group_key(self, devices):
"""Returns a group key for the set of devices.
Args:
devices: a list of canonical device strings in a collective group.
Returns:
int key uniquely identifying the set of device names.
"""
key_id = hash(tuple(sorted(devices)))
with self._lock:
if key_id not in self._group_key_table:
new_key = self._group_key
self._group_key += 1
self._group_key_table[key_id] = new_key
self._instance_key_table[new_key] = {}
for device in devices:
self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER
return self._group_key_table[key_id]
def get_instance_key(self, group_key, device):
"""Returns a new instance key for use in defining a collective op.
You should call this once per each collective op of a collective instance.
Args:
group_key: the group key returned by get_group_key(). You should not
assign the group key yourself.
device: a canonical device string. It should be the device this collective
op is on.
Returns:
a new instance key.
Raises:
ValueError: when the group key is invalid or the device is not in the
group.
"""
with self._lock:
group = self._instance_key_table.get(group_key, None)
if group is None:
raise ValueError('group {} not found'.format(group_key))
if device not in group:
raise ValueError('{} not in group {}'.format(device, group_key))
v = group[device]
group[device] += 1
return v
def __deepcopy__(self, memo):
# distribute_coordinator deep-copies the strategy object, so
# CollectiveKeys needs to support deep copy as well.
copied = CollectiveKeys()
copied._group_key = self._group_key
copied._group_key_table = copy.deepcopy(self._group_key_table, memo)
copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo)
return copied
class CollectiveReplicaLauncher(object):
"""Launch collectives on one replica."""
_use_scoped_allocator = True
_use_collective_v2 = False
def __init__(self,
group_key,
group_size,
collective_keys,
device,
executor=None):
if executor and not executor.is_async():
raise ValueError('executor must be async')
self._group_key = group_key
self._group_size = group_size
self._collective_keys = collective_keys
self._device = device
self._executor = executor
def _executor_scope(self):
if context.executing_eagerly() and not self._executor:
raise ValueError('collectives requires a async executor in eager mode')
if context.executing_eagerly():
return context.executor_scope(self._executor)
return ops.NullContextmanager()
def _control_input(self, control_input):
if control_input is not None:
return ops.control_dependencies([control_input])
return ops.NullContextmanager()
def _should_use_collective_v2(self):
if not CollectiveReplicaLauncher._use_collective_v2:
return False
if not ops.executing_eagerly_outside_functions():
return False
return True
def _next_instance_key(self):
"""Returns the next instance key."""
if self._should_use_collective_v2():
# Assigning instance keys at function building time have issues since
# different workers may retrace the function at different times. With
# collective V2 we can use capture_call_time_value to use a placeholder as
# the instance key and feed it at function call time. In this way we also
# don't reuse instance keys, which allows for per-instance cancellation.
graph = ops.get_default_graph()
# Control flow ops don't work with capture_call_time_value, so we put the
# capture in the function graph of that control flow op.
while getattr(graph, 'is_control_flow_graph', False):
graph = graph.outer_graph
if not context.executing_eagerly() and graph.building_function:
with graph.as_default():
# Capture self._next_instance_key so that when building a function
# that calls another tf.function, the instance key assignment is
# further delayed until we actually call the function in eager. Note
# that capture_call_time_value doesn't automatically propagate the
# deferred capture to the outer function.
return graph.capture_call_time_value(
self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32))
else:
instance_key = self._collective_keys.get_instance_key(
self._group_key, self._device)
with ops.device('CPU:0'):
return ops.convert_to_tensor(instance_key, dtype=dtypes.int32)
else:
return self._collective_keys.get_instance_key(self._group_key,
self._device)
def all_reduce(self,
input_tensor,
control_input=None,
communication_hint='AUTO',
timeout=0):
"""All-reduce a dense tensor.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
control_input: if not None, add control edges between control_input and
the all-reduce.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
The reduced tensor.
"""
instance_key = self._next_instance_key()
with self._executor_scope(), \
ops.device(self._device), \
self._control_input(control_input):
if self._should_use_collective_v2():
return collective_ops.all_reduce_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
else:
return collective_ops.all_reduce(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
def _all_gather(self, input_tensor, communication_hint='AUTO', timeout=0):
"""All-gather a dense tensor.
This can be called in eager mode if an async executor is supplied when
creating the launcher.
Args:
input_tensor: a dense tensor. It must have the same shape on all replicas.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
The reduced tensor.
"""
instance_key = self._next_instance_key()
with self._executor_scope(), ops.device(self._device):
if self._should_use_collective_v2():
return collective_ops.all_gather_v2(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
else:
return collective_ops.all_gather(
input_tensor,
self._group_size,
self._group_key,
instance_key,
communication_hint=communication_hint,
timeout=timeout)
def batch_all_reduce(self,
input_tensor_packs,
communication_hint='AUTO',
timeout=0):
"""Batch all-reduce dense tensors.
This takes a list of batches of tensors. Using multiple batches have the
benefit that it doesn't need to wait for all inputs to be ready to start the
all-reduce.
This can be called in eager mode if a async executor is supplied when
creating the launcher.
Args:
input_tensor_packs: a list of lists of dense tensors.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
A flat list of reduced tensors.
"""
# We don't batch with concat in eager. It's easy to get it wrong because
# we need to avoid any numpy() calls on values produced by the async
# executor. This effectively disables batching in eager, but it's unlikely
# to all-reduce a large number of tensors in eager.
batch_with_concat = (not self._use_scoped_allocator and
not context.executing_eagerly())
outputs = []
for pack in input_tensor_packs:
# TODO(b/169168846): inserts a parallel all_gather to verify packings
# are the same on each replica.
if batch_with_concat:
with ops.device(self._device):
flat_tensors = [array_ops.reshape(t, [-1]) for t in pack]
shapes = [array_ops.shape(t) for t in pack]
if communication_hint == 'NCCL' and outputs:
control_input = outputs[-1]
else:
control_input = None
reduced = self.all_reduce(
array_ops.concat(flat_tensors, axis=0), control_input,
communication_hint, timeout)
num_elements = [math_ops.reduce_prod(s) for s in shapes]
flat_outputs = array_ops.split(reduced, num_elements, axis=0)
for shape, flat_output in zip(shapes, flat_outputs):
outputs.append(array_ops.reshape(flat_output, shape))
else:
# By placing all CollectiveReduce ops in a batch under single name
# scope, we ensure they will be picked up by the `ScopedAllocator`
# grappler optimizer and packed into a single all-reduce.
with ops.name_scope('allreduce'):
for input_tensor in pack:
if communication_hint == 'NCCL' and outputs:
control_input = outputs[-1]
else:
control_input = None
outputs.append(
self.all_reduce(input_tensor, control_input, communication_hint,
timeout))
return outputs
def all_gather(self,
input_tensor,
axis,
communication_hint='AUTO',
timeout=0):
"""All-gather a dense tensor.
This method must be called inside a tf.function.
Args:
input_tensor: a dense tensor. It must have the same rank on all replicas,
and dimensions other than `axis` need to be the same as well.
axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the
range [0, rank(value)).
communication_hint: string providing hint to runtime for choosing
collective implementation. Available options are `AUTO`, `NCCL`, and
`RING`.
timeout: a float. The timeout in seconds.
Returns:
The gathered Tensor.
Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError('all_gather in eager mode is not supported')
with ops.device(self._device), \
ops.control_dependencies([array_ops.identity(input_tensor)]):
# 1. Transpose
# E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3,
# we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which
# brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to
# place it back.
perm_pre = array_ops.concat(
([axis], math_ops.range(axis),
math_ops.range(axis + 1, array_ops.rank(input_tensor))),
axis=0)
input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre)
# 2. Pad
gathered_shape = self._all_gather(
array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0),
communication_hint,
timeout=timeout)
first_dims = gathered_shape[:, 0]
full_axis_dim = math_ops.reduce_max(first_dims)
padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim)
# 3. Gather
gather_padded_out_tensor = self._all_gather(
padded_input_tensor, communication_hint, timeout=timeout)
# 4. Unpad
split_tensors = []
for i in range(self._group_size):
start_pos = i * full_axis_dim
split_tensors.append(gather_padded_out_tensor[start_pos:start_pos +
first_dims[i]])
out_tensor_t = array_ops.concat(split_tensors, 0)
# 5. Transpose back
perm_after = array_ops.concat(
(math_ops.range(1, axis + 1), [0],
math_ops.range(axis + 1, array_ops.rank(input_tensor_t))),
axis=0)
return array_ops.transpose(out_tensor_t, perm=perm_after)
def all_reduce_indexed_slices(self,
input_slices,
communication_hint='AUTO',
timeout=0):
"""All-reduce an IndexedSlices.
This method must be called inside a tf.function.
Args:
input_slices: an IndexedSlices.
communication_hint: string providing hint to runtime for choosing
collective implementation.
timeout: a float. The timeout in seconds.
Returns:
The reduced IndexedSlices.
Raises:
RuntimeError: if called in eager mode.
"""
if context.executing_eagerly():
raise RuntimeError(
'all_reduce_indexed_slices in eager mode is not supported')
# Current CollectiveAllGather implementations require input IndexedSlices to
# have consistent length across the board, we handle the reduction of
# IndexedSlices as follows:
# 1. Gather the lengths of IndexedSlices from all participants.
# 2. If they have consistent length, apply all_gather.
# 3. Otherwise convert IndexedSlices to dense tensors and apply
# all_reduce.
with ops.device(self._device):
def all_gather():
"""Use all_gather to aggregate `IndexedSlices`."""
all_values = self._all_gather(
input_slices.values, communication_hint, timeout=timeout)
# Add control dependency to order the all-gather.
control = [all_values] if communication_hint == 'NCCL' else []
with ops.control_dependencies(control):
all_indices = self._all_gather(
input_slices.indices, communication_hint, timeout=timeout)
return ops.IndexedSlices(
values=all_values,
indices=all_indices,
dense_shape=input_slices.dense_shape)
def densify_and_all_reduce():
"""Use all_reduce to aggregate `IndexedSlices`."""
densified = ops.convert_to_tensor(input_slices)
reduced = self.all_reduce(
densified, communication_hint=communication_hint, timeout=timeout)
# We have to convert dense grad to IndexedSlice because all_reduce()
# and all_gather() must have the same return type as required by
# control_flow_ops.cond.
return ops.IndexedSlices(
values=reduced,
indices=math_ops.range(array_ops.shape(reduced)[0]),
dense_shape=input_slices.dense_shape)
length = array_ops.shape(input_slices.indices)
all_lengths = self._all_gather(
length, communication_hint, timeout=timeout)
return control_flow_ops.cond(
math_ops.equal(
math_ops.reduce_max(all_lengths),
math_ops.reduce_min(all_lengths)), all_gather,
densify_and_all_reduce)
def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
"""Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
if any(isinstance(v, ops.IndexedSlices) for v in values):
return backprop.aggregate_indexed_slices_gradients(values)
else:
return accumulation_fn(values)
def divide_by_n_tensors_or_indexed_slices(value, n):
if isinstance(value, ops.IndexedSlices):
value = backprop.flatten_nested_indexed_slices(value)
return ops.IndexedSlices(
value.values / n, value.indices, value.dense_shape)
else:
return value / n
def copy_tensor_or_indexed_slices_to_device(value, device):
with ops.device(device):
if isinstance(value, ops.IndexedSlices):
copied_values = array_ops.identity(value.values)
copied_indices = array_ops.identity(value.indices)
copied_shape = array_ops.identity(value.dense_shape)
result = ops.IndexedSlices(copied_values, copied_indices, copied_shape)
else:
result = array_ops.identity(value)
return result
def is_indexed_slices(value):
if isinstance(value, ops.IndexedSlices):
return True
assert isinstance(value, value_lib.DistributedValues)
return all(isinstance(v, ops.IndexedSlices) for v in value.values)
def split_by_sparsity(values):
"""Split values into dense and sparse values.
Args:
values: a list of tensors or `PerReplica`s.
Returns:
Four lists:
a list of dense values, a list of their indices in `values` and
a list of sparse values, a list of their indices in `values`.
"""
dense_values = []
dense_indices = []
sparse_values = []
sparse_indices = []
for i, v in enumerate(values):
if is_indexed_slices(v):
sparse_values.append(v)
sparse_indices.append(i)
else:
dense_values.append(v)
dense_indices.append(i)
return dense_values, dense_indices, sparse_values, sparse_indices
def stitch_values(values_and_indices_list):
"""Stitch values together according to their indices.
Args:
values_and_indices_list: a list of tuples of values and indices indicating
the values and positions in the returned list.
Returns:
a stitched list of values.
"""
length = 0
for values_and_indices in values_and_indices_list:
length += len(values_and_indices[0])
result = [None] * length
for values_and_indices in values_and_indices_list:
if values_and_indices and values_and_indices[0]:
for v, i in zip(*values_and_indices):
assert result[i] is None
result[i] = v
return result
def group_by_size(input_tensors, bytes_per_pack):
"""Groups `input_tensors` into chunks of `bytes_per_pack`.
The method preserves the original order of `input_tensors`. The grouping is
best effort, each pack could have more or less bytes than `bytes_per_pack`.
It only groups values with known shape.
Args:
input_tensors: a list of Tensor.
bytes_per_pack: an integer.
Returns:
A list of packs of Tensor. All values are grouped into one pack if
`bytes_per_pack` is zero or any of the value has unknown shape.
"""
if bytes_per_pack == 0:
return [input_tensors]
packs = []
last_pack_size = 0
for value in input_tensors:
num_elements = value.shape.num_elements()
if num_elements is None:
# Can't pack values with unknown shape.
logging.warning(
'not packing values due to the unknown or inconsistent shape of %s',
value)
return [input_tensors]
size = num_elements * value.dtype.size
# Try to keep each pack as close to bytes_per_pack as possible, while each
# pack is at least bytes_per_pack large. I.E. we err on the side of having
# few but large packs.
if not packs or last_pack_size > bytes_per_pack:
packs.append([])
last_pack_size = 0
packs[-1].append(value)
last_pack_size += size
return packs
def _pad_util(input_tensor, full_axis_dim):
"""Pad the `input_tensor`'s first dimension to be `full_axis_dim`."""
missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0]
tensor_rank = array_ops.rank(input_tensor)
paddings_axis = [[0, missing_axis_dim]]
paddings = array_ops.concat([
paddings_axis,
array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32)
],
axis=0)
padded_input_tensor = array_ops.pad(input_tensor, paddings)
return padded_input_tensor