Update the docstring of all_reduce
PiperOrigin-RevId: 328974827 Change-Id: Ic285cfe57915c2cd6ccdb5252e77b7d20325a314
This commit is contained in:
parent
454ebebeab
commit
0c92bd7381
@ -3010,32 +3010,63 @@ class ReplicaContext(object):
|
||||
return (device_util.current(),)
|
||||
|
||||
def all_reduce(self, reduce_op, value, experimental_hints=None):
|
||||
"""All-reduces the given `value Tensor` nest across replicas.
|
||||
"""All-reduces `value` across all replicas.
|
||||
|
||||
If `all_reduce` is called in any replica, it must be called in all replicas.
|
||||
The nested structure and `Tensor` shapes must be identical in all replicas.
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
>>> def step_fn():
|
||||
... ctx = tf.distribute.get_replica_context()
|
||||
... value = tf.identity(1.)
|
||||
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
|
||||
>>> strategy.experimental_local_results(strategy.run(step_fn))
|
||||
(<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
|
||||
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
|
||||
|
||||
IMPORTANT: The ordering of communications must be identical in all replicas.
|
||||
It supports batched operations. You can pass a list of values and it
|
||||
attempts to batch them when possible. You can also specify `experimental_hints`
|
||||
to indicate the desired batching behavior, e.g. batch the values into
|
||||
multiple packs so that they can better overlap with computations.
|
||||
|
||||
Example with two replicas:
|
||||
Replica 0 `value`: {'a': 1, 'b': [40, 1]}
|
||||
Replica 1 `value`: {'a': 3, 'b': [ 2, 98]}
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
|
||||
>>> def step_fn():
|
||||
... ctx = tf.distribute.get_replica_context()
|
||||
... value1 = tf.identity(1.)
|
||||
... value2 = tf.identity(2.)
|
||||
... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
|
||||
>>> strategy.experimental_local_results(strategy.run(step_fn))
|
||||
([PerReplica:{
|
||||
0: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
|
||||
1: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
|
||||
}, PerReplica:{
|
||||
0: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>,
|
||||
1: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
|
||||
}],)
|
||||
|
||||
If `reduce_op` == `SUM`:
|
||||
Result (on all replicas): {'a': 4, 'b': [42, 99]}
|
||||
Note that all replicas need to participate in the all-reduce, otherwise this
|
||||
operation hangs. Note that if there're multiple all-reduces, they need to
|
||||
execute in the same order on all replicas. Dispatching all-reduce based on
|
||||
conditions is usually error-prone.
|
||||
|
||||
If `reduce_op` == `MEAN`:
|
||||
Result (on all replicas): {'a': 2, 'b': [21, 49.5]}
|
||||
This API currently can only be called in the replica context. Other
|
||||
variants to reduce values across replicas are:
|
||||
* `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API
|
||||
in the cross-replica context.
|
||||
* `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and
|
||||
all-reduce API in the cross-replica context.
|
||||
* `tf.distribute.Strategy.reduce`: a more convenient method to reduce
|
||||
to the host in cross-replica context.
|
||||
|
||||
Args:
|
||||
reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum.
|
||||
value: The nested structure of `Tensor`s to all-reduce. The structure must
|
||||
be compatible with `tf.nest`.
|
||||
experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
reduce_op: a `tf.distribute.ReduceOp` enum or its string form, which
|
||||
specifies how to reduce the value.
|
||||
value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts.
|
||||
The structure and the shapes of the `tf.Tensor` need to be same on all
|
||||
replicas.
|
||||
experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints
|
||||
to perform collective operations.
|
||||
|
||||
Returns:
|
||||
A `Tensor` nest with the reduced `value`s from each replica.
|
||||
A nested structure of `tf.Tensor` with the reduced values. The structure
|
||||
is the same as `value`.
|
||||
"""
|
||||
if isinstance(reduce_op, six.string_types):
|
||||
reduce_op = reduce_util.ReduceOp(reduce_op.upper())
|
||||
|
Loading…
x
Reference in New Issue
Block a user