Merge pull request #36713 from tensorflow:terrytangyuan-patch-2
PiperOrigin-RevId: 294943768 Change-Id: I402760b2ff8033f73241edef2e0ebf1957666a9b
This commit is contained in:
commit
11465ab7ae
@ -32,10 +32,10 @@ def _flatten_tensors(tensors):
|
||||
"""Check tensors for isomorphism and flatten.
|
||||
|
||||
Args:
|
||||
tensors: list of T `tf.Tensor` which must all have the same shape.
|
||||
tensors: list of `tf.Tensor` which must all have the same shape.
|
||||
|
||||
Returns:
|
||||
tensors: a list of T `tf.Tensor` which are flattened (1D) views of tensors
|
||||
tensors: a list of `tf.Tensor` which are flattened (1D) views of tensors
|
||||
shape: the original shape of each element of input tensors
|
||||
|
||||
Raises:
|
||||
@ -61,12 +61,12 @@ def _reshape_tensors(tensors, shape):
|
||||
"""Reshape tensors flattened by _flatten_tensors.
|
||||
|
||||
Args:
|
||||
tensors: list of T `tf.Tensor` of identical length 1D tensors.
|
||||
tensors: list of `tf.Tensor` of identical length 1D tensors.
|
||||
shape: list of integers describing the desired shape. Product of
|
||||
the elements must equal the length of each tensor.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the reshaped inputs.
|
||||
list of `tf.Tensor` which are the reshaped inputs.
|
||||
"""
|
||||
reshaped = []
|
||||
for t in tensors:
|
||||
@ -79,13 +79,13 @@ def _padded_split(tensor, pieces):
|
||||
"""Like split for 1D tensors but pads-out case where len % pieces != 0.
|
||||
|
||||
Args:
|
||||
tensor: T `tf.Tensor` that must be 1D.
|
||||
tensor: `tf.Tensor` that must be 1D.
|
||||
pieces: a positive integer specifying the number of pieces into which
|
||||
tensor should be split.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` of length pieces, which hold the values of
|
||||
thin input tensor, in order. The final tensor may
|
||||
list of `tf.Tensor` of length pieces, which hold the values of
|
||||
thin input tensor, in order. The final tensor may
|
||||
be zero-padded on the end to make its size equal to those of all
|
||||
of the other tensors.
|
||||
|
||||
@ -132,11 +132,11 @@ def _strip_padding(tensors, pad_len):
|
||||
"""Strip the suffix padding added by _padded_split.
|
||||
|
||||
Args:
|
||||
tensors: list of T `tf.Tensor` of identical length 1D tensors.
|
||||
tensors: list of `tf.Tensor` of identical length 1D tensors.
|
||||
pad_len: number of elements to be stripped from the end of each tensor.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the stripped inputs.
|
||||
list of `tf.Tensor` which are the stripped inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: tensors must be a non-empty list of 1D tensors, and
|
||||
@ -161,13 +161,13 @@ def _ragged_split(tensor, pieces):
|
||||
"""Like split for 1D tensors but allows case where len % pieces != 0.
|
||||
|
||||
Args:
|
||||
tensor: T `tf.Tensor` that must be 1D.
|
||||
tensor: `tf.Tensor` that must be 1D.
|
||||
pieces: a positive integer specifying the number of pieces into which
|
||||
tensor should be split.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` of length pieces, which hold the values of
|
||||
the input tensor, in order. The final tensor may be shorter
|
||||
list of `tf.Tensor` of length pieces, which hold the values of
|
||||
the input tensor, in order. The final tensor may be shorter
|
||||
than the others, which will all be of equal length.
|
||||
|
||||
Raises:
|
||||
@ -256,7 +256,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
|
||||
"""Construct a subgraph performing a ring-style all-reduce of input_tensors.
|
||||
|
||||
Args:
|
||||
input_tensors: a list of T `tf.Tensor` objects, which must all
|
||||
input_tensors: a list of `tf.Tensor` objects, which must all
|
||||
have the same shape and type.
|
||||
num_workers: number of worker tasks spanned by input_tensors.
|
||||
num_subchunks: number of subchunks each device should process in one tick.
|
||||
@ -272,7 +272,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
|
||||
size.
|
||||
|
||||
Returns:
|
||||
a list of T `tf.Tensor` identical sum-reductions of input_tensors.
|
||||
a list of `tf.Tensor` identical sum-reductions of input_tensors.
|
||||
"""
|
||||
if len(input_tensors) < 2:
|
||||
raise ValueError("input_tensors must be length 2 or longer")
|
||||
@ -299,7 +299,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks,
|
||||
"""Construct a subgraph for the first (reduction) pass of ring all-reduce.
|
||||
|
||||
Args:
|
||||
input_tensors: a list of T `tf.Tensor` 1D input tensors of same
|
||||
input_tensors: a list of `tf.Tensor` 1D input tensors of same
|
||||
shape and type.
|
||||
devices: array of device name strings
|
||||
num_subchunks: number of subchunks each device should process in one tick.
|
||||
@ -311,7 +311,7 @@ def _build_ring_gather(input_tensors, devices, num_subchunks,
|
||||
ValueError: tensors must all be one dimensional.
|
||||
|
||||
Returns:
|
||||
list of list of T `tf.Tensor` of (partially) reduced values where
|
||||
list of list of `tf.Tensor` of (partially) reduced values where
|
||||
exactly num_subchunks chunks at each device are fully reduced.
|
||||
"""
|
||||
num_devices = len(input_tensors)
|
||||
@ -360,11 +360,11 @@ def _apply_unary_to_chunks(f, chunks_by_dev):
|
||||
"""Apply a unary op to each tensor in chunks_by_dev, on same device.
|
||||
|
||||
Args:
|
||||
f: a unary function over T `tf.Tensor`.
|
||||
chunks_by_dev: list of lists of T `tf.Tensor`.
|
||||
f: a unary function over `tf.Tensor`.
|
||||
chunks_by_dev: list of lists of `tf.Tensor`.
|
||||
|
||||
Returns:
|
||||
new list of lists of T `tf.Tensor` with the same structure as
|
||||
new list of lists of `tf.Tensor` with the same structure as
|
||||
chunks_by_dev containing the derived tensors.
|
||||
"""
|
||||
output = []
|
||||
@ -381,14 +381,14 @@ def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
|
||||
Args:
|
||||
pred_by_s_d: as produced by _ring_permutations
|
||||
rank_by_s_d: as produced by _ring_permutations
|
||||
chunks_by_dev: list of list of T `tf.Tensor` indexed by ints
|
||||
chunks_by_dev: list of list of `tf.Tensor` indexed by ints
|
||||
(device, chunk)
|
||||
|
||||
Raises:
|
||||
ValueError: chunks_by_dev is not well-formed
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the fully reduced tensors, one
|
||||
list of `tf.Tensor` which are the fully reduced tensors, one
|
||||
at each device corresponding to the outer dimension of chunks_by_dev.
|
||||
"""
|
||||
num_devices = len(chunks_by_dev)
|
||||
@ -448,12 +448,12 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
|
||||
the future with edge-case specific logic.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T `tf.Tensor` to be elementwise reduced.
|
||||
input_tensors: list of `tf.Tensor` to be elementwise reduced.
|
||||
red_op: a binary elementwise reduction Op.
|
||||
un_op: an optional unary elementwise Op to apply to reduced values.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the fully reduced tensors, one
|
||||
list of `tf.Tensor` which are the fully reduced tensors, one
|
||||
at each device of input_tensors.
|
||||
|
||||
Raises:
|
||||
@ -481,13 +481,13 @@ def _build_recursive_hd_gather(input_tensors, devices, red_op):
|
||||
"""Construct the gather phase of recursive halving-doubling all-reduce.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T `tf.Tensor` to be elementwise reduced.
|
||||
input_tensors: list of `tf.Tensor` to be elementwise reduced.
|
||||
devices: a list of strings naming the devices hosting input_tensors,
|
||||
which will also be used to host the (partial) reduction values.
|
||||
red_op: a binary elementwise reduction Op.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the fully reduced tensor shards.
|
||||
list of `tf.Tensor` which are the fully reduced tensor shards.
|
||||
|
||||
Raises:
|
||||
ValueError: num_devices not a power of 2, or tensor len not divisible
|
||||
@ -522,12 +522,12 @@ def _build_recursive_hd_scatter(input_tensors, devices):
|
||||
"""Construct the scatter phase of recursive halving-doubling all-reduce.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T `tf.Tensor` that are fully-reduced shards.
|
||||
input_tensors: list of `tf.Tensor` that are fully-reduced shards.
|
||||
devices: a list of strings naming the devices on which the reconstituted
|
||||
full tensors should be placed.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the fully reduced tensors.
|
||||
list of `tf.Tensor` which are the fully reduced tensors.
|
||||
"""
|
||||
num_devices = len(devices)
|
||||
num_hops = int(math.log(num_devices, 2))
|
||||
@ -570,14 +570,14 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
|
||||
better in the other case.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T @(tf.Tensor} values to be reduced.
|
||||
input_tensors: list of `tf.Tensor` values to be reduced.
|
||||
gather_devices: list of names of devices on which reduction shards
|
||||
should be placed.
|
||||
red_op: an n-array elementwise reduction Op
|
||||
un_op: optional elementwise unary Op to be applied to fully-reduced values.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the fully reduced tensors.
|
||||
list of `tf.Tensor` which are the fully reduced tensors.
|
||||
"""
|
||||
input_tensors, shape = _flatten_tensors(input_tensors)
|
||||
dst_devices = [t.device for t in input_tensors]
|
||||
@ -593,14 +593,14 @@ def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
|
||||
"""Construct the gather (concentrate and reduce) phase of shuffle all-reduce.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T @(tf.Tensor} values to be reduced.
|
||||
input_tensors: list of `tf.Tensor` values to be reduced.
|
||||
gather_devices: list of names of devices on which reduction shards
|
||||
should be placed.
|
||||
red_op: the binary reduction Op
|
||||
un_op: optional elementwise unary Op to be applied to fully-reduced values.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` which are the fully reduced shards.
|
||||
list of `tf.Tensor` which are the fully reduced shards.
|
||||
|
||||
Raises:
|
||||
ValueError: inputs not well-formed.
|
||||
@ -630,12 +630,12 @@ def _build_shuffle_scatter(reduced_shards, dst_devices):
|
||||
"""Build the scatter phase of shuffle all-reduce.
|
||||
|
||||
Args:
|
||||
reduced_shards: list of T @(tf.Tensor} fully reduced shards
|
||||
reduced_shards: list of `tf.Tensor` fully reduced shards
|
||||
dst_devices: list of names of devices at which the fully-reduced value
|
||||
should be reconstituted.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` scattered tensors.
|
||||
list of `tf.Tensor` scattered tensors.
|
||||
"""
|
||||
num_devices = len(dst_devices)
|
||||
out_tensors = []
|
||||
@ -650,7 +650,7 @@ def _split_by_task(devices, values):
|
||||
|
||||
Args:
|
||||
devices: list of device name strings
|
||||
values: list of T `tf.tensor` of same length as devices.
|
||||
values: list of `tf.Tensor` of same length as devices.
|
||||
|
||||
Returns:
|
||||
(per_task_devices, per_task_values) where both values are
|
||||
@ -686,14 +686,14 @@ def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
|
||||
"""Build a subgraph that does one full all-reduce, using NCCL.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T `tf.Tensor` of same-shape and type values to
|
||||
input_tensors: list of `tf.Tensor` of same-shape and type values to
|
||||
be reduced.
|
||||
red_op: binary elementwise reduction operator. Must be one of
|
||||
red_op: binary elementwise reduction operator. Must be one of
|
||||
{tf.add}
|
||||
un_op: optional unary elementwise Op to apply to fully-reduce values.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` of reduced values.
|
||||
list of `tf.Tensor` of reduced values.
|
||||
|
||||
Raises:
|
||||
ValueError: red_op not supported.
|
||||
@ -715,14 +715,14 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
|
||||
"""Construct a subgraph for NCCL hybrid all-reduce.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T `tf.Tensor` of same-shape and type values to
|
||||
input_tensors: list of `tf.Tensor` of same-shape and type values to
|
||||
be reduced.
|
||||
red_op: binary elementwise reduction operator.
|
||||
upper_level_f: function for reducing one value per worker, across
|
||||
workers.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` of reduced values.
|
||||
list of `tf.Tensor` of reduced values.
|
||||
|
||||
Raises:
|
||||
ValueError: inputs not well-formed.
|
||||
@ -804,7 +804,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
|
||||
"""Construct a subgraph for Shuffle hybrid all-reduce.
|
||||
|
||||
Args:
|
||||
input_tensors: list of T `tf.Tensor` of same-shape and type values to
|
||||
input_tensors: list of `tf.Tensor` of same-shape and type values to
|
||||
be reduced.
|
||||
gather_devices: list of device names on which to host gather shards.
|
||||
red_op: binary elementwise reduction operator.
|
||||
@ -812,7 +812,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
|
||||
workers.
|
||||
|
||||
Returns:
|
||||
list of T `tf.Tensor` of reduced values.
|
||||
list of `tf.Tensor` of reduced values.
|
||||
|
||||
Raises:
|
||||
ValueError: inputs not well-formed.
|
||||
|
Loading…
Reference in New Issue
Block a user