diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 5259f27d96a..d17a594cb5e 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -520,7 +520,10 @@ class StrategyBase(object): """A state & compute distribution policy on a list of devices. See [the guide](https://www.tensorflow.org/guide/distributed_training) - for overview and examples. + for overview and examples. See `tf.distribute.StrategyExtended` and + [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) + for a glossory of concepts mentioned on this page such as "per-replica", + _replica_, and _reduce_. In short: @@ -736,12 +739,16 @@ class StrategyBase(object): # Iterate over the distributed dataset for x in dist_dataset: # process dataset elements - strategy.run(train_step, args=(x,)) + strategy.run(replica_fn, args=(x,)) ``` - We will assume that the input dataset is batched by the - global batch size. With this assumption, we will make a best effort to - divide each batch across all the replicas (one or more workers). + In the code snippet above, the dataset `dist_dataset` is batched by + GLOBAL_BATCH_SIZE, and we iterate through it using `for x in dist_dataset`, + where x is one batch of data of GLOBAL_BATCH_SIZE containing N batches of + data of per-replica batch size, corresponding to N replicas. + `tf.distribute.Strategy.run` will take care of feeding + the right per-replica batch to the right `replica_fn` execution on each + replica. In a multi-worker setting, we will first attempt to distribute the dataset by attempting to detect whether the dataset is being created out of @@ -892,8 +899,13 @@ class StrategyBase(object): `tf.distribute.DistributedValues` containing tensors or composite tensors. IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and - whether eager execution is enabled, `fn` may be called one or more times ( - once for each replica). + whether eager execution is enabled, `fn` may be called one or more times. If + `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is + called inside a `tf.function`, eager execution is disabled and `fn` is + called once (or once per replica, if you are using MirroredStrategy) to + generate a Tensorflow graph, which will then be reused for execution with + new inputs. Otherwise, if eager execution is enabled, `fn` will be called + every step just like regular python code. Example usage: