Export and Document DistributedDataset and DistributedIterator APIs
PiperOrigin-RevId: 317007583 Change-Id: I7d7c4615a12a19fb4fd151a0457f176ffe2cd765
This commit is contained in:
parent
56db128697
commit
c0ba8a09a7
@ -684,7 +684,8 @@ class StrategyBase(object):
|
||||
instead.
|
||||
* Use `tf.distribute.Strategy.run` to run a function
|
||||
once per replica, taking values that may be "per-replica" (e.g.
|
||||
from a distributed dataset) and returning "per-replica" values.
|
||||
from a `tf.distribute.DistributedDataset` object) and returning
|
||||
"per-replica" values.
|
||||
This function is executed in "replica context", which means each
|
||||
operation is performed separately on each replica.
|
||||
* Finally use a method (such as `tf.distribute.Strategy.reduce`) to
|
||||
@ -720,7 +721,8 @@ class StrategyBase(object):
|
||||
distributed-specific behavior.
|
||||
|
||||
You can use the `reduce` API to aggregate results across replicas and use
|
||||
this as a return value from one iteration over the distributed dataset. Or
|
||||
this as a return value from one iteration over a
|
||||
`tf.distribute.DistributedDataset`. Or
|
||||
you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to
|
||||
accumulate metrics across steps in a given epoch.
|
||||
|
||||
@ -859,12 +861,12 @@ class StrategyBase(object):
|
||||
return self.run(fn, args=args)
|
||||
|
||||
def experimental_distribute_dataset(self, dataset, options=None):
|
||||
"""Distributes a tf.data.Dataset instance provided via `dataset`.
|
||||
"""Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`.
|
||||
|
||||
The returned distributed dataset can be iterated over similar to how
|
||||
regular datasets can.
|
||||
NOTE: Currently, the user cannot add any more transformations to a
|
||||
distributed dataset.
|
||||
The returned `tf.distribute.DistributedDataset` can be iterated over
|
||||
similar to how regular datasets can.
|
||||
NOTE: The user cannot add any more transformations to a
|
||||
`tf.distribute.DistributedDataset`.
|
||||
|
||||
The following is an example:
|
||||
|
||||
@ -878,48 +880,53 @@ class StrategyBase(object):
|
||||
# Distribute that dataset
|
||||
dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
|
||||
# Iterate over the distributed dataset
|
||||
# Iterate over the `tf.distribute.DistributedDataset`
|
||||
for x in dist_dataset:
|
||||
# process dataset elements
|
||||
strategy.run(replica_fn, args=(x,))
|
||||
```
|
||||
|
||||
In the code snippet above, the dataset `dist_dataset` is batched by
|
||||
GLOBAL_BATCH_SIZE, and we iterate through it using `for x in dist_dataset`,
|
||||
where x is one batch of data of GLOBAL_BATCH_SIZE containing N batches of
|
||||
data of per-replica batch size, corresponding to N replicas.
|
||||
`tf.distribute.Strategy.run` will take care of feeding
|
||||
the right per-replica batch to the right `replica_fn` execution on each
|
||||
In the code snippet above, the `tf.distribute.DistributedDataset`
|
||||
`dist_dataset` is batched by `GLOBAL_BATCH_SIZE`, and we iterate through it
|
||||
using `for x in dist_dataset`. `x` a `tf.distribute.DistributedValues`
|
||||
containing data for all replicas, which aggregates to a batch of
|
||||
`GLOBAL_BATCH_SIZE`. `tf.distribute.Strategy.run` will take care of feeding
|
||||
the right per-replica data in `x` to the right `replica_fn` executed on each
|
||||
replica.
|
||||
|
||||
In a multi-worker setting, we will first attempt to distribute the dataset
|
||||
by attempting to detect whether the dataset is being created out of
|
||||
ReaderDatasets (e.g. TFRecordDataset, TextLineDataset, etc.) and if so,
|
||||
attempting to shard the input files. Note that there has to be at least one
|
||||
input file per worker. If you have less than one input file per worker, we
|
||||
suggest that you should disable distributing your dataset using the method
|
||||
below.
|
||||
What's under the hood of this method, when we say the `tf.data.Dataset`
|
||||
instance - `dataset` - gets distributed? It depends on how you set the
|
||||
`tf.data.experimental.AutoShardPolicy` through
|
||||
`tf.data.experimental.DistributeOptions`. By default, it is set to
|
||||
`tf.data.experimental.AutoShardPolicy.AUTO`. In a multi-worker setting, we
|
||||
will first attempt to distribute `dataset` by detecting whether `dataset` is
|
||||
being created out of reader datasets (e.g. `tf.data.TFRecordDataset`,
|
||||
`tf.data.TextLineDataset`, etc.) and if so, try to shard the input files.
|
||||
Note that there has to be at least one input file per worker. If you have
|
||||
less than one input file per worker, we suggest that you disable dataset
|
||||
sharding across workers, by setting the
|
||||
`tf.data.experimental.DistributeOptions.auto_shard_policy` to be
|
||||
`tf.data.experimental.AutoShardPolicy.OFF`.
|
||||
|
||||
If that attempt is unsuccessful (e.g. the dataset is created from a
|
||||
Dataset.range), we will shard the dataset evenly at the end by appending a
|
||||
`.shard` operation to the end of the processing pipeline. This will cause
|
||||
the entire preprocessing pipeline for all the data to be run on every
|
||||
worker, and each worker will do redundant work. We will print a warning
|
||||
if this method of sharding is selected.
|
||||
If the attempt to shard by file is unsuccessful (i.e. the dataset is not
|
||||
read from files), we will shard the dataset evenly at the end by
|
||||
appending a `.shard` operation to the end of the processing pipeline. This
|
||||
will cause the entire preprocessing pipeline for all the data to be run on
|
||||
every worker, and each worker will do redundant work. We will print a
|
||||
warning if this route is selected.
|
||||
|
||||
You can disable dataset sharding across workers using the
|
||||
`auto_shard_policy` option in `tf.data.experimental.DistributeOptions`.
|
||||
|
||||
Within each worker, we will also split the data among all the worker
|
||||
devices (if more than one a present), and this will happen even if
|
||||
multi-worker sharding is disabled using the method above.
|
||||
As mentioned before, within each worker, we will also split the data among
|
||||
all the worker devices (if more than one a present). This will happen
|
||||
even if multi-worker sharding is disabled.
|
||||
|
||||
If the above batch splitting and dataset sharding logic is undesirable,
|
||||
please use `experimental_distribute_datasets_from_function` instead, which
|
||||
does not do any automatic splitting or sharding.
|
||||
please use
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`
|
||||
instead, which does not do any automatic splitting or sharding.
|
||||
|
||||
You can also use the `element_spec` property of the distributed dataset
|
||||
returned by this API to query the `tf.TypeSpec` of the elements returned
|
||||
You can also use the `element_spec` property of the
|
||||
`tf.distribute.DistributedDataset` instance returned by this API to query
|
||||
the `tf.TypeSpec` of the elements returned
|
||||
by the iterator. This can be used to set the `input_signature` property
|
||||
of a `tf.function`.
|
||||
|
||||
@ -938,12 +945,21 @@ class StrategyBase(object):
|
||||
# train model with inputs
|
||||
return
|
||||
|
||||
# Iterate over the distributed dataset
|
||||
# Iterate over the `tf.distribute.DistributedDataset`
|
||||
for x in dist_dataset:
|
||||
# process dataset elements
|
||||
strategy.run(train_step, args=(x,))
|
||||
```
|
||||
|
||||
Note: The order in which the data is processed by the workers when using
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function` is
|
||||
not guaranteed. This is typically required if you are using
|
||||
`tf.distribute` to scale prediction. You can however insert an index for
|
||||
each element in the batch and order outputs accordingly. Refer to [this
|
||||
snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
|
||||
for an example of how to order outputs.
|
||||
|
||||
Args:
|
||||
dataset: `tf.data.Dataset` that will be sharded across all replicas using
|
||||
the rules stated above.
|
||||
@ -951,8 +967,7 @@ class StrategyBase(object):
|
||||
dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A "distributed `Dataset`", which acts like a `tf.data.Dataset` except
|
||||
it produces "per-replica" values.
|
||||
A `tf.distribute.DistributedDataset`.
|
||||
"""
|
||||
return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access
|
||||
|
||||
@ -978,10 +993,10 @@ class StrategyBase(object):
|
||||
The `dataset_fn` should take an `tf.distribute.InputContext` instance where
|
||||
information about batching and input replication can be accessed.
|
||||
|
||||
You can also use the `element_spec` property of the distributed dataset
|
||||
returned by this API to query the `tf.TypeSpec` of the elements returned
|
||||
by the iterator. This can be used to set the `input_signature` property
|
||||
of a `tf.function`.
|
||||
You can also use the `element_spec` property of the
|
||||
`tf.distribute.DistributedDataset` returned by this API to query the
|
||||
`tf.TypeSpec` of the elements returned by the iterator. This can be used to
|
||||
set the `input_signature` property of a `tf.function`.
|
||||
|
||||
>>> global_batch_size = 8
|
||||
>>> def dataset_fn(input_context):
|
||||
@ -1010,6 +1025,16 @@ class StrategyBase(object):
|
||||
the global batch size. This may be computed using
|
||||
`input_context.get_per_replica_batch_size`.
|
||||
|
||||
|
||||
Note: The order in which the data is processed by the workers when using
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function` is
|
||||
not guaranteed. This is typically required if you are using
|
||||
`tf.distribute` to scale prediction. You can however insert an index for
|
||||
each element in the batch and order outputs accordingly. Refer to [this
|
||||
snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats)
|
||||
for an example of how to order outputs.
|
||||
|
||||
Args:
|
||||
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
|
||||
returning a `tf.data.Dataset`.
|
||||
@ -1017,8 +1042,7 @@ class StrategyBase(object):
|
||||
dataset is distributed.
|
||||
|
||||
Returns:
|
||||
A "distributed `Dataset`", which acts like a `tf.data.Dataset` except
|
||||
it produces "per-replica" values.
|
||||
A `tf.distribute.DistributedDataset`.
|
||||
"""
|
||||
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
|
||||
dataset_fn, options)
|
||||
@ -1028,7 +1052,9 @@ class StrategyBase(object):
|
||||
|
||||
Executes ops specified by `fn` on each replica. If `args` or `kwargs` have
|
||||
`tf.distribute.DistributedValues`, such as those produced by a
|
||||
"distributed `Dataset`" or `experimental_distribute_values_from_function`
|
||||
`tf.distribute.DistributedDataset` from
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset` or
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function`,
|
||||
when `fn` is executed on a particular replica, it will be executed with the
|
||||
component of `tf.distribute.DistributedValues` that correspond to that
|
||||
replica.
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import sys
|
||||
|
||||
@ -52,6 +53,8 @@ from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.types import distribute as distribute_types
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
|
||||
def get_distributed_dataset(dataset,
|
||||
@ -138,6 +141,321 @@ def get_distributed_datasets_from_function(dataset_fn,
|
||||
strategy)
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedIterator", v1=[])
|
||||
class DistributedIteratorInterface(collections.Iterator,
|
||||
distribute_types.Iterator):
|
||||
"""An iterator over `tf.distribute.DistributedDataset`.
|
||||
|
||||
`tf.distribute.DistributedIterator` is the primary mechanism for enumerating
|
||||
elements of a `tf.distribute.DistributedDataset`. It supports the Python
|
||||
Iterator protocol, which means it can be iterated over using a for-loop or by
|
||||
fetching individual elements explicitly via `get_next()`.
|
||||
|
||||
You can create a `tf.distribute.DistributedIterator` by calling `iter` on
|
||||
a `tf.distribute.DistributedDataset` or creating a python loop over a
|
||||
`tf.distribute.DistributedDataset`.
|
||||
|
||||
Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
|
||||
on distributed input for more examples and caveats.
|
||||
"""
|
||||
|
||||
def get_next(self):
|
||||
"""Returns the next input from the iterator for all replicas.
|
||||
|
||||
Example use:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.range(100).batch(2)
|
||||
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
>>> dist_dataset_iterator = iter(dist_dataset)
|
||||
>>> @tf.function
|
||||
... def one_step(input):
|
||||
... return input
|
||||
>>> step_num = 5
|
||||
>>> for _ in range(step_num):
|
||||
... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
|
||||
>>> strategy.experimental_local_results(dist_dataset_iterator.get_next())
|
||||
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])>,)
|
||||
|
||||
The above example corresponds to the case where you have only one device. If
|
||||
you have two devices, for example,
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
|
||||
```
|
||||
Then the final line will print out:
|
||||
```python
|
||||
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
|
||||
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
|
||||
```
|
||||
|
||||
Returns:
|
||||
A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains
|
||||
the next input for all replicas.
|
||||
|
||||
Raises:
|
||||
`tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"DistributedIterator.get_next() must be implemented in descendants.")
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
# pylint: disable=line-too-long
|
||||
"""The type specification of an element of `tf.distribute.DistributedIterator`.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> global_batch_size = 16
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
|
||||
>>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
>>> distributed_iterator.element_spec
|
||||
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
|
||||
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
|
||||
|
||||
The above example corresponds to the case where you have only one device. If
|
||||
you have two devices, for example,
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
|
||||
```
|
||||
Then the final line will print out:
|
||||
```python
|
||||
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
|
||||
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
|
||||
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
|
||||
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
|
||||
```
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this `tf.distribute.DistributedIterator`. This returned value
|
||||
is typically a `tf.distribute.DistributedValues` object and specifies the
|
||||
`tf.TensorSpec` of individual components.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"DistributedIterator.element_spec() must be implemented in descendants")
|
||||
|
||||
|
||||
@tf_export("distribute.DistributedDataset", v1=[])
|
||||
class DistributedDatasetInterface(collections.Iterable,
|
||||
distribute_types.Iterable):
|
||||
# pylint: disable=line-too-long
|
||||
"""Represents a dataset distributed among devices and machines.
|
||||
|
||||
A `tf.distribute.DistributedDataset` could be thought of as a "distributed"
|
||||
dataset. When you use `tf.distribute` API to scale training to multiple
|
||||
devices or machines, you also need to distribute the input data, which leads
|
||||
to a `tf.distribute.DistributedDataset` instance, instead of a
|
||||
`tf.data.Dataset` instance in the non-distributed case. In TF 2.x,
|
||||
`tf.distribute.DistributedDataset` objects are Python iterables.
|
||||
|
||||
Note: `tf.distribute.DistributedDataset` instances are *not* of type
|
||||
`tf.data.Dataset`. It only supports two usages we will mention below:
|
||||
iteration and `element_spec`. We don't support any other APIs to transform or
|
||||
inspect the dataset.
|
||||
|
||||
There are two APIs to create a `tf.distribute.DistributedDataset` object:
|
||||
`tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and
|
||||
`tf.distribute.Strategy.experimental_distribute_datasets_from_function(dataset_fn)`.
|
||||
*When to use which?* When you have a `tf.data.Dataset` instance, and the
|
||||
regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance
|
||||
with a new batch size that is equal to the global batch size divided by the
|
||||
number of replicas in sync) and autosharding (i.e. the
|
||||
`tf.data.experimental.AutoShardPolicy` options) work for you, use the former
|
||||
API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance,
|
||||
or you would like to customize the batch splitting or sharding, you can wrap
|
||||
these logic in a `dataset_fn` and use the latter API. Both API handles
|
||||
prefetch to device for the user. For more details and examples, follow the
|
||||
links to the APIs.
|
||||
|
||||
|
||||
There are two main usages of a `DistributedDataset` object:
|
||||
|
||||
1. Iterate over it to generate the input for a single device or multiple
|
||||
devices, which is a `tf.distribute.DistributedValues` instance. To do this,
|
||||
you can:
|
||||
|
||||
* use a pythonic for-loop construct:
|
||||
|
||||
>>> global_batch_size = 2
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
|
||||
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
>>> @tf.function
|
||||
... def train_step(input):
|
||||
... features, labels = input
|
||||
... return labels - 0.3 * features
|
||||
>>> for x in dist_dataset:
|
||||
... # train_step trains the model using the dataset elements
|
||||
... loss = strategy.run(train_step, args=(x,))
|
||||
... print("Loss is", loss)
|
||||
Loss is tf.Tensor(
|
||||
[[0.7]
|
||||
[0.7]], shape=(2, 1), dtype=float32)
|
||||
Loss is tf.Tensor(
|
||||
[[0.7]
|
||||
[0.7]], shape=(2, 1), dtype=float32)
|
||||
|
||||
Placing the loop inside a `tf.function` will give a performance boost.
|
||||
However `break` and `return` are currently not supported if the loop is
|
||||
placed inside a `tf.function`. We also don't support placing the loop
|
||||
inside a `tf.function` when using
|
||||
`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
|
||||
`tf.distribute.experimental.TPUStrategy` with multiple workers.
|
||||
|
||||
* use `__iter__` to create an explicit iterator, which is of type
|
||||
`tf.distribute.DistributedIterator`
|
||||
|
||||
>>> global_batch_size = 4
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
|
||||
>>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
||||
>>> @tf.function
|
||||
... def distributed_train_step(dataset_inputs):
|
||||
... def train_step(input):
|
||||
... loss = tf.constant(0.1)
|
||||
... return loss
|
||||
... per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
|
||||
... return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
|
||||
>>> EPOCHS = 2
|
||||
>>> STEPS = 3
|
||||
>>> for epoch in range(EPOCHS):
|
||||
... total_loss = 0.0
|
||||
... num_batches = 0
|
||||
... dist_dataset_iterator = iter(train_dist_dataset)
|
||||
... for _ in range(STEPS):
|
||||
... total_loss += distributed_train_step(next(dist_dataset_iterator))
|
||||
... num_batches += 1
|
||||
... average_train_loss = total_loss / num_batches
|
||||
... template = ("Epoch {}, Loss: {}")
|
||||
... print (template.format(epoch+1, average_train_loss))
|
||||
Epoch 1, Loss: 0.10000000894069672
|
||||
Epoch 2, Loss: 0.10000000894069672
|
||||
|
||||
|
||||
To achieve a performance improvement, you can also wrap the `strategy.run`
|
||||
call with a `tf.range` inside a `tf.function`. This runs multiple steps in a
|
||||
`tf.function`. Autograph will convert it to a `tf.while_loop` on the worker.
|
||||
However, it is less flexible comparing with running a single step inside
|
||||
`tf.function`. For example, you cannot run things eagerly or arbitrary
|
||||
python code within the steps.
|
||||
|
||||
|
||||
2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`.
|
||||
|
||||
`tf.distribute.DistributedDataset` generates
|
||||
`tf.distribute.DistributedValues` as input to the devices. If you pass the
|
||||
input to a `tf.function` and would like to specify the shape and type of
|
||||
each Tensor argument to the function, you can pass a `tf.TypeSpec` object to
|
||||
the `input_signature` argument of the `tf.function`. To get the
|
||||
`tf.TypeSpec` of the input, you can use the `element_spec` property of the
|
||||
`tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator`
|
||||
object.
|
||||
|
||||
For example:
|
||||
|
||||
>>> global_batch_size = 2
|
||||
>>> epochs = 1
|
||||
>>> steps_per_epoch = 1
|
||||
>>> mirrored_strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
|
||||
>>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
|
||||
>>> @tf.function(input_signature=[dist_dataset.element_spec])
|
||||
... def train_step(per_replica_inputs):
|
||||
... def step_fn(inputs):
|
||||
... return tf.square(inputs)
|
||||
... return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
|
||||
>>> for _ in range(epochs):
|
||||
... iterator = iter(dist_dataset)
|
||||
... for _ in range(steps_per_epoch):
|
||||
... output = train_step(next(iterator))
|
||||
... print(output)
|
||||
tf.Tensor(
|
||||
[[4.]
|
||||
[4.]], shape=(2, 1), dtype=float32)
|
||||
|
||||
|
||||
Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
|
||||
on distributed input for more examples and caveats.
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
"""Creates an iterator for the `tf.distribute.DistributedDataset`.
|
||||
|
||||
The returned iterator implements the Python Iterator protocol.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> global_batch_size = 4
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
|
||||
>>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
|
||||
>>> print(next(distributed_iterator))
|
||||
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
|
||||
|
||||
|
||||
The above example corresponds to the case where you have only one device. If
|
||||
you have two devices, for example,
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
|
||||
```
|
||||
Then the final line will print out:
|
||||
```python
|
||||
PerReplica:{
|
||||
0: tf.Tensor([1 2], shape=(2,), dtype=int32),
|
||||
1: tf.Tensor([3 4], shape=(2,), dtype=int32)
|
||||
}
|
||||
```
|
||||
|
||||
Returns:
|
||||
An `tf.distribute.DistributedIterator` instance for the given
|
||||
`tf.distribute.DistributedDataset` object to enumerate over the
|
||||
distributed data.
|
||||
"""
|
||||
raise NotImplementedError("Must be implemented in descendants")
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this `tf.distribute.DistributedDataset`.
|
||||
|
||||
Example usage:
|
||||
|
||||
>>> global_batch_size = 16
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
|
||||
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
|
||||
>>> dist_dataset.element_spec
|
||||
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
|
||||
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
|
||||
|
||||
The above example corresponds to the case where you have only one device. If
|
||||
you have two devices, for example,
|
||||
```python
|
||||
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
|
||||
```
|
||||
Then the final line will print out:
|
||||
```python
|
||||
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
|
||||
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
|
||||
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
|
||||
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
|
||||
```
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TypeSpec` objects matching the structure of an
|
||||
element of this `tf.distribute.DistributedDataset`. This returned value is
|
||||
typically a `tf.distribute.DistributedValues` object and specifies the
|
||||
`tf.TensorSpec` of individual components.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"DistributedDataset.element_spec must be implemented in descendants.")
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def reduce(self, initial_state, reduce_func):
|
||||
raise NotImplementedError(
|
||||
"DistributedDataset.reduce must be implemented in descendants.")
|
||||
|
||||
|
||||
class InputWorkers(object):
|
||||
"""A 1-to-many mapping from input worker devices to compute devices."""
|
||||
|
||||
@ -259,9 +577,10 @@ def _get_static_shape(iterators):
|
||||
return static_shape
|
||||
|
||||
|
||||
class DistributedIteratorBase(distribute_types.Iterator):
|
||||
class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
"""Common implementation for all input iterators."""
|
||||
|
||||
# pylint: disable=super-init-not-called
|
||||
def __init__(self, input_workers, iterators, strategy):
|
||||
static_shape = _get_static_shape(iterators)
|
||||
|
||||
@ -548,9 +867,10 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
self._strategy)
|
||||
|
||||
|
||||
class _IterableInput(distribute_types.Iterable):
|
||||
class _IterableInput(DistributedDatasetInterface):
|
||||
"""Base class for iterable inputs for distribution strategies."""
|
||||
|
||||
# pylint: disable=super-init-not-called
|
||||
def __init__(self, input_workers):
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
self._input_workers = input_workers
|
||||
|
@ -75,20 +75,21 @@ def _on_write_update_replica(var, update_fn, value, **kwargs):
|
||||
class DistributedValues(object):
|
||||
"""Base class for representing distributed values.
|
||||
|
||||
A subclass instance of DistributedValues is created when creating variables
|
||||
within a distribution strategy, iterating a `tf.Dataset` or through
|
||||
`strategy.run`. This base class should never be instantiated
|
||||
directly. DistributedValues contains a value per replica. Depending on
|
||||
A subclass instance of `tf.distribute.DistributedValues` is created when
|
||||
creating variables within a distribution strategy, iterating a
|
||||
`tf.distribute.DistributedDataset` or through `tf.distribute.Strategy.run`.
|
||||
This base class should never be instantiated directly.
|
||||
`tf.distribute.DistributedValues` contains a value per replica. Depending on
|
||||
the subclass, the values could either be synced on update, synced on demand,
|
||||
or never synced.
|
||||
|
||||
DistributedValues can be reduced to obtain single value across replicas,
|
||||
as input into `run` or the per replica values inspected
|
||||
using `experimental_local_results`.
|
||||
`tf.distribute.DistributedValues` can be reduced to obtain single value across
|
||||
replicas, as input into `tf.distribute.Strategy.run` or the per-replica values
|
||||
inspected using `tf.distribute.Strategy.experimental_local_results`.
|
||||
|
||||
Example usage:
|
||||
|
||||
1. Created from Dataset:
|
||||
1. Created from a `tf.distribute.DistributedDataset`:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
|
||||
|
@ -0,0 +1,16 @@
|
||||
path: "tensorflow.distribute.DistributedDataset"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.input_lib.DistributedDatasetInterface\'>"
|
||||
is_instance: "<class \'collections.abc.Iterable\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
path: "tensorflow.distribute.DistributedIterator"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.input_lib.DistributedIteratorInterface\'>"
|
||||
is_instance: "<class \'collections.abc.Iterator\'>"
|
||||
member {
|
||||
name: "element_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "get_next"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -4,6 +4,14 @@ tf_module {
|
||||
name: "CrossDeviceOps"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DistributedDataset"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DistributedIterator"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "DistributedValues"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user