Update the docstring of all_reduce

PiperOrigin-RevId: 328974827
Change-Id: Ic285cfe57915c2cd6ccdb5252e77b7d20325a314
This commit is contained in:
Ran Chen 2020-08-28 11:21:00 -07:00 committed by TensorFlower Gardener
parent 454ebebeab
commit 0c92bd7381

View File

@ -3010,32 +3010,63 @@ class ReplicaContext(object):
return (device_util.current(),)
def all_reduce(self, reduce_op, value, experimental_hints=None):
"""All-reduces the given `value Tensor` nest across replicas.
"""All-reduces `value` across all replicas.
If `all_reduce` is called in any replica, it must be called in all replicas.
The nested structure and `Tensor` shapes must be identical in all replicas.
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def step_fn():
... ctx = tf.distribute.get_replica_context()
... value = tf.identity(1.)
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
>>> strategy.experimental_local_results(strategy.run(step_fn))
(<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
IMPORTANT: The ordering of communications must be identical in all replicas.
It supports batched operations. You can pass a list of values and it
attempts to batch them when possible. You can also specify `experimental_hints`
to indicate the desired batching behavior, e.g. batch the values into
multiple packs so that they can better overlap with computations.
Example with two replicas:
Replica 0 `value`: {'a': 1, 'b': [40, 1]}
Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
>>> def step_fn():
... ctx = tf.distribute.get_replica_context()
... value1 = tf.identity(1.)
... value2 = tf.identity(2.)
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
>>> strategy.experimental_local_results(strategy.run(step_fn))
([PerReplica:{
0: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
1: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
}, PerReplica:{
0: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>,
1: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
}],)
If `reduce_op` == `SUM`:
Result (on all replicas): {'a': 4, 'b': [42, 99]}
Note that all replicas need to participate in the all-reduce, otherwise this
operation hangs. Note that if there're multiple all-reduces, they need to
execute in the same order on all replicas. Dispatching all-reduce based on
conditions is usually error-prone.
If `reduce_op` == `MEAN`:
Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
This API currently can only be called in the replica context. Other
variants to reduce values across replicas are:
* `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
in the cross-replica context.
* `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
all-reduce API in the cross-replica context.
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
to the host in cross-replica context.
Args:
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
value: The nested structure of `Tensor`s to all-reduce. The structure must
be compatible with `tf.nest`.
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
reduce_op: a `tf.distribute.ReduceOp` enum or its string form, which
specifies how to reduce the value.
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
The structure and the shapes of the `tf.Tensor` need to be same on all
replicas.
experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints
to perform collective operations.
Returns:
A `Tensor` nest with the reduced `value`s from each replica.
A nested structure of `tf.Tensor` with the reduced values. The structure
is the same as `value`.
"""
if isinstance(reduce_op, six.string_types):
reduce_op = reduce_util.ReduceOp(reduce_op.upper())