867 lines
32 KiB
Python
867 lines
32 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import math
|
|
|
|
from tensorflow.python.framework import device as device_lib
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import nccl_ops
|
|
|
|
|
|
def _flatten_tensors(tensors):
|
|
"""Check tensors for isomorphism and flatten.
|
|
|
|
Args:
|
|
tensors: list of `tf.Tensor` which must all have the same shape.
|
|
|
|
Returns:
|
|
tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
|
|
shape: the original shape of each element of input tensors
|
|
|
|
Raises:
|
|
ValueError: tensors are empty or non-isomorphic or have unknown shape.
|
|
"""
|
|
if not tensors:
|
|
raise ValueError("tensors cannot be empty")
|
|
shape = tensors[0].shape
|
|
for tensor in tensors:
|
|
shape = shape.merge_with(tensor.shape)
|
|
if not shape.is_fully_defined():
|
|
raise ValueError("Tensors must have statically known shape.")
|
|
if len(shape) != 1:
|
|
reshaped = []
|
|
for t in tensors:
|
|
with ops.colocate_with(t):
|
|
reshaped.append(array_ops.reshape(t, [-1]))
|
|
tensors = reshaped
|
|
return tensors, shape
|
|
|
|
|
|
def _reshape_tensors(tensors, shape):
|
|
"""Reshape tensors flattened by _flatten_tensors.
|
|
|
|
Args:
|
|
tensors: list of `tf.Tensor` of identical length 1D tensors.
|
|
shape: list of integers describing the desired shape. Product of
|
|
the elements must equal the length of each tensor.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the reshaped inputs.
|
|
"""
|
|
reshaped = []
|
|
for t in tensors:
|
|
with ops.colocate_with(t):
|
|
reshaped.append(array_ops.reshape(t, shape))
|
|
return reshaped
|
|
|
|
|
|
def _padded_split(tensor, pieces):
|
|
"""Like split for 1D tensors but pads-out case where len % pieces != 0.
|
|
|
|
Args:
|
|
tensor: `tf.Tensor` that must be 1D.
|
|
pieces: a positive integer specifying the number of pieces into which
|
|
tensor should be split.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` of length pieces, which hold the values of
|
|
thin input tensor, in order. The final tensor may
|
|
be zero-padded on the end to make its size equal to those of all
|
|
of the other tensors.
|
|
|
|
Raises:
|
|
ValueError: The input tensor is not 1D.
|
|
"""
|
|
shape = tensor.shape
|
|
if 1 != len(shape):
|
|
raise ValueError("input tensor must be 1D")
|
|
tensor_len = shape.dims[0].value
|
|
with ops.colocate_with(tensor):
|
|
if tensor_len % pieces != 0:
|
|
# pad to an even length
|
|
chunk_size = 1 + tensor_len // pieces
|
|
if pieces > tensor_len:
|
|
# This is an edge case that should not come up in practice,
|
|
# i.e. a different reduction algorithm would be better,
|
|
# but we'll make it work just for completeness.
|
|
pad_len = pieces - tensor_len
|
|
extended_whole = array_ops.concat(
|
|
[tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
|
|
parts = array_ops.split(extended_whole, pieces)
|
|
return parts, pad_len
|
|
elif (pieces - 1) * chunk_size >= tensor_len:
|
|
# Another edge case of limited real interest.
|
|
pad_len = (pieces * chunk_size) % tensor_len
|
|
extended_whole = array_ops.concat(
|
|
[tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
|
|
parts = array_ops.split(extended_whole, pieces)
|
|
return parts, pad_len
|
|
else:
|
|
last_chunk_size = tensor_len - (pieces - 1) * chunk_size
|
|
pad_len = chunk_size - last_chunk_size
|
|
piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
|
|
parts = array_ops.split(tensor, piece_lens)
|
|
parts[-1] = array_ops.concat(
|
|
[parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
|
|
return parts, pad_len
|
|
else:
|
|
return array_ops.split(tensor, pieces), 0
|
|
|
|
|
|
def _strip_padding(tensors, pad_len):
|
|
"""Strip the suffix padding added by _padded_split.
|
|
|
|
Args:
|
|
tensors: list of `tf.Tensor` of identical length 1D tensors.
|
|
pad_len: number of elements to be stripped from the end of each tensor.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the stripped inputs.
|
|
|
|
Raises:
|
|
ValueError: tensors must be a non-empty list of 1D tensors, and
|
|
each must be longer than pad_len.
|
|
"""
|
|
if not tensors:
|
|
raise ValueError("tensors cannot be empty")
|
|
shape = tensors[0].shape
|
|
if len(shape) > 1:
|
|
raise ValueError("tensors must be 1D")
|
|
prefix_len = int(shape[0] - pad_len)
|
|
if prefix_len < 0:
|
|
raise ValueError("pad_len longer than tensor")
|
|
stripped = []
|
|
for t in tensors:
|
|
with ops.colocate_with(t):
|
|
stripped.append(array_ops.slice(t, [0], [prefix_len]))
|
|
return stripped
|
|
|
|
|
|
def _ragged_split(tensor, pieces):
|
|
"""Like split for 1D tensors but allows case where len % pieces != 0.
|
|
|
|
Args:
|
|
tensor: `tf.Tensor` that must be 1D.
|
|
pieces: a positive integer specifying the number of pieces into which
|
|
tensor should be split.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` of length pieces, which hold the values of
|
|
the input tensor, in order. The final tensor may be shorter
|
|
than the others, which will all be of equal length.
|
|
|
|
Raises:
|
|
ValueError: input tensor must be 1D.
|
|
"""
|
|
shape = tensor.shape
|
|
if 1 != len(shape):
|
|
raise ValueError("input tensor must be 1D")
|
|
tensor_len = shape.dims[0].value
|
|
chunk_size = tensor_len // pieces
|
|
with ops.colocate_with(tensor):
|
|
if tensor_len != (pieces * chunk_size):
|
|
# last piece will be short
|
|
assert pieces > 1
|
|
last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
|
|
assert last_chunk_size > 0
|
|
piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
|
|
return array_ops.split(tensor, piece_lens)
|
|
else:
|
|
return array_ops.split(tensor, pieces)
|
|
|
|
|
|
def _ring_permutations(num_workers, num_subchunks, gpu_perm):
|
|
""""Generate an array of device index arrays, one for each subchunk.
|
|
|
|
In the basic ring reduction algorithm there are size(T)/num_devices
|
|
data chunks and each device process one chunk per tick, i.e. sending
|
|
one chunk and receiving one chunk. The idea of subchunking is that
|
|
each device processes num_subchunks smaller data regions per tick,
|
|
and the ring rank permutation is different for each subchunk index
|
|
so that a device is potentially sending to and receiving from
|
|
num_subchunks different other devices at each tick. Where multiple
|
|
independent data channels exist between devices, this strategy
|
|
supplies a method of using them in parallel.
|
|
|
|
Args:
|
|
num_workers: number of worker tasks
|
|
num_subchunks: number of subchunks into which to divide each per-GPU chunk.
|
|
gpu_perm: an array of integers in [0, num_gpus-1] giving the default
|
|
ring order of GPUs at each worker. Other permutations will be generated
|
|
by rotating this array and splicing together per-worker instances.
|
|
|
|
Raises:
|
|
ValueError: the number of subchunks may not exceed the number of GPUs.
|
|
|
|
Returns:
|
|
pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
|
|
preceding device in the permutation for that subchunk. The
|
|
device index of GPU i at worker j is i + (j * num_gpus).
|
|
rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
|
|
local rank of device d in the permutation for that subchunk.
|
|
"""
|
|
num_gpus = len(gpu_perm)
|
|
devices = num_workers * num_gpus
|
|
if devices == 0:
|
|
return [], []
|
|
if num_subchunks > num_gpus:
|
|
raise ValueError(
|
|
"num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
|
|
rotation_interval = max(1, int(num_gpus / num_subchunks))
|
|
perms_by_s = []
|
|
for s in range(0, num_subchunks):
|
|
full_order = []
|
|
offset = s * rotation_interval
|
|
for w in range(0, num_workers):
|
|
default_order = [(w * num_gpus) + i for i in gpu_perm]
|
|
dev_order = default_order[offset:] + default_order[:offset]
|
|
full_order += dev_order
|
|
perms_by_s.append(full_order)
|
|
pred_by_s_d = [[-1 for d in range(0, devices)]
|
|
for s in range(0, num_subchunks)]
|
|
rank_by_s_d = [[-1 for d in range(0, devices)]
|
|
for s in range(0, num_subchunks)]
|
|
for s in range(0, num_subchunks):
|
|
for d in range(0, devices):
|
|
for t in range(0, devices):
|
|
if d == perms_by_s[s][t]:
|
|
rank_by_s_d[s][d] = t
|
|
pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
|
|
break
|
|
return (pred_by_s_d, rank_by_s_d)
|
|
|
|
|
|
def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
|
|
gpu_perm, red_op, un_op=None):
|
|
"""Construct a subgraph performing a ring-style all-reduce of input_tensors.
|
|
|
|
Args:
|
|
input_tensors: a list of `tf.Tensor` objects, which must all
|
|
have the same shape and type.
|
|
num_workers: number of worker tasks spanned by input_tensors.
|
|
num_subchunks: number of subchunks each device should process in one tick.
|
|
gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
|
|
each worker. All workers must have the same number of
|
|
GPUs with the same rank ordering. If NVLINK is available, this should
|
|
be a ring order supported by NVLINK edges.
|
|
red_op: a binary operator for elementwise reduction.
|
|
un_op: an optional unary operator to apply to fully reduced values.
|
|
|
|
Raises:
|
|
ValueError: empty input_tensors or they don't all have same
|
|
size.
|
|
|
|
Returns:
|
|
a list of `tf.Tensor` identical sum-reductions of input_tensors.
|
|
"""
|
|
if len(input_tensors) < 2:
|
|
raise ValueError("input_tensors must be length 2 or longer")
|
|
input_tensors, shape = _flatten_tensors(input_tensors)
|
|
devices = [t.device for t in input_tensors]
|
|
(pred_by_s_d, rank_by_s_d) = _ring_permutations(
|
|
num_workers, num_subchunks, gpu_perm)
|
|
chunks_by_dev, pad_len = _build_ring_gather(
|
|
input_tensors, devices,
|
|
num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
|
|
if un_op:
|
|
chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
|
|
output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
|
|
chunks_by_dev)
|
|
if pad_len > 0:
|
|
output_tensors = _strip_padding(output_tensors, pad_len)
|
|
if len(shape) != 1:
|
|
output_tensors = _reshape_tensors(output_tensors, shape)
|
|
return output_tensors
|
|
|
|
|
|
def _build_ring_gather(input_tensors, devices, num_subchunks,
|
|
pred_by_s_d, rank_by_s_d, red_op):
|
|
"""Construct a subgraph for the first (reduction) pass of ring all-reduce.
|
|
|
|
Args:
|
|
input_tensors: a list of `tf.Tensor` 1D input tensors of same
|
|
shape and type.
|
|
devices: array of device name strings
|
|
num_subchunks: number of subchunks each device should process in one tick.
|
|
pred_by_s_d: as produced by _ring_permutations
|
|
rank_by_s_d: as produced by _ring_permutations
|
|
red_op: a binary operator for elementwise reduction
|
|
|
|
Raises:
|
|
ValueError: tensors must all be one dimensional.
|
|
|
|
Returns:
|
|
list of list of `tf.Tensor` of (partially) reduced values where
|
|
exactly num_subchunks chunks at each device are fully reduced.
|
|
"""
|
|
num_devices = len(input_tensors)
|
|
if num_devices == 0:
|
|
return []
|
|
if num_devices == 1:
|
|
return input_tensors
|
|
shape = input_tensors[0].shape
|
|
if 1 != len(shape):
|
|
raise ValueError("input tensors must be 1D")
|
|
num_chunks = num_devices * num_subchunks
|
|
num_ticks = num_devices - 1
|
|
# Initialize chunks_by_dev with splits of the input tensors.
|
|
chunks_by_dev = []
|
|
split_pad_len = 0
|
|
for d in range(0, num_devices):
|
|
with ops.device(devices[d]):
|
|
splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
|
|
chunks_by_dev.append(splits)
|
|
# Reduction phase
|
|
for tick in range(0, num_ticks):
|
|
# One new partial reduction for every chunk
|
|
new_partial_reductions = [None for _ in range(0, num_chunks)]
|
|
# Compute reductions with respect to last tick's values
|
|
for d in range(0, num_devices):
|
|
with ops.device(devices[d]):
|
|
for s in range(0, num_subchunks):
|
|
rank = rank_by_s_d[s][d]
|
|
seg_index = (rank + num_devices - (2 + tick)) % num_devices
|
|
pred_dev = pred_by_s_d[s][d]
|
|
chunk_index = (seg_index * num_subchunks) + s
|
|
new_partial_reductions[chunk_index] = red_op(
|
|
chunks_by_dev[pred_dev][chunk_index],
|
|
chunks_by_dev[d][chunk_index])
|
|
# Update chunks_by_dev with the new values at the end of the tick.
|
|
for d in range(0, num_devices):
|
|
for s in range(0, num_subchunks):
|
|
rank = rank_by_s_d[s][d]
|
|
seg_index = (rank + num_devices - (2 + tick)) % num_devices
|
|
chunk_index = (seg_index * num_subchunks) + s
|
|
chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
|
|
return chunks_by_dev, split_pad_len
|
|
|
|
|
|
def _apply_unary_to_chunks(f, chunks_by_dev):
|
|
"""Apply a unary op to each tensor in chunks_by_dev, on same device.
|
|
|
|
Args:
|
|
f: a unary function over `tf.Tensor`.
|
|
chunks_by_dev: list of lists of `tf.Tensor`.
|
|
|
|
Returns:
|
|
new list of lists of `tf.Tensor` with the same structure as
|
|
chunks_by_dev containing the derived tensors.
|
|
"""
|
|
output = []
|
|
for x in chunks_by_dev:
|
|
with ops.colocate_with(x[0]):
|
|
output.append([f(t) for t in x])
|
|
return output
|
|
|
|
|
|
def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
|
|
chunks_by_dev):
|
|
"""Construct subgraph for second (scatter) pass of ring all-reduce.
|
|
|
|
Args:
|
|
pred_by_s_d: as produced by _ring_permutations
|
|
rank_by_s_d: as produced by _ring_permutations
|
|
chunks_by_dev: list of list of `tf.Tensor` indexed by ints
|
|
(device, chunk)
|
|
|
|
Raises:
|
|
ValueError: chunks_by_dev is not well-formed
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the fully reduced tensors, one
|
|
at each device corresponding to the outer dimension of chunks_by_dev.
|
|
"""
|
|
num_devices = len(chunks_by_dev)
|
|
num_chunks = len(chunks_by_dev[0])
|
|
if 0 != num_chunks % num_devices:
|
|
raise ValueError(
|
|
"Expect number of chunks per device to be divisible by num_devices")
|
|
num_subchunks = int(num_chunks / num_devices)
|
|
num_ticks = num_devices - 1
|
|
for tick in range(0, num_ticks):
|
|
passed_values = [None for _ in range(0, num_chunks)]
|
|
for d in range(0, num_devices):
|
|
with ops.colocate_with(chunks_by_dev[d][0]):
|
|
for s in range(0, num_subchunks):
|
|
rank = rank_by_s_d[s][d]
|
|
seg_index = (rank + num_devices - (1 + tick)) % num_devices
|
|
pred_dev = pred_by_s_d[s][d]
|
|
chunk_index = (seg_index * num_subchunks) + s
|
|
passed_values[chunk_index] = array_ops.identity(
|
|
chunks_by_dev[pred_dev][chunk_index])
|
|
for d in range(0, num_devices):
|
|
for s in range(0, num_subchunks):
|
|
rank = rank_by_s_d[s][d]
|
|
seg_index = (rank + num_devices - (1 + tick)) % num_devices
|
|
chunk_index = (seg_index * num_subchunks) + s
|
|
chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
|
|
# Join chunks at each device.
|
|
output = []
|
|
for x in chunks_by_dev:
|
|
with ops.colocate_with(x[0]):
|
|
output.append(array_ops.concat(x, 0))
|
|
return output
|
|
|
|
|
|
def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
|
|
"""Construct a subgraph for recursive halving-doubling all-reduce.
|
|
|
|
The recursive halving-doubling algorithm is described in
|
|
(Thakur et al., 2015).
|
|
|
|
The concept is to arrange the participating n devices in
|
|
a linear sequence where devices exchange data pairwise
|
|
with one other device in each round. During the gather
|
|
phase there are lg(n) rounds where devices exchange
|
|
increasingly smaller sub-tensors with another device
|
|
at increasingly greater distances, until at the top
|
|
each device has 1/n of the fully reduced values. During the
|
|
scatter phase each device exchanges its fully reduced
|
|
sub-tensor (which doubles in length at each round)
|
|
with one other device at increasingly smaller distances
|
|
until each device has all of the fully reduced values.
|
|
|
|
Note: this preliminary version requires that len(input_tensors) be a
|
|
power of 2. TODO(tucker): relax this restriction. Also, the
|
|
number of elements in each tensor must be divisible by 2^h where h
|
|
is the number of hops in each phase. This will also be relaxed in
|
|
the future with edge-case specific logic.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` to be elementwise reduced.
|
|
red_op: a binary elementwise reduction Op.
|
|
un_op: an optional unary elementwise Op to apply to reduced values.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the fully reduced tensors, one
|
|
at each device of input_tensors.
|
|
|
|
Raises:
|
|
ValueError: num_devices not a power of 2, or tensor len not divisible
|
|
by 2 the proper number of times.
|
|
|
|
References:
|
|
Optimization of Collective Communication Operations in MPICH:
|
|
[Thakur et al., 2005]
|
|
(https://journals.sagepub.com/doi/abs/10.1177/1094342005051521)
|
|
([pdf](http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf))
|
|
"""
|
|
devices = [t.device for t in input_tensors]
|
|
input_tensors, shape = _flatten_tensors(input_tensors)
|
|
reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
|
|
if un_op:
|
|
reduced_shards = [un_op(t) for t in reduced_shards]
|
|
output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
|
|
if len(shape) != 1:
|
|
output_tensors = _reshape_tensors(output_tensors, shape)
|
|
return output_tensors
|
|
|
|
|
|
def _build_recursive_hd_gather(input_tensors, devices, red_op):
|
|
"""Construct the gather phase of recursive halving-doubling all-reduce.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` to be elementwise reduced.
|
|
devices: a list of strings naming the devices hosting input_tensors,
|
|
which will also be used to host the (partial) reduction values.
|
|
red_op: a binary elementwise reduction Op.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the fully reduced tensor shards.
|
|
|
|
Raises:
|
|
ValueError: num_devices not a power of 2, or tensor len not divisible
|
|
by 2 the proper number of times.
|
|
"""
|
|
num_devices = len(devices)
|
|
num_hops = int(math.log(num_devices, 2))
|
|
if num_devices != (2 ** num_hops):
|
|
raise ValueError("num_devices must be a power of 2")
|
|
chunks = input_tensors
|
|
for h in range(0, num_hops):
|
|
span = 2 ** h
|
|
group_size = span * 2
|
|
new_chunks = [[] for _ in devices]
|
|
for d in range(0, num_devices):
|
|
if (d % group_size) >= (group_size / 2):
|
|
# skip right half of a pair
|
|
continue
|
|
left_dev = devices[d]
|
|
right_dev = devices[d + span]
|
|
left_split = array_ops.split(chunks[d], 2)
|
|
right_split = array_ops.split(chunks[d+span], 2)
|
|
with ops.device(left_dev):
|
|
new_chunks[d] = red_op(left_split[0], right_split[0])
|
|
with ops.device(right_dev):
|
|
new_chunks[d + span] = red_op(left_split[1], right_split[1])
|
|
chunks = new_chunks
|
|
return chunks
|
|
|
|
|
|
def _build_recursive_hd_scatter(input_tensors, devices):
|
|
"""Construct the scatter phase of recursive halving-doubling all-reduce.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` that are fully-reduced shards.
|
|
devices: a list of strings naming the devices on which the reconstituted
|
|
full tensors should be placed.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the fully reduced tensors.
|
|
"""
|
|
num_devices = len(devices)
|
|
num_hops = int(math.log(num_devices, 2))
|
|
assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
|
|
chunks = input_tensors
|
|
for h in reversed(range(0, num_hops)):
|
|
span = 2 ** h
|
|
group_size = span * 2
|
|
new_chunks = [[] for _ in devices]
|
|
for d in range(0, num_devices):
|
|
if (d % group_size) >= (group_size / 2):
|
|
# skip right half of a pair
|
|
continue
|
|
left_idx = d
|
|
right_idx = d + span
|
|
left_dev = devices[left_idx]
|
|
right_dev = devices[right_idx]
|
|
with ops.device(left_dev):
|
|
new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
|
|
chunks[right_idx]], 0)
|
|
with ops.device(right_dev):
|
|
new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
|
|
chunks[right_idx]], 0)
|
|
chunks = new_chunks
|
|
return chunks
|
|
|
|
|
|
def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
|
|
"""Construct a subgraph for shuffle all-reduce.
|
|
|
|
Shuffle reduce is essentially the algorithm implemented when using
|
|
parameter servers. Suppose tensor length is n, there are d devices
|
|
and g gather shards. Each device sends a n/g length sub-tensor to
|
|
each gather shard. The gather shards perform a reduction across d
|
|
fragments, then broadcast the result back to each device. The
|
|
devices then join the g fully reduced fragments they receive from
|
|
the shards. The gather shards could perform d-1 pairwise
|
|
reductions, or one d-way reduction. The first is better where
|
|
reduction Op time is low compared to transmission time, the second
|
|
better in the other case.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` values to be reduced.
|
|
gather_devices: list of names of devices on which reduction shards
|
|
should be placed.
|
|
red_op: an n-array elementwise reduction Op
|
|
un_op: optional elementwise unary Op to be applied to fully-reduced values.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the fully reduced tensors.
|
|
"""
|
|
input_tensors, shape = _flatten_tensors(input_tensors)
|
|
dst_devices = [t.device for t in input_tensors]
|
|
reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
|
|
red_op, un_op)
|
|
output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
|
|
if len(shape) != 1:
|
|
output_tensors = _reshape_tensors(output_tensors, shape)
|
|
return output_tensors
|
|
|
|
|
|
def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
|
|
"""Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` values to be reduced.
|
|
gather_devices: list of names of devices on which reduction shards
|
|
should be placed.
|
|
red_op: the binary reduction Op
|
|
un_op: optional elementwise unary Op to be applied to fully-reduced values.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` which are the fully reduced shards.
|
|
|
|
Raises:
|
|
ValueError: inputs not well-formed.
|
|
"""
|
|
num_source_devices = len(input_tensors)
|
|
num_gather_devices = len(gather_devices)
|
|
shape = input_tensors[0].shape
|
|
if len(shape) != 1:
|
|
raise ValueError("input_tensors must be 1D")
|
|
shards_by_source = []
|
|
for d in range(0, num_source_devices):
|
|
with ops.colocate_with(input_tensors[d]):
|
|
shards_by_source.append(
|
|
_ragged_split(input_tensors[d], num_gather_devices))
|
|
reduced_shards = []
|
|
for d in range(0, num_gather_devices):
|
|
with ops.device(gather_devices[d]):
|
|
values = [s[d] for s in shards_by_source]
|
|
red_shard = red_op(values)
|
|
if un_op:
|
|
red_shard = un_op(red_shard)
|
|
reduced_shards.append(red_shard)
|
|
return reduced_shards
|
|
|
|
|
|
def _build_shuffle_scatter(reduced_shards, dst_devices):
|
|
"""Build the scatter phase of shuffle all-reduce.
|
|
|
|
Args:
|
|
reduced_shards: list of `tf.Tensor` fully reduced shards
|
|
dst_devices: list of names of devices at which the fully-reduced value
|
|
should be reconstituted.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` scattered tensors.
|
|
"""
|
|
num_devices = len(dst_devices)
|
|
out_tensors = []
|
|
for d in range(0, num_devices):
|
|
with ops.device(dst_devices[d]):
|
|
out_tensors.append(array_ops.concat(reduced_shards, 0))
|
|
return out_tensors
|
|
|
|
|
|
def _split_by_task(devices, values):
|
|
"""Partition devices and values by common task.
|
|
|
|
Args:
|
|
devices: list of device name strings
|
|
values: list of `tf.Tensor` of same length as devices.
|
|
|
|
Returns:
|
|
(per_task_devices, per_task_values) where both values are
|
|
lists of lists with isomorphic structure: the outer list is
|
|
indexed by task, and the inner list has length of the number
|
|
of values belonging to that task. per_task_devices contains
|
|
the specific devices to which the values are local, and
|
|
per_task_values contains the corresponding values.
|
|
|
|
Raises:
|
|
ValueError: devices must be same length as values.
|
|
"""
|
|
num_devices = len(devices)
|
|
if num_devices != len(values):
|
|
raise ValueError("len(devices) must equal len(values)")
|
|
per_task_devices = collections.OrderedDict()
|
|
per_task_values = collections.OrderedDict()
|
|
for d in range(num_devices):
|
|
d_spec = device_lib.DeviceSpec.from_string(devices[d])
|
|
if not hasattr(d_spec, "task") or d_spec.task is None:
|
|
assert False, "failed to parse device %s" % devices[d]
|
|
index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
|
|
if index not in per_task_devices:
|
|
per_task_devices[index] = []
|
|
per_task_values[index] = []
|
|
per_task_devices[index].append(devices[d])
|
|
per_task_values[index].append(values[d])
|
|
|
|
return (list(per_task_devices.values()), list(per_task_values.values()))
|
|
|
|
|
|
def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
|
|
"""Build a subgraph that does one full all-reduce, using NCCL.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` of same-shape and type values to
|
|
be reduced.
|
|
red_op: binary elementwise reduction operator. Must be one of
|
|
{tf.add}
|
|
un_op: optional unary elementwise Op to apply to fully-reduce values.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` of reduced values.
|
|
|
|
Raises:
|
|
ValueError: red_op not supported.
|
|
"""
|
|
if red_op == math_ops.add:
|
|
output_tensors = nccl_ops.all_sum(input_tensors)
|
|
else:
|
|
raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
|
|
if un_op:
|
|
un_op_wrapped = []
|
|
for t in output_tensors:
|
|
with ops.colocate_with(t):
|
|
un_op_wrapped.append(un_op(t))
|
|
output_tensors = un_op_wrapped
|
|
return output_tensors
|
|
|
|
|
|
def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
|
|
"""Construct a subgraph for NCCL hybrid all-reduce.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` of same-shape and type values to
|
|
be reduced.
|
|
red_op: binary elementwise reduction operator.
|
|
upper_level_f: function for reducing one value per worker, across
|
|
workers.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` of reduced values.
|
|
|
|
Raises:
|
|
ValueError: inputs not well-formed.
|
|
"""
|
|
input_tensors, shape = _flatten_tensors(input_tensors)
|
|
devices = [t.device for t in input_tensors]
|
|
per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
|
|
num_workers = len(per_worker_devices)
|
|
up_values = [None for w in range(0, num_workers)]
|
|
up_devices = up_values[:]
|
|
down_values = up_values[:]
|
|
# First stage: reduce within each worker using NCCL
|
|
for w in range(0, num_workers):
|
|
worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
|
|
# NOTE: these reductions will not run to completion unless
|
|
# every output value is used. Since we only need one, we
|
|
# need to put control dependencies on the rest.
|
|
with ops.control_dependencies(worker_values):
|
|
with ops.device(worker_values[0].device):
|
|
up_values[w] = array_ops.identity(worker_values[0])
|
|
up_devices[w] = per_worker_devices[w][0]
|
|
# Second stage: Apply upper_level_f to reduce across first device at
|
|
# each worker
|
|
level_2_output = upper_level_f(up_values)
|
|
# Third stage: propagate within each worker using NCCL Broadcast
|
|
for w in range(0, num_workers):
|
|
dst_tensors = []
|
|
with ops.device(per_worker_devices[w][0]):
|
|
broadcast_src = nccl_ops.broadcast(array_ops.identity(level_2_output[w]))
|
|
for d in per_worker_devices[w]:
|
|
with ops.device(d):
|
|
dst_tensors.append(array_ops.identity(broadcast_src))
|
|
down_values[w] = dst_tensors
|
|
output_tensors = [v for sublist in down_values for v in sublist]
|
|
if len(shape) != 1:
|
|
output_tensors = _reshape_tensors(output_tensors, shape)
|
|
return output_tensors
|
|
|
|
|
|
def _reduce_non_singleton(input_tensors, red_f, un_op):
|
|
"""If len(input_tensors) > 1, apply red_f, else apply un_op."""
|
|
if len(input_tensors) > 1:
|
|
return red_f(input_tensors)
|
|
else:
|
|
if not un_op:
|
|
return input_tensors
|
|
output_tensors = []
|
|
for t in input_tensors:
|
|
with ops.colocate_with(t):
|
|
output_tensors.append(un_op(t))
|
|
return output_tensors
|
|
|
|
|
|
def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
|
|
"""Construct hybrid of NCCL within workers, Ring across workers."""
|
|
def upper_builder(y):
|
|
return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
|
|
def upper_level_f(x):
|
|
return _reduce_non_singleton(x, upper_builder, un_op)
|
|
return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
|
|
|
|
|
|
def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
|
|
"""Construct hybrid of NCCL within workers, Recursive-HD across workers."""
|
|
upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
|
|
return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)
|
|
|
|
|
|
def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
|
|
shuffle_red_op, un_op=None):
|
|
"""Construct hybrid of NCCL within workers, Shuffle across workers."""
|
|
def upper_level_f(x):
|
|
return build_shuffle_all_reduce(x, gather_devices, shuffle_red_op, un_op)
|
|
|
|
return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)
|
|
|
|
|
|
def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
|
|
"""Construct a subgraph for Shuffle hybrid all-reduce.
|
|
|
|
Args:
|
|
input_tensors: list of `tf.Tensor` of same-shape and type values to
|
|
be reduced.
|
|
gather_devices: list of device names on which to host gather shards.
|
|
red_op: binary elementwise reduction operator.
|
|
upper_level_f: function for reducing one value per worker, across
|
|
workers.
|
|
|
|
Returns:
|
|
list of `tf.Tensor` of reduced values.
|
|
|
|
Raises:
|
|
ValueError: inputs not well-formed.
|
|
"""
|
|
input_tensors, shape = _flatten_tensors(input_tensors)
|
|
# First stage, reduce across each worker using gather_devices.
|
|
devices = [t.device for t in input_tensors]
|
|
per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
|
|
num_workers = len(per_worker_devices)
|
|
up_values = []
|
|
if len(gather_devices) != num_workers:
|
|
raise ValueError("For shuffle hybrid, gather_devices must contain one "
|
|
"device per worker. ")
|
|
for w in range(0, num_workers):
|
|
reduced_shards = _build_shuffle_gather(
|
|
per_worker_values[w], [gather_devices[w]], red_op)
|
|
up_values.append(reduced_shards[0])
|
|
# Second stage, apply upper_level_f.
|
|
level_2_output = upper_level_f(up_values)
|
|
# Third stage, apply shuffle scatter at each worker.
|
|
output_tensors = []
|
|
for w in range(0, num_workers):
|
|
output_tensors += _build_shuffle_scatter(
|
|
[level_2_output[w]], per_worker_devices[w])
|
|
if len(shape) != 1:
|
|
output_tensors = _reshape_tensors(output_tensors, shape)
|
|
return output_tensors
|
|
|
|
|
|
def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
|
|
red_n_op, red_op, un_op=None):
|
|
"""Construct hybrid of Shuffle within workers, Ring across workers."""
|
|
def upper_builder(tensors):
|
|
return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
|
|
red_op, un_op)
|
|
def upper_level_f(tensors):
|
|
return _reduce_non_singleton(tensors, upper_builder, un_op)
|
|
return _build_shuffle_hybrid(
|
|
input_tensors, gather_devices, red_n_op, upper_level_f)
|
|
|
|
|
|
def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
|
|
second_gather_devices, red_op, un_op=None):
|
|
"""Construct hybrid of Shuffle within workers, Shuffle across workers."""
|
|
def upper_builder(tensors):
|
|
return build_shuffle_all_reduce(tensors, second_gather_devices,
|
|
red_op, un_op)
|
|
def upper_level_f(tensors):
|
|
return _reduce_non_singleton(tensors, upper_builder, un_op)
|
|
return _build_shuffle_hybrid(
|
|
input_tensors, first_gather_devices, red_op, upper_level_f)
|