STT-tensorflow/tensorflow/python/distribute/input_lib.py

1971 lines
77 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Various classes representing distributed inputs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
import sys
import six
from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.types import distribute as distribute_types
from tensorflow.python.util import nest
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
def get_distributed_dataset(dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
"""Returns a distributed dataset from the given tf.data.Dataset instance.
This is a common function that is used by all strategies to return a
distributed dataset. The distributed dataset instance returned is different
depending on if we are in a TF 1 or TF 2 context. The distributed dataset
instances returned differ from each other in the APIs supported by each of
them.
Args:
dataset: a tf.data.Dataset instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
Returns:
A distributed dataset instance.
"""
if tf2.enabled():
return DistributedDataset(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
else:
return DistributedDatasetV1(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
def get_distributed_datasets_from_function(dataset_fn,
input_workers,
input_contexts,
strategy):
"""Returns a distributed dataset from the given input function.
This is a common function that is used by all strategies to return a
distributed dataset. The distributed dataset instance returned is different
depending on if we are in a TF 1 or TF 2 context. The distributed dataset
instances returned differ from each other in the APIs supported by each of
them.
Args:
dataset_fn: a function that returns a tf.data.Dataset instance.
input_workers: an InputWorkers object which specifies devices on which
iterators should be created.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `dataset_fn`. Length and order should match worker order in
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
Returns:
A distributed dataset instance.
"""
if tf2.enabled():
return DistributedDatasetsFromFunction(
dataset_fn,
input_workers,
input_contexts,
strategy)
else:
return DistributedDatasetsFromFunctionV1(
dataset_fn,
input_workers,
input_contexts,
strategy)
@tf_export("distribute.DistributedIterator", v1=[])
class DistributedIteratorInterface(collections.Iterator,
distribute_types.Iterator):
"""An iterator over `tf.distribute.DistributedDataset`.
`tf.distribute.DistributedIterator` is the primary mechanism for enumerating
elements of a `tf.distribute.DistributedDataset`. It supports the Python
Iterator protocol, which means it can be iterated over using a for-loop or by
fetching individual elements explicitly via `get_next()`.
You can create a `tf.distribute.DistributedIterator` by calling `iter` on
a `tf.distribute.DistributedDataset` or creating a python loop over a
`tf.distribute.DistributedDataset`.
Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
on distributed input for more examples and caveats.
"""
def get_next(self):
"""Returns the next input from the iterator for all replicas.
Example use:
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.range(100).batch(2)
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> dist_dataset_iterator = iter(dist_dataset)
>>> @tf.function
... def one_step(input):
... return input
>>> step_num = 5
>>> for _ in range(step_num):
... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),))
>>> strategy.experimental_local_results(dist_dataset_iterator.get_next())
(<tf.Tensor: shape=(2,), dtype=int64, numpy=array([10, 11])>,)
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
(<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>,
<tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>)
```
Returns:
A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains
the next input for all replicas.
Raises:
`tf.errors.OutOfRangeError`: If the end of the iterator has been reached.
"""
raise NotImplementedError(
"DistributedIterator.get_next() must be implemented in descendants.")
@property
def element_spec(self):
# pylint: disable=line-too-long
"""The type specification of an element of `tf.distribute.DistributedIterator`.
Example usage:
>>> global_batch_size = 16
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
>>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> distributed_iterator.element_spec
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
```
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this `tf.distribute.DistributedIterator`. This returned value
is typically a `tf.distribute.DistributedValues` object and specifies the
`tf.TensorSpec` of individual components.
"""
raise NotImplementedError(
"DistributedIterator.element_spec() must be implemented in descendants")
def get_next_as_optional(self):
"""Returns a `tf.experimental.Optional` that contains the next value for all replicas.
If the `tf.distribute.DistributedIterator` has reached the end of the
sequence, the returned `tf.experimental.Optional` will have no value.
Example usage:
>>> strategy = tf.distribute.MirroredStrategy()
>>> global_batch_size = 2
>>> steps_per_loop = 2
>>> dataset = tf.data.Dataset.range(10).batch(global_batch_size)
>>> distributed_iterator = iter(
... strategy.experimental_distribute_dataset(dataset))
>>> def step_fn(x):
... return x
>>> @tf.function
... def train_fn(distributed_iterator):
... for _ in tf.range(steps_per_loop):
... optional_data = distributed_iterator.get_next_as_optional()
... if not optional_data.has_value():
... break
... tf.print(strategy.run(step_fn, args=(optional_data.get_value(),)))
>>> train_fn(distributed_iterator)
... # ([0 1],)
... # ([2 3],)
Returns:
An `tf.experimental.Optional` object representing the next value from the
`tf.distribute.DistributedIterator` (if it has one) or no value.
"""
raise NotImplementedError(
"get_next_as_optional() not implemented in descendants")
@tf_export("distribute.DistributedDataset", v1=[])
class DistributedDatasetInterface(collections.Iterable,
distribute_types.Iterable):
# pylint: disable=line-too-long
"""Represents a dataset distributed among devices and machines.
A `tf.distribute.DistributedDataset` could be thought of as a "distributed"
dataset. When you use `tf.distribute` API to scale training to multiple
devices or machines, you also need to distribute the input data, which leads
to a `tf.distribute.DistributedDataset` instance, instead of a
`tf.data.Dataset` instance in the non-distributed case. In TF 2.x,
`tf.distribute.DistributedDataset` objects are Python iterables.
Note: `tf.distribute.DistributedDataset` instances are *not* of type
`tf.data.Dataset`. It only supports two usages we will mention below:
iteration and `element_spec`. We don't support any other APIs to transform or
inspect the dataset.
There are two APIs to create a `tf.distribute.DistributedDataset` object:
`tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and
`tf.distribute.Strategy.experimental_distribute_datasets_from_function(dataset_fn)`.
*When to use which?* When you have a `tf.data.Dataset` instance, and the
regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance
with a new batch size that is equal to the global batch size divided by the
number of replicas in sync) and autosharding (i.e. the
`tf.data.experimental.AutoShardPolicy` options) work for you, use the former
API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance,
or you would like to customize the batch splitting or sharding, you can wrap
these logic in a `dataset_fn` and use the latter API. Both API handles
prefetch to device for the user. For more details and examples, follow the
links to the APIs.
There are two main usages of a `DistributedDataset` object:
1. Iterate over it to generate the input for a single device or multiple
devices, which is a `tf.distribute.DistributedValues` instance. To do this,
you can:
* use a pythonic for-loop construct:
>>> global_batch_size = 2
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size)
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> @tf.function
... def train_step(input):
... features, labels = input
... return labels - 0.3 * features
>>> for x in dist_dataset:
... # train_step trains the model using the dataset elements
... loss = strategy.run(train_step, args=(x,))
... print("Loss is", loss)
Loss is tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32)
Loss is tf.Tensor(
[[0.7]
[0.7]], shape=(2, 1), dtype=float32)
Placing the loop inside a `tf.function` will give a performance boost.
However `break` and `return` are currently not supported if the loop is
placed inside a `tf.function`. We also don't support placing the loop
inside a `tf.function` when using
`tf.distribute.experimental.MultiWorkerMirroredStrategy` or
`tf.distribute.experimental.TPUStrategy` with multiple workers.
* use `__iter__` to create an explicit iterator, which is of type
`tf.distribute.DistributedIterator`
>>> global_batch_size = 4
>>> strategy = tf.distribute.MirroredStrategy()
>>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size)
>>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
>>> @tf.function
... def distributed_train_step(dataset_inputs):
... def train_step(input):
... loss = tf.constant(0.1)
... return loss
... per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
... return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None)
>>> EPOCHS = 2
>>> STEPS = 3
>>> for epoch in range(EPOCHS):
... total_loss = 0.0
... num_batches = 0
... dist_dataset_iterator = iter(train_dist_dataset)
... for _ in range(STEPS):
... total_loss += distributed_train_step(next(dist_dataset_iterator))
... num_batches += 1
... average_train_loss = total_loss / num_batches
... template = ("Epoch {}, Loss: {}")
... print (template.format(epoch+1, average_train_loss))
Epoch 1, Loss: 0.10000000894069672
Epoch 2, Loss: 0.10000000894069672
To achieve a performance improvement, you can also wrap the `strategy.run`
call with a `tf.range` inside a `tf.function`. This runs multiple steps in a
`tf.function`. Autograph will convert it to a `tf.while_loop` on the worker.
However, it is less flexible comparing with running a single step inside
`tf.function`. For example, you cannot run things eagerly or arbitrary
python code within the steps.
2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`.
`tf.distribute.DistributedDataset` generates
`tf.distribute.DistributedValues` as input to the devices. If you pass the
input to a `tf.function` and would like to specify the shape and type of
each Tensor argument to the function, you can pass a `tf.TypeSpec` object to
the `input_signature` argument of the `tf.function`. To get the
`tf.TypeSpec` of the input, you can use the `element_spec` property of the
`tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator`
object.
For example:
>>> global_batch_size = 2
>>> epochs = 1
>>> steps_per_epoch = 1
>>> mirrored_strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size)
>>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)
>>> @tf.function(input_signature=[dist_dataset.element_spec])
... def train_step(per_replica_inputs):
... def step_fn(inputs):
... return tf.square(inputs)
... return mirrored_strategy.run(step_fn, args=(per_replica_inputs,))
>>> for _ in range(epochs):
... iterator = iter(dist_dataset)
... for _ in range(steps_per_epoch):
... output = train_step(next(iterator))
... print(output)
tf.Tensor(
[[4.]
[4.]], shape=(2, 1), dtype=float32)
Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input)
on distributed input for more examples and caveats.
"""
def __iter__(self):
"""Creates an iterator for the `tf.distribute.DistributedDataset`.
The returned iterator implements the Python Iterator protocol.
Example usage:
>>> global_batch_size = 4
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size)
>>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset))
>>> print(next(distributed_iterator))
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
PerReplica:{
0: tf.Tensor([1 2], shape=(2,), dtype=int32),
1: tf.Tensor([3 4], shape=(2,), dtype=int32)
}
```
Returns:
An `tf.distribute.DistributedIterator` instance for the given
`tf.distribute.DistributedDataset` object to enumerate over the
distributed data.
"""
raise NotImplementedError("Must be implemented in descendants")
@property
def element_spec(self):
"""The type specification of an element of this `tf.distribute.DistributedDataset`.
Example usage:
>>> global_batch_size = 16
>>> strategy = tf.distribute.MirroredStrategy()
>>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size)
>>> dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> dist_dataset.element_spec
(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))
The above example corresponds to the case where you have only one device. If
you have two devices, for example,
```python
strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1'])
```
Then the final line will print out:
```python
(PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)),
PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None),
TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)))
```
Returns:
A nested structure of `tf.TypeSpec` objects matching the structure of an
element of this `tf.distribute.DistributedDataset`. This returned value is
typically a `tf.distribute.DistributedValues` object and specifies the
`tf.TensorSpec` of individual components.
"""
raise NotImplementedError(
"DistributedDataset.element_spec must be implemented in descendants.")
@doc_controls.do_not_generate_docs
def reduce(self, initial_state, reduce_func):
raise NotImplementedError(
"DistributedDataset.reduce must be implemented in descendants.")
class InputWorkers(object):
"""A 1-to-many mapping from input worker devices to compute devices."""
def __init__(self, worker_device_pairs):
"""Initialize an `InputWorkers` object.
Args:
worker_device_pairs: A sequence of pairs:
`(input device, a tuple of compute devices fed by that input device)`.
"""
self._worker_device_pairs = worker_device_pairs
self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
for _, f in self._worker_device_pairs)
@property
def num_workers(self):
return len(self._input_worker_devices)
@property
def worker_devices(self):
return self._input_worker_devices
def compute_devices_for_worker(self, worker_index):
return self._fed_devices[worker_index]
def __repr__(self):
devices = self.worker_devices
debug_repr = ",\n".join(" %d %s: %s" %
(i, devices[i], self._fed_devices[i])
for i in range(len(devices)))
return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
def serialize(self):
return self._worker_device_pairs
def deserialize(self, worker_device_pairs):
return InputWorkers(worker_device_pairs)
def _get_next_as_optional(iterator, strategy, name=None):
"""Returns an empty dataset indicator and the next input from the iterator."""
replicas = []
worker_has_values = []
worker_devices = []
for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
worker_has_value, next_element = (
iterator._iterators[i].get_next_as_list(new_name)) # pylint: disable=protected-access
# Collective all-reduce requires explicit devices for inputs.
with ops.device("/cpu:0"):
# Converting to integers for all-reduce.
worker_has_value = math_ops.cast(worker_has_value, dtypes.int64)
worker_devices.append(worker_has_value.device)
worker_has_values.append(worker_has_value)
# Make `replicas` a flat list of values across all replicas.
replicas.append(next_element)
# Run an all-reduce to see whether any worker has values.
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
# cases.
if getattr(strategy.extended, "_support_per_replica_values", True):
# Slight hack: `reduce` expects a `PerReplica`, so we pass it one, even
# though it doesn't actually have a value per replica.
worker_has_values = values.PerReplica(worker_has_values)
global_has_value = strategy.reduce(
reduce_util.ReduceOp.SUM, worker_has_values, axis=None)
else:
assert len(worker_has_values) == 1
global_has_value = worker_has_values[0]
global_has_value = array_ops.reshape(
math_ops.cast(global_has_value, dtypes.bool), [])
return global_has_value, replicas
def _is_statically_shaped(tensor_class, shape):
"""Test if an iterator output is statically shaped.
For sparse and ragged tensors this only tests the batch dimension.
Args:
tensor_class: a class from an iterator.output_classes list.
shape: a TensorShape from an iterator.output_shapes list.
Returns:
True if the shape is static, false otherwise.
"""
if (tensor_class == sparse_tensor.SparseTensor or
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
# For sparse or ragged tensor, we should only check the first
# dimension in order to get_next_as_optional. This is because
# when these tensors get batched by dataset only the batch dimension
# is set.
if shape.rank > 0 and shape.as_list()[0] is None:
return False
return True
return shape.is_fully_defined()
def _get_static_shape(iterators):
"""Returns a boolean indicating if the input is fully defined."""
static_shape = True
for iterator in iterators:
if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
_SingleWorkerDatasetIterator)):
continue
flattened = zip(nest.flatten(iterator.output_shapes),
nest.flatten(iterator.output_classes))
for output_shape, output_class in flattened:
if not _is_statically_shaped(output_class, output_shape):
static_shape = False
break
return static_shape
class DistributedIteratorBase(DistributedIteratorInterface):
"""Common implementation for all input iterators."""
# pylint: disable=super-init-not-called
def __init__(self, input_workers, iterators, strategy):
static_shape = _get_static_shape(iterators)
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
# get_next_as_optional(). And we only enable get_next_as_optional when the
# output shapes are not static.
#
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
# when user passed input_fn instead of dataset.
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
assert isinstance(input_workers, InputWorkers)
if not input_workers.worker_devices:
raise ValueError("Should have at least one worker for input iterator.")
self._iterators = iterators
self._input_workers = input_workers
self._strategy = strategy
def next(self):
return self.__next__()
def __next__(self):
try:
return self.get_next()
except errors.OutOfRangeError:
raise StopIteration
def __iter__(self):
return self
def get_next_as_optional(self):
global_has_value, replicas = _get_next_as_optional(self, self._strategy)
def return_none():
return optional_ops.Optional.empty(self._element_spec)
def return_value(replicas):
"""Wraps the inputs for replicas in an `tf.experimental.Optional`."""
results = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
devices = self._input_workers.compute_devices_for_worker(i)
for j, device in enumerate(devices):
with ops.device(device):
result = replicas[i][j]
results.append(result)
replicas = results
return optional_ops.Optional.from_value(
distribute_utils.regroup(replicas))
return control_flow_ops.cond(global_has_value,
lambda: return_value(replicas),
lambda: return_none()) # pylint: disable=unnecessary-lambda
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
if not self._enable_get_next_as_optional:
replicas = []
for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None:
d = tf_device.DeviceSpec.from_string(worker)
new_name = "%s_%s_%d" % (name, d.job, d.task)
else:
new_name = None
with ops.device(worker):
# Make `replicas` a flat list of values across all replicas.
replicas.extend(
self._iterators[i].get_next_as_list_static_shapes(new_name))
return distribute_utils.regroup(replicas)
out_of_range_replicas = []
def out_of_range_fn(worker_index, device):
"""This function will throw an OutOfRange error."""
# As this will be only called when there is no data left, so calling
# get_next() will trigger an OutOfRange error.
data = self._iterators[worker_index].get_next(device)
out_of_range_replicas.append(data)
return data
global_has_value, replicas = _get_next_as_optional(self, self._strategy)
results = []
for i, worker in enumerate(self._input_workers.worker_devices):
with ops.device(worker):
devices = self._input_workers.compute_devices_for_worker(i)
for j, device in enumerate(devices):
with ops.device(device):
# pylint: disable=undefined-loop-variable
# pylint: disable=cell-var-from-loop
# It is fine for the lambda to capture variables from the loop as
# the lambda is executed in the loop as well.
result = control_flow_ops.cond(
global_has_value,
lambda: replicas[i][j],
lambda: out_of_range_fn(i, device),
strict=True,
)
# pylint: enable=cell-var-from-loop
# pylint: enable=undefined-loop-variable
results.append(result)
replicas = results
return distribute_utils.regroup(replicas)
class DistributedIteratorV1(DistributedIteratorBase):
"""Input Iterator for a distributed dataset."""
# We need a private initializer method for re-initializing multidevice
# iterators when used with Keras training loops. If we don't reinitialize the
# iterator we run into memory leak issues (b/123315763).
@property
def _initializer(self):
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
return control_flow_ops.group(init_ops)
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
"""Initialize underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
return self._initializer
@property
def initializer(self):
"""Returns a list of ops that initialize the iterator."""
return self.initialize()
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_classes(self):
return self._iterators[0].output_classes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_shapes(self):
return self._iterators[0].output_shapes
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
@property
def output_types(self):
return self._iterators[0].output_types
# TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs.
def get_iterator(self, worker):
for i, w in enumerate(self._input_workers.worker_devices):
if worker == w:
return self._iterators[i]
return None
@property
def element_spec(self):
"""The type specification of an element of this iterator."""
return self._element_spec
class DistributedIteratorSpec(type_spec.TypeSpec):
"""Type specification for `DistributedIterator`."""
__slots__ = ["_input_workers", "_element_spec", "_strategy"]
def __init__(self, input_workers, element_spec, strategy):
# We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels.
if isinstance(input_workers, tuple):
raise NotImplementedError("DistributedIteratorSpec does not have support "
"for deserialization.")
else:
self._input_workers = input_workers
self._element_spec = element_spec
self._strategy = strategy
@property
def value_type(self):
return DistributedIterator
def _serialize(self):
# We cannot serialize the strategy object so we convert it to an id that we
# can use for comparison.
return (self._input_workers.serialize(),
self._element_spec, id(self._strategy))
def _deserialize(self):
raise ValueError("Deserialization is currently unsupported for "
"DistributedIteratorSpec.")
# Overriding this method so that we can merge and reconstruct the spec object
def most_specific_compatible_type(self, other):
"""Returns the most specific TypeSpec compatible with `self` and `other`.
Args:
other: A `TypeSpec`.
Raises:
ValueError: If there is no TypeSpec that is compatible with both `self`
and `other`.
"""
# pylint: disable=protected-access
if type(self) is not type(other):
raise ValueError("No TypeSpec is compatible with both %s and %s" %
(self, other))
if self._input_workers.serialize() != other._input_workers.serialize():
raise ValueError("_input_workers is not compatible with both %s "
"and %s" % (self, other))
if self._strategy is not other._strategy:
raise ValueError("tf.distribute strategy is not compatible with both %s "
"and %s" % (self, other))
element_spec = nest.map_structure(
lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
other._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy)
@property
def _component_specs(self):
specs = []
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
element_spec = nest.map_structure(
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
compute_devices,
element_spec))
return specs
def _to_components(self, value):
return value._iterators # pylint: disable=protected-access
def _from_components(self, components):
return DistributedIterator(input_workers=self._input_workers,
iterators=None,
components=components,
element_spec=self._element_spec,
strategy=self._strategy)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedIteratorSpec(value._input_workers, value._element_spec,
value._strategy)
def _with_tensor_ranks_only(self):
element_spec = nest.map_structure(
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
self._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy)
class DistributedIterator(DistributedIteratorBase,
composite_tensor.CompositeTensor):
"""Input Iterator for a distributed dataset."""
def __init__(self, input_workers=None, iterators=None, strategy=None,
components=None, element_spec=None):
if input_workers is None:
raise ValueError("`input_workers` should be "
"provided.")
error_message = ("Either `input_workers` or "
"both `components` and `element_spec` need to be "
"provided.")
if iterators is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._input_workers = input_workers
self._iterators = components
static_shape = _get_static_shape(self._iterators)
self._strategy = strategy
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = (
not static_shape) or strategy.extended._in_multi_worker_mode()
else:
self._enable_get_next_as_optional = False
else:
if (components is not None and element_spec is not None):
raise ValueError(error_message)
super(DistributedIterator, self).__init__(input_workers, iterators,
strategy)
@property
def element_spec(self):
return self._element_spec
@property
def _type_spec(self):
return DistributedIteratorSpec(self._input_workers,
self.element_spec,
self._strategy)
class _IterableInput(DistributedDatasetInterface):
"""Base class for iterable inputs for distribution strategies."""
# pylint: disable=super-init-not-called
def __init__(self, input_workers):
assert isinstance(input_workers, InputWorkers)
self._input_workers = input_workers
def __iter__(self):
raise NotImplementedError("must be implemented in descendants")
def reduce(self, initial_state, reduce_fn):
"""Execute a `reduce_fn` over all the elements of the input."""
iterator = iter(self)
has_data, data = _get_next_as_optional(iterator, self._strategy)
def cond(has_data, data, state):
del data, state # Unused.
return has_data
def loop_body(has_data, data, state):
"""Executes `reduce_fn` in a loop till the dataset is empty."""
del has_data # Unused.
# data is list of lists here. where each list corresponds to one worker.
# TODO(b/130570614): Add support for the multiworker and TPU pods use
# case.
if self._input_workers.num_workers == 1:
data = data[0]
else:
raise ValueError("Dataset iteration within a tf.function is"
" not supported for multiple workers.")
state = reduce_fn(state, distribute_utils.regroup(data))
has_data, data = _get_next_as_optional(iterator, self._strategy)
return has_data, data, state
has_data, data, final_state = control_flow_ops.while_loop(
cond, loop_body, [has_data, data, initial_state], parallel_iterations=1)
return final_state
class DistributedDataset(_IterableInput):
"""Distributed dataset that supports prefetching to multiple devices."""
def __init__(self,
dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
"""Distribute the dataset on all workers.
If `split_batch_by` is not None, we "split" each batch of the dataset by
`split_batch_by` value.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
"""
super(DistributedDataset, self).__init__(input_workers=input_workers)
# We clone and shard the dataset on each worker. The current setup tries to
# shard the dataset by files if possible so that each worker sees a
# different subset of files. If that is not possible, will attempt to shard
# the final input such that each worker will run the entire preprocessing
# pipeline and only receive its own shard of the dataset.
if split_batch_by:
try:
# pylint: disable=protected-access
with ops.colocate_with(dataset._variant_tensor):
dataset = distribute._RebatchDataset(dataset, split_batch_by)
# Add a prefetch to pipeline rebatching for performance.
# TODO(rachelim): Instead of inserting an extra prefetch stage here,
# leverage static graph rewrites to insert _RebatchDataset before
# the final `prefetch` if it exists.
dataset = dataset.prefetch(split_batch_by)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(
ValueError,
ValueError(
"Call the `batch` method on the input Dataset in order to be "
"able to split your input across {} replicas.\n Please "
"the tf.distribute.Strategy guide. {}".format(
split_batch_by, e)),
sys.exc_info()[2])
else:
raise
self._cloned_datasets = []
if input_context:
# Between-graph where we rely on the input_context for sharding
assert input_workers.num_workers == 1
dataset = input_ops.auto_shard_dataset(dataset,
input_context.num_input_pipelines,
input_context.input_pipeline_id)
self._cloned_datasets.append(dataset)
else:
replicated_ds = distribute.replicate(dataset,
input_workers.worker_devices)
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
cloned_dataset = replicated_ds[worker]
cloned_dataset = cloned_dataset.with_options(dataset.options())
cloned_dataset = input_ops.auto_shard_dataset(
cloned_dataset, len(input_workers.worker_devices), i)
self._cloned_datasets.append(cloned_dataset)
self._input_workers = input_workers
self._strategy = strategy
self._element_spec = _create_distributed_tensor_spec(self._strategy,
dataset.element_spec) # pylint: disable=protected-access
def __iter__(self):
if not (context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# This is an optional flag that can be used to turn off using
# OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
# as a stop gap solution that will allow us to roll out this change.
enable_legacy_iterators = getattr(self._strategy,
"_enable_legacy_iterators", False)
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers,
enable_legacy_iterators)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
else:
iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
return iterator
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
return self._element_spec
class DistributedDatasetV1(DistributedDataset):
"""Distributed dataset that supports prefetching to multiple devices."""
def __init__(self,
dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
self._input_workers = input_workers
super(DistributedDatasetV1, self).__init__(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
def make_one_shot_iterator(self):
"""Get a one time use iterator for DistributedDatasetV1.
Note: This API is deprecated. Please use `for ... in dataset:` to iterate
over the dataset or `iter` to create an iterator.
Returns:
A DistributedIteratorV1 instance.
"""
return self._make_one_shot_iterator()
def _make_one_shot_iterator(self):
"""Get an iterator for DistributedDatasetV1."""
# Graph mode with one shot iterator is disabled because we have to call
# `initialize` on the iterator which is only required if we are using a
# tf.distribute strategy.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
return self._get_iterator()
def make_initializable_iterator(self):
"""Get an initializable iterator for DistributedDatasetV1.
Note: This API is deprecated. Please use
`tf.compat.v1.data.make_initializable_iterator(dataset)` to create an
initializable iterator.
Returns:
A DistributedIteratorV1 instance.
"""
return self._make_initializable_iterator()
def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument
"""Get an initializable iterator for DistributedDatasetV1."""
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `iter()` instead.")
return self._get_iterator()
def _get_iterator(self):
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers,
True)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
return iterator
def __iter__(self):
if (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
return self._get_iterator()
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# TODO(priyag): Add other replication modes.
class DistributedDatasetsFromFunction(_IterableInput):
"""Inputs created from dataset function."""
def __init__(self, dataset_fn, input_workers, input_contexts, strategy):
"""Makes an iterable from datasets created by the given function.
Args:
dataset_fn: A function that returns a `Dataset` given an `InputContext`.
input_workers: an `InputWorkers` object.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `dataset_fn`. Length and order should match worker order in
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
"""
super(DistributedDatasetsFromFunction, self).__init__(
input_workers=input_workers)
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
self._input_workers = input_workers
self._input_contexts = input_contexts
self._strategy = strategy
self._datasets, element_spec = (
_create_datasets_per_worker_with_input_context(self._input_contexts,
self._input_workers,
dataset_fn))
self._element_spec = _create_distributed_tensor_spec(
self._strategy, element_spec)
def __iter__(self):
if (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
# This is an optional flag that can be used to turn off using
# OwnedMultiDeviceIterators and instead use the legacy
# MultiDeviceIterators as a stop gap solution that will allow us to roll
# out this change.
enable_legacy_iterators = getattr(self._strategy,
"_enable_legacy_iterators", False)
iterators = _create_iterators_per_worker(self._datasets,
self._input_workers,
enable_legacy_iterators)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy)
else:
iterator = DistributedIterator(self._input_workers, iterators,
self._strategy)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
return iterator
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
if self._element_spec is None:
raise ValueError("You must create an iterator before calling "
"`element_spec` on the distributed dataset or iterator. "
"This is because the dataset function is not called "
"before an iterator is created.")
return self._element_spec
class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
"""Inputs created from dataset function."""
def _make_initializable_iterator(self, shared_name=None):
"""Get an initializable iterator for DistributedDatasetsFromFunctionV1."""
del shared_name # Unused
# Eager mode generates already initialized iterators. Hence we cannot create
# an initializable iterator.
if context.executing_eagerly():
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `iter()` instead.")
return self._get_iterator()
def _make_one_shot_iterator(self):
"""Get an iterator for iterating over DistributedDatasetsFromFunctionV1."""
# Graph mode with one shot iterator is disabled because we have to call
# `initialize` on the iterator which is only required if we are using a
# tf.distribute strategy.
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
return self._get_iterator()
def _get_iterator(self):
iterators = _create_iterators_per_worker(self._datasets,
self._input_workers, True)
iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
return iterator
def __iter__(self):
if (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
return self._get_iterator()
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
# APIs.
class InputFunctionIterator(DistributedIteratorV1):
"""Iterator created from input function."""
def __init__(self, input_fn, input_workers, input_contexts, strategy):
"""Make an iterator for input provided via an input function.
Currently implements PER_WORKER mode, in which the `input_fn` is called
once on each worker.
TODO(priyag): Add other replication modes.
Args:
input_fn: Input function that returns a `tf.data.Dataset` object.
input_workers: an `InputWorkers` object.
input_contexts: A list of `InputContext` instances to be passed to call(s)
to `input_fn`. Length and order should match worker order in
`worker_device_pairs`.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
"""
assert isinstance(input_workers, InputWorkers)
if input_workers.num_workers != len(input_contexts):
raise ValueError(
"Number of input workers (%d) is not same as number of "
"input_contexts (%d)" %
(input_workers.num_workers, len(input_contexts)))
iterators = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
result = input_fn(ctx)
devices = input_workers.compute_devices_for_worker(i)
if isinstance(result, dataset_ops.DatasetV2):
iterator = _SingleWorkerDatasetIterator(result, worker, devices)
elif callable(result):
iterator = _SingleWorkerCallableIterator(result, worker, devices)
else:
raise ValueError(
"input_fn must return a tf.data.Dataset or a callable.")
iterators.append(iterator)
super(InputFunctionIterator, self).__init__(input_workers, iterators,
strategy)
# TODO(anjalisridhar): This class will soon be removed and users should move
# to using DistributedIterator.
class DatasetIterator(DistributedIteratorV1):
"""Iterator created from input dataset."""
def __init__(self,
dataset,
input_workers,
strategy,
split_batch_by=None,
input_context=None):
"""Make an iterator for the dataset on given devices.
If `split_batch_by` is not None, we "split" each batch of the
dataset by `split_batch_by` value.
Args:
dataset: `tf.data.Dataset` that will be used as the input source.
input_workers: an `InputWorkers` object.
strategy: a `tf.distribute.Strategy` object, used to run all-reduce to
handle last partial batch.
split_batch_by: Optional integer. If present, we "split" each batch of the
dataset by `split_batch_by` value.
input_context: `InputContext` for sharding. Only pass this in for between
graph multi-worker cases where there is only one `input_worker`. In
these cases, we will shard based on the `input_pipeline_id` and
`num_input_pipelines` in the `InputContext`.
"""
dist_dataset = DistributedDatasetV1(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
super(DatasetIterator, self).__init__(
input_workers,
worker_iterators, # pylint: disable=protected-access
strategy)
self._element_spec = dist_dataset.element_spec
def _dummy_tensor_fn(value_structure):
"""A function to create dummy tensors from `value_structure`."""
def create_dummy_tensor(spec):
"""Create a dummy tensor with possible batch dimensions set to 0."""
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
# Splice out the ragged dimensions.
# pylint: disable=protected-access
feature_shape = spec._shape[:1].concatenate(
spec._shape[(1 + spec._ragged_rank):])
feature_type = spec._dtype
# pylint: enable=protected-access
else:
feature_shape = spec.shape
feature_type = spec.dtype
# Ideally we should set the batch dimension to 0, however as in
# DistributionStrategy we don't know the batch dimension, we try to
# guess it as much as possible. If the feature has unknown dimensions, we
# will set them to 0. If the feature shape is already static, we guess the
# first dimension as batch dimension and set it to 0.
dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
if feature_shape else [])
if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
feature_shape.is_fully_defined()):
dims[0] = tensor_shape.Dimension(0)
if isinstance(spec, sparse_tensor.SparseTensorSpec):
return sparse_tensor.SparseTensor(
values=array_ops.zeros(0, feature_type),
indices=array_ops.zeros((0, len(dims)), dtypes.int64),
dense_shape=dims)
# Create the dummy tensor.
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
# Reinsert the ragged dimensions with size 0.
# pylint: disable=protected-access
row_splits = array_ops.zeros(1, spec._row_splits_dtype)
dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
# pylint: enable=protected-access
return dummy_tensor
return nest.map_structure(create_dummy_tensor, value_structure)
def _recover_shape_fn(data, value_structure):
"""Recover the shape of `data` the same as shape of `value_structure`."""
flattened_data = nest.flatten(data)
for i, spec in enumerate(nest.flatten(value_structure)):
for target, source in zip(
nest.flatten(flattened_data[i], expand_composites=True),
nest.flatten(spec, expand_composites=True)):
target.set_shape(source.shape)
# `SparseTensor` shape is not determined by the shape of its component
# tensors. Rather, its shape depends on a tensor's values.
if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape:
dense_shape = spec.shape
with ops.device(flattened_data[i].op.device):
# For partially defined shapes, fill in missing values from tensor.
if not dense_shape.is_fully_defined():
dense_shape = array_ops.stack([
flattened_data[i].dense_shape[j] if dim is None else dim
for j, dim in enumerate(dense_shape.as_list())
])
flattened_data[i] = sparse_tensor.SparseTensor(
indices=flattened_data[i].indices,
values=flattened_data[i].values,
dense_shape=dense_shape)
data = nest.pack_sequence_as(data, flattened_data)
return data
class _SingleWorkerDatasetIteratorBase(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
input to the devices on the given worker.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
"""
self._dataset = dataset
self._worker = worker
self._devices = devices
self._element_spec = dataset.element_spec
self._make_iterator()
def _make_iterator(self):
raise NotImplementedError("must be implemented in descendants")
def get_next(self, device, name=None):
"""Get next element for the given device."""
del name
with ops.device(self._worker):
return self._iterator.get_next(device)
def get_next_as_list_static_shapes(self, name=None):
"""Get next element from the underlying iterator.
Runs the iterator get_next() within a device scope. Since this doesn't use
get_next_as_optional(), is is considerably faster than get_next_as_list()
(but can only be used when the shapes are static).
Args:
name: not used.
Returns:
A list consisting of the next data from each device.
"""
del name
with ops.device(self._worker):
return self._iterator.get_next()
def get_next_as_list(self, name=None):
"""Get next element from underlying iterator.
If there is no data left, a list of dummy tensors with possible batch
dimensions set to 0 will be returned. Use of get_next_as_optional() and
extra logic adds overhead compared to get_next_as_list_static_shapes(), but
allows us to handle non-static shapes.
Args:
name: not used.
Returns:
A boolean tensor indicates whether there is any data in next element and
the real data as the next element or a list of dummy tensors if no data
left.
"""
del name
with ops.device(self._worker):
data_list = self._iterator.get_next_as_optional()
result = []
for i, data in enumerate(data_list):
# Place the condition op in the same device as the data so the data
# doesn't need to be sent back to the worker.
with ops.device(self._devices[i]):
# Data will be fetched in order, so we only need to check if the first
# replica has value to see whether there is data left for this single
# worker.
if i == 0:
worker_has_value = data.has_value()
# pylint: disable=unnecessary-lambda
# pylint: disable=cell-var-from-loop
real_data = control_flow_ops.cond(
data.has_value(),
lambda: data.get_value(),
lambda: _dummy_tensor_fn(data.element_spec),
strict=True,
)
# Some dimensions in `replicas` will become unknown after we
# conditionally return the real tensors or the dummy tensors. Recover
# the shapes from `data.element_spec`. We only need to do this in
# non eager mode because we always know the runtime shape of the
# tensors in eager mode.
if not context.executing_eagerly():
real_data = _recover_shape_fn(real_data, data.element_spec)
result.append(real_data)
# pylint: enable=cell-var-from-loop
# pylint: enable=unnecessary-lambda
return worker_has_value, result
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
__slots__ = ["_worker", "_devices", "_element_spec"]
def __init__(self, worker, devices, element_spec):
self._worker = worker
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._element_spec = element_spec
@property
def value_type(self):
return _SingleWorkerOwnedDatasetIterator
def _serialize(self):
return (self._worker, self._devices, self._element_spec)
@property
def _component_specs(self):
specs = []
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, self._worker, element_spec=self._element_spec))
return specs
def _to_components(self, value):
return [value._iterator] # pylint: disable=protected-access
def _from_components(self, components):
return _SingleWorkerOwnedDatasetIterator(
dataset=None,
worker=self._worker,
devices=self._devices,
components=components,
element_spec=self._element_spec)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
value._element_spec)
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
composite_tensor.CompositeTensor):
"""Iterator for a DistributedDataset instance."""
def __init__(self, dataset=None, worker=None, devices=None, components=None,
element_spec=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
given worker. The lifetime of this iterator is tied to the encompassing
python object. Once we go out of scope of the python object or return from
a tf.function the underlying iterator resource is deleted.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
components: Tensor components to construct the
_SingleWorkerOwnedDatasetIterator from.
element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of elements of the iterator.
"""
if worker is None or devices is None:
raise ValueError("Both `worker` and `devices` should be provided")
error_message = ("Either `dataset` or both `components` and `element_spec` "
"need to be provided.")
if dataset is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._worker = worker
self._devices = devices
self._iterator = components[0]
else:
if (components is not None or element_spec is not None):
raise ValueError(error_message)
super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker,
devices)
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
if not self._worker:
raise ValueError("Worked device must be specified when creating an "
"owned iterator.")
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset, self._devices, source_device=host_device)
@property
def element_spec(self):
return self._element_spec
@property
def _type_spec(self):
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
self._element_spec)
@property
def output_classes(self):
"""Returns the class of each component of an element of this iterator.
The expected values are `tf.Tensor` and `tf.SparseTensor`.
Returns:
A nested structure of Python `type` objects corresponding to each
component of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._element_spec)
@property
def output_shapes(self):
"""Returns the shape of each component of an element of this iterator.
Returns:
A nested structure of `tf.TensorShape` objects corresponding to each
component of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._element_spec)
@property
def output_types(self):
"""Returns the type of each component of an element of this iterator.
Returns:
A nested structure of `tf.DType` objects corresponding to each component
of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._element_spec)
class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
"""Iterator for a single DistributedDatasetV1 instance."""
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
def initialize(self):
"""Initialize underlying iterator.
In eager execution, this simply recreates the underlying iterator.
In graph execution, it returns the initializer ops for the underlying
iterator.
Returns:
A list of any initializer ops that should be run.
"""
if ops.executing_eagerly_outside_functions():
self._iterator._eager_reset() # pylint: disable=protected-access
return []
else:
return [self._iterator.initializer]
@property
def output_classes(self):
return dataset_ops.get_legacy_output_classes(self._iterator)
@property
def output_shapes(self):
return dataset_ops.get_legacy_output_shapes(self._iterator)
@property
def output_types(self):
return dataset_ops.get_legacy_output_types(self._iterator)
class _SingleWorkerCallableIterator(object):
"""Iterator for a single tensor-returning callable."""
def __init__(self, fn, worker, devices):
self._fn = fn
self._worker = worker
self._devices = devices
def get_next(self, device, name=None):
"""Get next element for the given device from the callable."""
del device, name
with ops.device(self._worker):
return self._fn()
def get_next_as_list_static_shapes(self, name=None):
"""Get next element from the callable."""
del name
with ops.device(self._worker):
data_list = [self._fn() for _ in self._devices]
return data_list
def get_next_as_list(self, name=None):
"""Get next element from the callable."""
del name
with ops.device(self._worker):
data_list = [self._fn() for _ in self._devices]
return constant_op.constant(True), data_list
def initialize(self):
# TODO(petebu) Should this throw an exception instead?
return []
def _create_iterators_per_worker(worker_datasets, input_workers,
enable_legacy_iterators):
"""Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers)
assert len(worker_datasets) == len(input_workers.worker_devices)
iterators = []
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
if tf2.enabled() and not enable_legacy_iterators:
iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker,
worker_devices)
else:
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
iterators.append(iterator)
return iterators
def _create_datasets_per_worker_with_input_context(input_contexts,
input_workers, dataset_fn):
"""Create device datasets per worker given a dataset function."""
datasets = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
dataset = dataset_fn(ctx)
datasets.append(dataset)
return datasets, dataset.element_spec
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
"""Get the batched dataset from `d`."""
# pylint: disable=protected-access
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)):
return d
elif isinstance(d, (dataset_ops.PrefetchDataset,
dataset_ops._OptionsDataset)):
return _get_batched_dataset(d._input_dataset)
raise ValueError(
"Unable to get batched dataset from the input dataset. `batch` "
"`map_and_batch` need to be the last operations on the dataset. "
"The batch operations can be followed by a prefetch.")
def _get_batched_dataset_attributes(d):
"""Get `batch_size`, `drop_remainder` of dataset."""
# pylint: disable=protected-access
assert isinstance(d,
(dataset_ops.BatchDataset, batching._MapAndBatchDataset))
if isinstance(d, dataset_ops.BatchDataset):
batch_size = d._batch_size
drop_remainder = d._drop_remainder
elif isinstance(d, batching._MapAndBatchDataset):
batch_size = d._batch_size_t
drop_remainder = d._drop_remainder_t
# pylint: enable=protected-access
if tensor_util.is_tensor(batch_size):
batch_size = tensor_util.constant_value(batch_size)
if tensor_util.is_tensor(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
return batch_size, drop_remainder
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_dataset_attributes(dataset):
"""Get the underlying attributes from the dataset object."""
# pylint: disable=protected-access
# First, get batch_size and drop_remainder from the dataset. We need
# to walk back the dataset creation process and find the batched version in
# order to get the attributes.
batched_dataset = _get_batched_dataset(dataset)
batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
# Second, prefetch buffer should be get from the original dataset.
prefetch_buffer = None
if isinstance(dataset, dataset_ops.PrefetchDataset):
prefetch_buffer = dataset._buffer_size
elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
prefetch_buffer = dataset._dataset._buffer_size
return batch_size, drop_remainder, prefetch_buffer
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.
This context object is useful when running multiple steps at a time using the
`experimental_run_steps_on_iterator` API. For e.g. it allows the user's step
function to specify which outputs to emit at what frequency. Currently it
supports capturing output from the last step, as well as capturing non tensor
outputs. In the future it will be augmented to support other use cases such
as output each N steps.
"""
def __init__(self):
"""Initialize an output context.
Returns:
A context object.
"""
self._last_step_outputs = {}
self._last_step_outputs_reduce_ops = {}
self._non_tensor_outputs = {}
@property
def last_step_outputs(self):
"""A dictionary consisting of outputs to be captured on last step.
Keys in the dictionary are names of tensors to be captured, as specified
when `set_last_step_output` is called.
Values in the dictionary are the tensors themselves. If
`set_last_step_output` was called with a `reduce_op` for this output,
then the value is the reduced value.
Returns:
A dictionary with last step outputs.
"""
return self._last_step_outputs
def _set_last_step_outputs(self, outputs):
"""Replace the entire dictionary of last step outputs."""
if not isinstance(outputs, dict):
raise ValueError("Need a dictionary to set last_step_outputs.")
self._last_step_outputs = outputs
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
axis=None)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
axis=None)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
@property
def non_tensor_outputs(self):
"""A dictionary consisting of any non tensor outputs to be captured."""
return self._non_tensor_outputs
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = (
distribution.experimental_local_results(value))
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
def _create_distributed_tensor_spec(strategy, tensor_spec):
"""Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`.
Args:
strategy: The given `tf.distribute` strategy.
tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the
shape should be None if you have partial batches.
Returns:
A `tf.TypeSpec` that matches the values produced by a given strategy. This
can be a `tf.TensorSpec` or a `PerRelicaSpec`.
"""
num_replicas = len(strategy.extended.worker_devices)
# If the number of devices used in the strategy is just 1 then we return
# the tensor_spec as is.
if num_replicas == 1:
return tensor_spec
# If the number of devices is greater than 1 then we assume the input to
# tf.function is a per replica type.
def _get_value_per_replica(tensor_spec_per_input):
value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
return values.PerReplicaSpec(*value_specs)
return nest.map_structure(_get_value_per_replica, tensor_spec)
def _replace_per_replica_spec(spec, i):
"""If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec."""
if isinstance(spec, values.PerReplicaSpec):
return spec._value_specs[i] # pylint: disable=protected-access
else:
return spec