Small additions to DistributedStrategy's API docs

PiperOrigin-RevId: 308949260
Change-Id: Ib77b03bbcc38083ce64504e29f84c2cfc8073f85
This commit is contained in:
Xinyi Wang 2020-04-28 20:13:57 -07:00 committed by TensorFlower Gardener
parent 9403febb65
commit 4bfe1dce64

View File

@ -520,7 +520,10 @@ class StrategyBase(object):
"""A state & compute distribution policy on a list of devices. """A state & compute distribution policy on a list of devices.
See [the guide](https://www.tensorflow.org/guide/distributed_training) See [the guide](https://www.tensorflow.org/guide/distributed_training)
for overview and examples. for overview and examples. See `tf.distribute.StrategyExtended` and
[`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
for a glossory of concepts mentioned on this page such as "per-replica",
_replica_, and _reduce_.
In short: In short:
@ -736,12 +739,16 @@ class StrategyBase(object):
# Iterate over the distributed dataset # Iterate over the distributed dataset
for x in dist_dataset: for x in dist_dataset:
# process dataset elements # process dataset elements
strategy.run(train_step, args=(x,)) strategy.run(replica_fn, args=(x,))
``` ```
We will assume that the input dataset is batched by the In the code snippet above, the dataset `dist_dataset` is batched by
global batch size. With this assumption, we will make a best effort to GLOBAL_BATCH_SIZE, and we iterate through it using `for x in dist_dataset`,
divide each batch across all the replicas (one or more workers). where x is one batch of data of GLOBAL_BATCH_SIZE containing N batches of
data of per-replica batch size, corresponding to N replicas.
`tf.distribute.Strategy.run` will take care of feeding
the right per-replica batch to the right `replica_fn` execution on each
replica.
In a multi-worker setting, we will first attempt to distribute the dataset In a multi-worker setting, we will first attempt to distribute the dataset
by attempting to detect whether the dataset is being created out of by attempting to detect whether the dataset is being created out of
@ -892,8 +899,13 @@ class StrategyBase(object):
`tf.distribute.DistributedValues` containing tensors or composite tensors. `tf.distribute.DistributedValues` containing tensors or composite tensors.
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
whether eager execution is enabled, `fn` may be called one or more times ( whether eager execution is enabled, `fn` may be called one or more times. If
once for each replica). `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
called inside a `tf.function`, eager execution is disabled and `fn` is
called once (or once per replica, if you are using MirroredStrategy) to
generate a Tensorflow graph, which will then be reused for execution with
new inputs. Otherwise, if eager execution is enabled, `fn` will be called
every step just like regular python code.
Example usage: Example usage: