Improve docstring of strategy.run.

PiperOrigin-RevId: 324938654
Change-Id: I3e13d90c026fad42657bb8094ccb32dc86e36b4b
This commit is contained in:
Yuefeng Zhou 2020-08-04 19:12:35 -07:00 committed by TensorFlower Gardener
parent 6388aa43d7
commit eaa5235e00

View File

@ -1146,10 +1146,11 @@ class StrategyBase(object):
dataset_fn, options)
def run(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments.
"""Invokes `fn` on each replica, with the given arguments.
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
`tf.distribute.DistributedValues`, such as those produced by a
This method is the primary way to distribute your computation with a
tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs`
have `tf.distribute.DistributedValues`, such as those produced by a
`tf.distribute.DistributedDataset` from
`tf.distribute.Strategy.experimental_distribute_dataset` or
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
@ -1157,20 +1158,27 @@ class StrategyBase(object):
component of `tf.distribute.DistributedValues` that correspond to that
replica.
`fn` may call `tf.distribute.get_replica_context()` to access members such
as `all_reduce`.
`fn` is invoked under a replica context. `fn` may call
`tf.distribute.get_replica_context()` to access members such as
`all_reduce`. Please see the module-level docstring of tf.distribute for the
concept of replica context.
All arguments in `args` or `kwargs` should either be nest of tensors or
`tf.distribute.DistributedValues` containing tensors or composite tensors.
All arguments in `args` or `kwargs` should either be Python values of a
nested structure of tensors, e.g. a list of tensors, in which case `args`
and `kwargs` will be passed to the `fn` invoked on each replica. Or `args`
or `kwargs` can be `tf.distribute.DistributedValues` containing tensors or
composite tensors, i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which
case each `fn` call will get the component of a
`tf.distribute.DistributedValues` corresponding to its replica.
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
whether eager execution is enabled, `fn` may be called one or more times. If
`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
called inside a `tf.function`, eager execution is disabled and `fn` is
called once (or once per replica, if you are using MirroredStrategy) to
generate a Tensorflow graph, which will then be reused for execution with
new inputs. Otherwise, if eager execution is enabled, `fn` will be called
every step just like regular python code.
called inside a `tf.function` (eager execution is disabled inside a
`tf.function` by default), `fn` is called once per replica to generate a
Tensorflow graph, which will then be reused for execution with new inputs.
Otherwise, if eager execution is enabled, `fn` will be called once per
replica every step just like regular python code.
Example usage:
@ -1205,11 +1213,33 @@ class StrategyBase(object):
>>> result
<tf.Tensor: shape=(), dtype=int32, numpy=4>
3. Use `tf.distribute.ReplicaContext` to allreduce values.
>>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"])
>>> @tf.function
... def run():
... def value_fn(value_context):
... return tf.constant(value_context.replica_id_in_sync_group)
... distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
... def replica_fn(input):
... return tf.distribute.get_replica_context().all_reduce("sum", input)
... return strategy.run(replica_fn, args=(distributed_values,))
>>> result = run()
>>> result
PerReplica:{
0: <tf.Tensor: shape=(), dtype=int32, numpy=1>,
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}
Args:
fn: The function to run. The output must be a `tf.nest` of `Tensor`s.
args: (Optional) Positional arguments to `fn`.
kwargs: (Optional) Keyword arguments to `fn`.
options: (Optional) An instance of `tf.distribute.RunOptions` specifying
fn: The function to run on each replica.
args: Optional positional arguments to `fn`. Its element can be a Python
value, a tensor or a `tf.distribute.DistributedValues`.
kwargs: Optional keyword arguments to `fn`. Its element can be a Python
value, a tensor or a `tf.distribute.DistributedValues`.
options: An optional instance of `tf.distribute.RunOptions` specifying
the options to run `fn`.
Returns: