Merge pull request from tensorflow:terrytangyuan-patch-2

PiperOrigin-RevId: 294943768
Change-Id: I402760b2ff8033f73241edef2e0ebf1957666a9b
This commit is contained in:
TensorFlower Gardener 2020-02-13 10:26:24 -08:00
commit 11465ab7ae

View File

@ -32,10 +32,10 @@ def _flatten_tensors(tensors):
"""Check tensors for isomorphism and flatten.
Args:
tensors: list of T `tf.Tensor` which must all have the same shape.
tensors: list of `tf.Tensor` which must all have the same shape.
Returns:
tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors
tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
shape: the original shape of each element of input tensors
Raises:
@ -61,12 +61,12 @@ def _reshape_tensors(tensors, shape):
"""Reshape tensors flattened by _flatten_tensors.
Args:
tensors: list of T `tf.Tensor` of identical length 1D tensors.
tensors: list of `tf.Tensor` of identical length 1D tensors.
shape: list of integers describing the desired shape. Product of
the elements must equal the length of each tensor.
Returns:
list of T `tf.Tensor` which are the reshaped inputs.
list of `tf.Tensor` which are the reshaped inputs.
"""
reshaped = []
for t in tensors:
@ -79,13 +79,13 @@ def _padded_split(tensor, pieces):
"""Like split for 1D tensors but pads-out case where len % pieces != 0.
Args:
tensor: T `tf.Tensor` that must be 1D.
tensor: `tf.Tensor` that must be 1D.
pieces: a positive integer specifying the number of pieces into which
tensor should be split.
Returns:
list of T `tf.Tensor` of length pieces, which hold the values of
thin input tensor, in order. The final tensor may
list of `tf.Tensor` of length pieces, which hold the values of
thin input tensor, in order. The final tensor may
be zero-padded on the end to make its size equal to those of all
of the other tensors.
@ -132,11 +132,11 @@ def _strip_padding(tensors, pad_len):
"""Strip the suffix padding added by _padded_split.
Args:
tensors: list of T `tf.Tensor` of identical length 1D tensors.
tensors: list of `tf.Tensor` of identical length 1D tensors.
pad_len: number of elements to be stripped from the end of each tensor.
Returns:
list of T `tf.Tensor` which are the stripped inputs.
list of `tf.Tensor` which are the stripped inputs.
Raises:
ValueError: tensors must be a non-empty list of 1D tensors, and
@ -161,13 +161,13 @@ def _ragged_split(tensor, pieces):
"""Like split for 1D tensors but allows case where len % pieces != 0.
Args:
tensor: T `tf.Tensor` that must be 1D.
tensor: `tf.Tensor` that must be 1D.
pieces: a positive integer specifying the number of pieces into which
tensor should be split.
Returns:
list of T `tf.Tensor` of length pieces, which hold the values of
the input tensor, in order. The final tensor may be shorter
list of `tf.Tensor` of length pieces, which hold the values of
the input tensor, in order. The final tensor may be shorter
than the others, which will all be of equal length.
Raises:
@ -256,7 +256,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
"""Construct a subgraph performing a ring-style all-reduce of input_tensors.
Args:
input_tensors: a list of T `tf.Tensor` objects, which must all
input_tensors: a list of `tf.Tensor` objects, which must all
have the same shape and type.
num_workers: number of worker tasks spanned by input_tensors.
num_subchunks: number of subchunks each device should process in one tick.
@ -272,7 +272,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
size.
Returns:
a list of T `tf.Tensor` identical sum-reductions of input_tensors.
a list of `tf.Tensor` identical sum-reductions of input_tensors.
"""
if len(input_tensors) < 2:
raise ValueError("input_tensors must be length 2 or longer")
@ -299,7 +299,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks,
"""Construct a subgraph for the first (reduction) pass of ring all-reduce.
Args:
input_tensors: a list of T `tf.Tensor` 1D input tensors of same
input_tensors: a list of `tf.Tensor` 1D input tensors of same
shape and type.
devices: array of device name strings
num_subchunks: number of subchunks each device should process in one tick.
@ -311,7 +311,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks,
ValueError: tensors must all be one dimensional.
Returns:
list of list of T `tf.Tensor` of (partially) reduced values where
list of list of `tf.Tensor` of (partially) reduced values where
exactly num_subchunks chunks at each device are fully reduced.
"""
num_devices = len(input_tensors)
@ -360,11 +360,11 @@ def _apply_unary_to_chunks(f, chunks_by_dev):
"""Apply a unary op to each tensor in chunks_by_dev, on same device.
Args:
f: a unary function over T `tf.Tensor`.
chunks_by_dev: list of lists of T `tf.Tensor`.
f: a unary function over `tf.Tensor`.
chunks_by_dev: list of lists of `tf.Tensor`.
Returns:
new list of lists of T `tf.Tensor` with the same structure as
new list of lists of `tf.Tensor` with the same structure as
chunks_by_dev containing the derived tensors.
"""
output = []
@ -381,14 +381,14 @@ def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
Args:
pred_by_s_d: as produced by _ring_permutations
rank_by_s_d: as produced by _ring_permutations
chunks_by_dev: list of list of T `tf.Tensor` indexed by ints
chunks_by_dev: list of list of `tf.Tensor` indexed by ints
(device, chunk)
Raises:
ValueError: chunks_by_dev is not well-formed
Returns:
list of T `tf.Tensor` which are the fully reduced tensors, one
list of `tf.Tensor` which are the fully reduced tensors, one
at each device corresponding to the outer dimension of chunks_by_dev.
"""
num_devices = len(chunks_by_dev)
@ -448,12 +448,12 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
the future with edge-case specific logic.
Args:
input_tensors: list of T `tf.Tensor` to be elementwise reduced.
input_tensors: list of `tf.Tensor` to be elementwise reduced.
red_op: a binary elementwise reduction Op.
un_op: an optional unary elementwise Op to apply to reduced values.
Returns:
list of T `tf.Tensor` which are the fully reduced tensors, one
list of `tf.Tensor` which are the fully reduced tensors, one
at each device of input_tensors.
Raises:
@ -481,13 +481,13 @@ def _build_recursive_hd_gather(input_tensors, devices, red_op):
"""Construct the gather phase of recursive halving-doubling all-reduce.
Args:
input_tensors: list of T `tf.Tensor` to be elementwise reduced.
input_tensors: list of `tf.Tensor` to be elementwise reduced.
devices: a list of strings naming the devices hosting input_tensors,
which will also be used to host the (partial) reduction values.
red_op: a binary elementwise reduction Op.
Returns:
list of T `tf.Tensor` which are the fully reduced tensor shards.
list of `tf.Tensor` which are the fully reduced tensor shards.
Raises:
ValueError: num_devices not a power of 2, or tensor len not divisible
@ -522,12 +522,12 @@ def _build_recursive_hd_scatter(input_tensors, devices):
"""Construct the scatter phase of recursive halving-doubling all-reduce.
Args:
input_tensors: list of T `tf.Tensor` that are fully-reduced shards.
input_tensors: list of `tf.Tensor` that are fully-reduced shards.
devices: a list of strings naming the devices on which the reconstituted
full tensors should be placed.
Returns:
list of T `tf.Tensor` which are the fully reduced tensors.
list of `tf.Tensor` which are the fully reduced tensors.
"""
num_devices = len(devices)
num_hops = int(math.log(num_devices, 2))
@ -570,14 +570,14 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
better in the other case.
Args:
input_tensors: list of T @(tf.Tensor} values to be reduced.
input_tensors: list of `tf.Tensor` values to be reduced.
gather_devices: list of names of devices on which reduction shards
should be placed.
red_op: an n-array elementwise reduction Op
un_op: optional elementwise unary Op to be applied to fully-reduced values.
Returns:
list of T `tf.Tensor` which are the fully reduced tensors.
list of `tf.Tensor` which are the fully reduced tensors.
"""
input_tensors, shape = _flatten_tensors(input_tensors)
dst_devices = [t.device for t in input_tensors]
@ -593,14 +593,14 @@ def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
"""Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
Args:
input_tensors: list of T @(tf.Tensor} values to be reduced.
input_tensors: list of `tf.Tensor` values to be reduced.
gather_devices: list of names of devices on which reduction shards
should be placed.
red_op: the binary reduction Op
un_op: optional elementwise unary Op to be applied to fully-reduced values.
Returns:
list of T `tf.Tensor` which are the fully reduced shards.
list of `tf.Tensor` which are the fully reduced shards.
Raises:
ValueError: inputs not well-formed.
@ -630,12 +630,12 @@ def _build_shuffle_scatter(reduced_shards, dst_devices):
"""Build the scatter phase of shuffle all-reduce.
Args:
reduced_shards: list of T @(tf.Tensor} fully reduced shards
reduced_shards: list of `tf.Tensor` fully reduced shards
dst_devices: list of names of devices at which the fully-reduced value
should be reconstituted.
Returns:
list of T `tf.Tensor` scattered tensors.
list of `tf.Tensor` scattered tensors.
"""
num_devices = len(dst_devices)
out_tensors = []
@ -650,7 +650,7 @@ def _split_by_task(devices, values):
Args:
devices: list of device name strings
values: list of T `tf.tensor` of same length as devices.
values: list of `tf.Tensor` of same length as devices.
Returns:
(per_task_devices, per_task_values) where both values are
@ -686,14 +686,14 @@ def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
"""Build a subgraph that does one full all-reduce, using NCCL.
Args:
input_tensors: list of T `tf.Tensor` of same-shape and type values to
input_tensors: list of `tf.Tensor` of same-shape and type values to
be reduced.
red_op: binary elementwise reduction operator. Must be one of
red_op: binary elementwise reduction operator. Must be one of
{tf.add}
un_op: optional unary elementwise Op to apply to fully-reduce values.
Returns:
list of T `tf.Tensor` of reduced values.
list of `tf.Tensor` of reduced values.
Raises:
ValueError: red_op not supported.
@ -715,14 +715,14 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
"""Construct a subgraph for NCCL hybrid all-reduce.
Args:
input_tensors: list of T `tf.Tensor` of same-shape and type values to
input_tensors: list of `tf.Tensor` of same-shape and type values to
be reduced.
red_op: binary elementwise reduction operator.
upper_level_f: function for reducing one value per worker, across
workers.
Returns:
list of T `tf.Tensor` of reduced values.
list of `tf.Tensor` of reduced values.
Raises:
ValueError: inputs not well-formed.
@ -804,7 +804,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
"""Construct a subgraph for Shuffle hybrid all-reduce.
Args:
input_tensors: list of T `tf.Tensor` of same-shape and type values to
input_tensors: list of `tf.Tensor` of same-shape and type values to
be reduced.
gather_devices: list of device names on which to host gather shards.
red_op: binary elementwise reduction operator.
@ -812,7 +812,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
workers.
Returns:
list of T `tf.Tensor` of reduced values.
list of `tf.Tensor` of reduced values.
Raises:
ValueError: inputs not well-formed.