Small additions to DistributedStrategy's API docs

PiperOrigin-RevId: 308949260
Change-Id: Ib77b03bbcc38083ce64504e29f84c2cfc8073f85
This commit is contained in:
Xinyi Wang 2020-04-28 20:13:57 -07:00 committed by TensorFlower Gardener
parent 9403febb65
commit 4bfe1dce64

View File

@ -520,7 +520,10 @@ class StrategyBase(object):
"""A state & compute distribution policy on a list of devices.
See [the guide](https://www.tensorflow.org/guide/distributed_training)
for overview and examples.
for overview and examples. See `tf.distribute.StrategyExtended` and
[`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
for a glossory of concepts mentioned on this page such as "per-replica",
_replica_, and _reduce_.
In short:
@ -736,12 +739,16 @@ class StrategyBase(object):
# Iterate over the distributed dataset
for x in dist_dataset:
# process dataset elements
strategy.run(train_step, args=(x,))
strategy.run(replica_fn, args=(x,))
```
We will assume that the input dataset is batched by the
global batch size. With this assumption, we will make a best effort to
divide each batch across all the replicas (one or more workers).
In the code snippet above, the dataset `dist_dataset` is batched by
GLOBAL_BATCH_SIZE, and we iterate through it using `for x in dist_dataset`,
where x is one batch of data of GLOBAL_BATCH_SIZE containing N batches of
data of per-replica batch size, corresponding to N replicas.
`tf.distribute.Strategy.run` will take care of feeding
the right per-replica batch to the right `replica_fn` execution on each
replica.
In a multi-worker setting, we will first attempt to distribute the dataset
by attempting to detect whether the dataset is being created out of
@ -892,8 +899,13 @@ class StrategyBase(object):
`tf.distribute.DistributedValues` containing tensors or composite tensors.
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
whether eager execution is enabled, `fn` may be called one or more times (
once for each replica).
whether eager execution is enabled, `fn` may be called one or more times. If
`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
called inside a `tf.function`, eager execution is disabled and `fn` is
called once (or once per replica, if you are using MirroredStrategy) to
generate a Tensorflow graph, which will then be reused for execution with
new inputs. Otherwise, if eager execution is enabled, `fn` will be called
every step just like regular python code.
Example usage: