Small additions to DistributedStrategy's API docs
PiperOrigin-RevId: 308949260 Change-Id: Ib77b03bbcc38083ce64504e29f84c2cfc8073f85
This commit is contained in:
parent
9403febb65
commit
4bfe1dce64
@ -520,7 +520,10 @@ class StrategyBase(object):
|
||||
"""A state & compute distribution policy on a list of devices.
|
||||
|
||||
See [the guide](https://www.tensorflow.org/guide/distributed_training)
|
||||
for overview and examples.
|
||||
for overview and examples. See `tf.distribute.StrategyExtended` and
|
||||
[`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute)
|
||||
for a glossory of concepts mentioned on this page such as "per-replica",
|
||||
_replica_, and _reduce_.
|
||||
|
||||
In short:
|
||||
|
||||
@ -736,12 +739,16 @@ class StrategyBase(object):
|
||||
# Iterate over the distributed dataset
|
||||
for x in dist_dataset:
|
||||
# process dataset elements
|
||||
strategy.run(train_step, args=(x,))
|
||||
strategy.run(replica_fn, args=(x,))
|
||||
```
|
||||
|
||||
We will assume that the input dataset is batched by the
|
||||
global batch size. With this assumption, we will make a best effort to
|
||||
divide each batch across all the replicas (one or more workers).
|
||||
In the code snippet above, the dataset `dist_dataset` is batched by
|
||||
GLOBAL_BATCH_SIZE, and we iterate through it using `for x in dist_dataset`,
|
||||
where x is one batch of data of GLOBAL_BATCH_SIZE containing N batches of
|
||||
data of per-replica batch size, corresponding to N replicas.
|
||||
`tf.distribute.Strategy.run` will take care of feeding
|
||||
the right per-replica batch to the right `replica_fn` execution on each
|
||||
replica.
|
||||
|
||||
In a multi-worker setting, we will first attempt to distribute the dataset
|
||||
by attempting to detect whether the dataset is being created out of
|
||||
@ -892,8 +899,13 @@ class StrategyBase(object):
|
||||
`tf.distribute.DistributedValues` containing tensors or composite tensors.
|
||||
|
||||
IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and
|
||||
whether eager execution is enabled, `fn` may be called one or more times (
|
||||
once for each replica).
|
||||
whether eager execution is enabled, `fn` may be called one or more times. If
|
||||
`fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is
|
||||
called inside a `tf.function`, eager execution is disabled and `fn` is
|
||||
called once (or once per replica, if you are using MirroredStrategy) to
|
||||
generate a Tensorflow graph, which will then be reused for execution with
|
||||
new inputs. Otherwise, if eager execution is enabled, `fn` will be called
|
||||
every step just like regular python code.
|
||||
|
||||
Example usage:
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user