Improve docstring of strategy.run.
PiperOrigin-RevId: 324938654 Change-Id: I3e13d90c026fad42657bb8094ccb32dc86e36b4b
This commit is contained in:
parent
6388aa43d7
commit
eaa5235e00
@ -1146,10 +1146,11 @@ class StrategyBase(object):
|
||||
dataset_fn, options)
|
||||
|
||||
def run(self, fn, args=(), kwargs=None, options=None):
|
||||
"""Run `fn` on each replica, with the given arguments.
|
||||
"""Invokes `fn` on each replica, with the given arguments.
|
||||
|
||||
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
|
||||
`tf.distribute.DistributedValues`, such as those produced by a
|
||||
This method is the primary way to distribute your computation with a
|
||||
tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
|
||||
have `tf.distribute.DistributedValues`, such as those produced by a
|
||||
`tf.distribute.DistributedDataset` from
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
|
||||
@ -1157,20 +1158,27 @@ class StrategyBase(object):
|
||||
component of `tf.distribute.DistributedValues` that correspond to that
|
||||
replica.
|
||||
|
||||
`fn` may call `tf.distribute.get_replica_context()` to access members such
|
||||
as `all_reduce`.
|
||||
`fn` is invoked under a replica context. `fn` may call
|
||||
`tf.distribute.get_replica_context()` to access members such as
|
||||
`all_reduce`. Please see the module-level docstring of tf.distribute for the
|
||||
concept of replica context.
|
||||
|
||||
All arguments in `args` or `kwargs` should either be nest of tensors or
|
||||
`tf.distribute.DistributedValues` containing tensors or composite tensors.
|
||||
All arguments in `args` or `kwargs` should either be Python values of a
|
||||
nested structure of tensors, e.g. a list of tensors, in which case `args`
|
||||
and `kwargs` will be passed to the `fn` invoked on each replica. Or `args`
|
||||
or `kwargs` can be `tf.distribute.DistributedValues` containing tensors or
|
||||
composite tensors, i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which
|
||||
case each `fn` call will get the component of a
|
||||
`tf.distribute.DistributedValues` corresponding to its replica.
|
||||
|
||||
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
|
||||
whether eager execution is enabled, `fn` may be called one or more times. If
|
||||
`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
|
||||
called inside a `tf.function`, eager execution is disabled and `fn` is
|
||||
called once (or once per replica, if you are using MirroredStrategy) to
|
||||
generate a Tensorflow graph, which will then be reused for execution with
|
||||
new inputs. Otherwise, if eager execution is enabled, `fn` will be called
|
||||
every step just like regular python code.
|
||||
called inside a `tf.function` (eager execution is disabled inside a
|
||||
`tf.function` by default), `fn` is called once per replica to generate a
|
||||
Tensorflow graph, which will then be reused for execution with new inputs.
|
||||
Otherwise, if eager execution is enabled, `fn` will be called once per
|
||||
replica every step just like regular python code.
|
||||
|
||||
Example usage:
|
||||
|
||||
@ -1205,11 +1213,33 @@ class StrategyBase(object):
|
||||
>>> result
|
||||
<tf.Tensor: shape=(), dtype=int32, numpy=4>
|
||||
|
||||
3. Use `tf.distribute.ReplicaContext` to allreduce values.
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
|
||||
>>> @tf.function
|
||||
... def run():
|
||||
... def value_fn(value_context):
|
||||
... return tf.constant(value_context.replica_id_in_sync_group)
|
||||
... distributed_values = (
|
||||
... strategy.experimental_distribute_values_from_function(
|
||||
... value_fn))
|
||||
... def replica_fn(input):
|
||||
... return tf.distribute.get_replica_context().all_reduce("sum", input)
|
||||
... return strategy.run(replica_fn, args=(distributed_values,))
|
||||
>>> result = run()
|
||||
>>> result
|
||||
PerReplica:{
|
||||
0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
|
||||
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
|
||||
}
|
||||
|
||||
Args:
|
||||
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
|
||||
args: (Optional) Positional arguments to `fn`.
|
||||
kwargs: (Optional) Keyword arguments to `fn`.
|
||||
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
|
||||
fn: The function to run on each replica.
|
||||
args: Optional positional arguments to `fn`. Its element can be a Python
|
||||
value, a tensor or a `tf.distribute.DistributedValues`.
|
||||
kwargs: Optional keyword arguments to `fn`. Its element can be a Python
|
||||
value, a tensor or a `tf.distribute.DistributedValues`.
|
||||
options: An optional instance of `tf.distribute.RunOptions` specifying
|
||||
the options to run `fn`.
|
||||
|
||||
Returns:
|
||||
|
Loading…
Reference in New Issue
Block a user