Add CompositeTensor support to distributed iterators.

PiperOrigin-RevId: 308423654
Change-Id: I406d7adeca6059112768f0b8a02a80fbbd463463
This commit is contained in:
Anjali Sridhar 2020-04-25 10:20:14 -07:00 committed by TensorFlower Gardener
parent be3e3961f8
commit dac6d6ae7c
6 changed files with 928 additions and 189 deletions

View File

@ -847,6 +847,38 @@ distribute_py_test(
srcs = ["input_lib_test.py"],
main = "input_lib_test.py",
shard_count = 10,
tags = [
"multi_and_single_gpu",
"no_gpu_presubmit", # TODO(b/154660040)
],
deps = [
":collective_all_reduce_strategy",
":combinations",
":input_lib",
":mirrored_strategy",
":multi_worker_test_base",
":reduce_util",
":strategy_combinations",
":values",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"//tensorflow/python/ops/ragged:ragged_tensor",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
distribute_py_test(
name = "input_lib_type_spec_test",
srcs = ["input_lib_type_spec_test.py"],
main = "input_lib_type_spec_test.py",
shard_count = 10,
tags = [
"multi_and_single_gpu",
],
@ -1453,9 +1485,10 @@ distribute_py_test(
name = "ctl_correctness_test",
srcs = ["ctl_correctness_test.py"],
main = "ctl_correctness_test.py",
shard_count = 10,
shard_count = 20,
tags = [
"multi_and_single_gpu",
"no_gpu_presubmit", # TODO(b/154660040)
"noguitar", # b/140755528
],
deps = [

View File

@ -330,8 +330,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
communication=self._communication)
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
local_devices)
host_device = device_util.get_host_for_device(self._worker_device)
self._input_workers = input_lib.InputWorkers(
[(self._worker_device, self.worker_devices)])
[(host_device, self.worker_devices)])
# Add a default device so that ops without specified devices will not end up
# on other workers.

View File

@ -33,6 +33,7 @@ from tensorflow.python.distribute import input_ops
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
from tensorflow.python.eager import context
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
@ -41,6 +42,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@ -143,9 +145,10 @@ class InputWorkers(object):
worker_device_pairs: A sequence of pairs:
`(input device, a tuple of compute devices fed by that input device)`.
"""
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
self._worker_device_pairs = worker_device_pairs
self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
for _, f in worker_device_pairs)
for _, f in self._worker_device_pairs)
@property
def num_workers(self):
@ -165,6 +168,12 @@ class InputWorkers(object):
for i in range(len(devices)))
return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
def serialize(self):
return self._worker_device_pairs
def deserialize(self, worker_device_pairs):
return InputWorkers(worker_device_pairs)
def _get_next_as_optional(iterator, strategy, name=None):
"""Returns an empty dataset indicator and the next input from the iterator."""
@ -208,7 +217,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
def _is_statically_shaped(tensor_class, shape):
"""Test if an iteratort output is statically shaped.
"""Test if an iterator output is statically shaped.
For sparse and ragged tensors this only tests the batch dimension.
@ -231,13 +240,12 @@ def _is_statically_shaped(tensor_class, shape):
return shape.is_fully_defined()
class DistributedIterator(object):
"""Common implementation for all input iterators."""
def __init__(self, input_workers, iterators, strategy):
def _get_static_shape(iterators):
"""Returns a boolean indicating if the input is fully defined."""
static_shape = True
for iterator in iterators:
if not isinstance(iterator, _SingleWorkerDatasetIterator):
if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
_SingleWorkerDatasetIterator)):
continue
flattened = zip(nest.flatten(iterator.output_shapes),
nest.flatten(iterator.output_classes))
@ -245,6 +253,14 @@ class DistributedIterator(object):
if not _is_statically_shaped(output_class, output_shape):
static_shape = False
break
return static_shape
class DistributedIteratorBase(object):
"""Common implementation for all input iterators."""
def __init__(self, input_workers, iterators, strategy):
static_shape = _get_static_shape(iterators)
# TODO(b/133073708): we currently need a flag to control the usage because
# there is a performance difference between get_next() and
@ -360,6 +376,10 @@ class DistributedIterator(object):
return values.regroup(replicas)
class DistributedIteratorV1(DistributedIteratorBase):
"""Input Iterator for a distributed dataset."""
# We need a private initializer method for re-initializing multidevice
# iterators when used with Keras training loops. If we don't reinitialize the
# iterator we run into memory leak issues (b/123315763).
@ -370,23 +390,14 @@ class DistributedIterator(object):
init_ops.extend(it.initialize())
return control_flow_ops.group(init_ops)
@property
def element_spec(self):
"""The type specification of an element of this iterator."""
return self._element_spec
class DistributedIteratorV1(DistributedIterator):
"""Input Iterator for a distributed dataset instance."""
@deprecated(None, "Use the iterator's `initializer` property instead.")
def initialize(self):
"""Initialze underlying iterators.
"""Initialize underlying iterators.
Returns:
A list of any initializer ops that should be run.
"""
return super(DistributedIteratorV1, self)._initializer
return self._initializer
@property
def initializer(self):
@ -415,6 +426,161 @@ class DistributedIteratorV1(DistributedIterator):
return self._iterators[i]
return None
@property
def element_spec(self):
"""The type specification of an element of this iterator."""
return self._element_spec
class DistributedIteratorSpec(type_spec.TypeSpec):
"""Type specification for `DistributedIterator`."""
__slots__ = ["_input_workers", "_element_spec", "_strategy"]
def __init__(self, input_workers, element_spec, strategy):
# We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels.
if isinstance(input_workers, tuple):
raise NotImplementedError("DistributedIteratorSpec does not have support "
"for deserialization.")
else:
self._input_workers = input_workers
self._element_spec = element_spec
self._strategy = strategy
@property
def value_type(self):
return DistributedIterator
def _serialize(self):
# We cannot serialize the strategy object so we convert it to an id that we
# can use for comparison.
return (self._input_workers.serialize(),
self._element_spec, id(self._strategy))
def _deserialize(self):
raise ValueError("Deserialization is currently unsupported for "
"DistributedIteratorSpec.")
@staticmethod
def _is_compatible(a, b):
"""Returns true if the given type serializations compatible."""
if type(a) is not type(b):
return False
if isinstance(a, tuple):
return (len(a) == len(b) and
all(DistributedIteratorSpec._is_compatible(x, y) for (x, y) in
zip(a, b)))
if isinstance(a, dict):
return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
DistributedIteratorSpec._is_compatible(a[k], b[k]) for k in a.keys()))
if isinstance(a, (type_spec.TypeSpec, tensor_shape.TensorShape,
dtypes.DType)):
return a.is_compatible_with(b)
return a == b
# Overriding this method so that we can merge and reconstruct the spec object
def most_specific_compatible_type(self, other):
"""Returns the most specific TypeSpec compatible with `self` and `other`.
Args:
other: A `TypeSpec`.
Raises:
ValueError: If there is no TypeSpec that is compatible with both `self`
and `other`.
"""
# pylint: disable=protected-access
if type(self) is not type(other):
raise ValueError("No TypeSpec is compatible with both %s and %s" %
(self, other))
if not self._is_compatible(self._input_workers.serialize(),
other._input_workers.serialize()):
raise ValueError("_input_workers is not compatible with both %s "
"and %s" % (self, other))
if self._element_spec != other._element_spec:
raise ValueError("_element_spec is not compatible with both %s "
"and %s" % (self, other))
if id(self._strategy) != id(other._strategy):
raise ValueError("tf.distribute strategy is not compatible with both %s "
"and %s" % (self, other))
return DistributedIteratorSpec(self._input_workers, self._element_spec,
self._strategy)
@property
def _component_specs(self):
specs = []
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
for i in range(len(worker_device_pairs)):
input_device, compute_devices = worker_device_pairs[i]
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
compute_devices,
element_spec=
self._element_spec))
return specs
def _to_components(self, value):
return value._iterators # pylint: disable=protected-access
def _from_components(self, components):
return DistributedIterator(input_workers=self._input_workers,
iterators=None,
components=components,
element_spec=self._element_spec,
strategy=self._strategy)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedIteratorSpec(value._input_workers, value._element_spec,
value._strategy)
class DistributedIterator(DistributedIteratorBase,
composite_tensor.CompositeTensor):
"""Input Iterator for a distributed dataset."""
def __init__(self, input_workers=None, iterators=None, strategy=None,
components=None, element_spec=None):
if input_workers is None:
raise ValueError("`input_workers` should be "
"provided.")
error_message = ("Either `input_workers` or "
"both `components` and `element_spec` need to be "
"provided.")
if iterators is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._input_workers = input_workers
self._iterators = components
static_shape = _get_static_shape(self._iterators)
self._strategy = strategy
if getattr(
strategy.extended, "experimental_enable_get_next_as_optional", False):
self._enable_get_next_as_optional = not static_shape
else:
self._enable_get_next_as_optional = False
else:
if (components is not None and element_spec is not None):
raise ValueError(error_message)
super(DistributedIterator, self).__init__(input_workers, iterators,
strategy)
@property
def element_spec(self):
return self._element_spec
@property
def _type_spec(self):
return DistributedIteratorSpec(self._input_workers,
self.element_spec,
self._strategy)
class _IterableInput(object):
"""Base class for iterable inputs for distribution strategies."""
@ -482,7 +648,6 @@ class DistributedDataset(_IterableInput):
`num_input_pipelines` in the `InputContext`.
"""
super(DistributedDataset, self).__init__(input_workers=input_workers)
# We clone and shard the dataset on each worker. The current setup tries to
# shard the dataset by files if possible so that each worker sees a
# different subset of files. If that is not possible, will attempt to shard
@ -541,8 +706,18 @@ class DistributedDataset(_IterableInput):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# This is an optional flag that can be used to turn off using
# OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
# as a stop gap solution that will allow us to roll out this change.
enable_legacy_iterators = getattr(self._strategy,
"_enable_legacy_iterators", False)
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers)
self._input_workers,
enable_legacy_iterators)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
else:
iterator = DistributedIterator(self._input_workers, worker_iterators,
self._strategy)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
@ -615,12 +790,21 @@ class DistributedDatasetV1(DistributedDataset):
def _get_iterator(self):
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
self._input_workers)
self._input_workers,
True)
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
self._strategy)
iterator._element_spec = self.element_spec # pylint: disable=protected-access
return iterator
def __iter__(self):
if (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
return self._get_iterator()
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# TODO(priyag): Add other replication modes.
class DistributedDatasetsFromFunction(_IterableInput):
@ -653,14 +837,27 @@ class DistributedDatasetsFromFunction(_IterableInput):
self._strategy = strategy
self._element_spec = None
super(DistributedDatasetsFromFunction, self).__init__(
input_workers=input_workers)
def __iter__(self):
if not (context.executing_eagerly() or
if (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# This is an optional flag that can be used to turn off using
# OwnedMultiDeviceIterators and instead use the legacy
# MultiDeviceIterators as a stop gap solution that will allow us to roll
# out this change.
enable_legacy_iterators = getattr(self._strategy,
"_enable_legacy_iterators", False)
iterators, element_spec = _create_iterators_per_worker_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn)
self._input_contexts, self._input_workers, self._dataset_fn,
enable_legacy_iterators)
if enable_legacy_iterators:
iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy)
else:
iterator = DistributedIterator(self._input_workers, iterators,
self._strategy)
self._element_spec = _create_distributed_tensor_spec(self._strategy,
@ -668,6 +865,9 @@ class DistributedDatasetsFromFunction(_IterableInput):
iterator._element_spec = self._element_spec # pylint: disable=protected-access
return iterator
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
@property
def element_spec(self):
"""The type specification of an element of this dataset."""
@ -705,7 +905,8 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
def _get_iterator(self):
iterators, element_spec = _create_iterators_per_worker_with_input_context(
self._input_contexts, self._input_workers, self._dataset_fn)
self._input_contexts, self._input_workers, self._dataset_fn,
True)
iterator = DistributedIteratorV1(self._input_workers, iterators,
self._strategy)
self._element_spec = _create_distributed_tensor_spec(self._strategy,
@ -713,6 +914,14 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
iterator._element_spec = self._element_spec # pylint: disable=protected-access
return iterator
def __iter__(self):
if (ops.executing_eagerly_outside_functions() or
ops.get_default_graph().building_function):
return self._get_iterator()
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
# APIs.
@ -797,7 +1006,7 @@ class DatasetIterator(DistributedIteratorV1):
split_batch_by=split_batch_by,
input_context=input_context)
worker_iterators = _create_iterators_per_worker(
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
super(DatasetIterator, self).__init__(
input_workers,
worker_iterators, # pylint: disable=protected-access
@ -808,18 +1017,18 @@ class DatasetIterator(DistributedIteratorV1):
def _dummy_tensor_fn(value_structure):
"""A function to create dummy tensors from `value_structure`."""
def create_dummy_tensor(type_spec):
def create_dummy_tensor(spec):
"""Create a dummy tensor with possible batch dimensions set to 0."""
if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
# Splice out the ragged dimensions.
# pylint: disable=protected-access
feature_shape = type_spec._shape[:1].concatenate(
type_spec._shape[(1 + type_spec._ragged_rank):])
feature_type = type_spec._dtype
feature_shape = spec._shape[:1].concatenate(
spec._shape[(1 + spec._ragged_rank):])
feature_type = spec._dtype
# pylint: enable=protected-access
else:
feature_shape = type_spec.shape
feature_type = type_spec.dtype
feature_shape = spec.shape
feature_type = spec.dtype
# Ideally we should set the batch dimension to 0, however as in
# DistributionStrategy we don't know the batch dimension, we try to
# guess it as much as possible. If the feature has unknown dimensions, we
@ -827,11 +1036,11 @@ def _dummy_tensor_fn(value_structure):
# first dimension as batch dimension and set it to 0.
dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
if feature_shape else [])
if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or
if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
feature_shape.is_fully_defined()):
dims[0] = tensor_shape.Dimension(0)
if isinstance(type_spec, sparse_tensor.SparseTensorSpec):
if isinstance(spec, sparse_tensor.SparseTensorSpec):
return sparse_tensor.SparseTensor(
values=array_ops.zeros(0, feature_type),
indices=array_ops.zeros((0, len(dims)), dtypes.int64),
@ -839,26 +1048,26 @@ def _dummy_tensor_fn(value_structure):
# Create the dummy tensor.
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
# Reinsert the ragged dimensions with size 0.
# pylint: disable=protected-access
row_splits = array_ops.zeros(1, type_spec._row_splits_dtype)
row_splits = array_ops.zeros(1, spec._row_splits_dtype)
dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False)
dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
# pylint: enable=protected-access
return dummy_tensor
return nest.map_structure(create_dummy_tensor, value_structure)
class _SingleWorkerDatasetIterator(object):
class _SingleWorkerDatasetIteratorBase(object):
"""Iterator for a single `tf.data.Dataset`."""
def __init__(self, dataset, worker, devices):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`MultiDeviceIterator` is used to prefetch input to the devices on the
given worker.
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
input to the devices on the given worker.
Args:
dataset: A `tf.data.Dataset` instance.
@ -868,13 +1077,11 @@ class _SingleWorkerDatasetIterator(object):
self._dataset = dataset
self._worker = worker
self._devices = devices
self._element_spec = dataset.element_spec
self._make_iterator()
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
raise NotImplementedError("must be implemented in descendants")
def get_next(self, device, name=None):
"""Get next element for the given device."""
@ -923,9 +1130,9 @@ class _SingleWorkerDatasetIterator(object):
# Place the condition op in the same device as the data so the data
# doesn't need to be sent back to the worker.
with ops.device(self._devices[i]):
# As MultiDeviceIterator will fetch data in order, so we only need to
# check if the first replica has value to see whether there is data
# left for this single worker.
# Data will be fetched in order, so we only need to check if the first
# replica has value to see whether there is data left for this single
# worker.
if i == 0:
worker_has_value = data.has_value()
@ -943,8 +1150,159 @@ class _SingleWorkerDatasetIterator(object):
return worker_has_value, result
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
__slots__ = ["_worker", "_devices", "_element_spec"]
def __init__(self, worker, devices, element_spec):
self._worker = worker
self._devices = devices
self._element_spec = element_spec
@property
def value_type(self):
return _SingleWorkerOwnedDatasetIterator
def _serialize(self):
return (self._worker, tuple(self._devices), self._element_spec)
@property
def _component_specs(self):
specs = []
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, self._worker, element_spec=self._element_spec))
return specs
def _to_components(self, value):
return [value._iterator] # pylint: disable=protected-access
def _from_components(self, components):
return _SingleWorkerOwnedDatasetIterator(
dataset=None,
worker=self._worker,
devices=self._devices,
components=components,
element_spec=self._element_spec)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
value._element_spec)
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
composite_tensor.CompositeTensor):
"""Iterator for a DistributedDataset instance."""
def __init__(self, dataset=None, worker=None, devices=None, components=None,
element_spec=None):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
given worker. The lifetime of this iterator is tied to the encompassing
python object. Once we go out of scope of the python object or return from
a tf.function the underlying iterator resource is deleted.
Args:
dataset: A `tf.data.Dataset` instance.
worker: Worker on which ops should be created.
devices: Distribute data from `dataset` to these devices.
components: Tensor components to construct the
_SingleWorkerOwnedDatasetIterator from.
element_spec: A nested structure of `TypeSpec` objects that represents the
type specification of elements of the iterator.
"""
if worker is None or devices is None:
raise ValueError("Both `worker` and `devices` should be provided")
error_message = ("Either `dataset` or both `components` and `element_spec` "
"need to be provided.")
if dataset is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._worker = worker
self._devices = devices
self._iterator = components[0]
else:
if (components is not None or element_spec is not None):
raise ValueError(error_message)
super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker,
devices)
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
if not self._worker:
raise ValueError("Worked device must be specified when creating an "
"owned iterator.")
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
self._dataset, self._devices, source_device=host_device)
@property
def element_spec(self):
return self._element_spec
@property
def _type_spec(self):
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
self._element_spec)
@property
def output_classes(self):
"""Returns the class of each component of an element of this iterator.
The expected values are `tf.Tensor` and `tf.SparseTensor`.
Returns:
A nested structure of Python `type` objects corresponding to each
component of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
self._element_spec)
@property
def output_shapes(self):
"""Returns the shape of each component of an element of this iterator.
Returns:
A nested structure of `tf.TensorShape` objects corresponding to each
component of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
self._element_spec)
@property
def output_types(self):
"""Returns the type of each component of an element of this iterator.
Returns:
A nested structure of `tf.DType` objects corresponding to each component
of an element of this dataset.
"""
return nest.map_structure(
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
self._element_spec)
class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
"""Iterator for a single DistributedDatasetV1 instance."""
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
def initialize(self):
"""Initialze underlying iterator.
"""Initialize underlying iterator.
In eager execution, this simply recreates the underlying iterator.
In graph execution, it returns the initializer ops for the underlying
@ -1005,7 +1363,8 @@ class _SingleWorkerCallableIterator(object):
return []
def _create_iterators_per_worker(worker_datasets, input_workers):
def _create_iterators_per_worker(worker_datasets, input_workers,
enable_legacy_iterators):
"""Create a multidevice iterator on each of the workers."""
assert isinstance(input_workers, InputWorkers)
@ -1014,6 +1373,10 @@ def _create_iterators_per_worker(worker_datasets, input_workers):
for i, worker in enumerate(input_workers.worker_devices):
with ops.device(worker):
worker_devices = input_workers.compute_devices_for_worker(i)
if tf2.enabled() and not enable_legacy_iterators:
iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker,
worker_devices)
else:
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
iterators.append(iterator)
@ -1022,15 +1385,23 @@ def _create_iterators_per_worker(worker_datasets, input_workers):
def _create_iterators_per_worker_with_input_context(input_contexts,
input_workers,
dataset_fn):
dataset_fn,
enable_legacy_iterators):
"""Create a multidevice iterator per workers given a dataset function."""
iterators = []
element_specs = []
for i, ctx in enumerate(input_contexts):
worker = input_workers.worker_devices[i]
with ops.device(worker):
dataset = dataset_fn(ctx)
element_specs.append(dataset.element_spec)
devices = input_workers.compute_devices_for_worker(i)
iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
if tf2.enabled() and not enable_legacy_iterators:
iterator = _SingleWorkerOwnedDatasetIterator(dataset, worker,
devices)
else:
iterator = _SingleWorkerDatasetIterator(dataset, worker,
devices)
iterators.append(iterator)
return iterators, dataset.element_spec

View File

@ -45,6 +45,7 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@ -105,15 +106,16 @@ class DistributedIteratorTestBase(test.TestCase):
split_batch_by,
strategy,
input_context=None):
if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)):
return input_lib.DistributedDatasetV1(
if input_type == "dataset":
if tf2.enabled():
return input_lib.DistributedDataset(
dataset,
input_workers,
strategy,
split_batch_by=split_batch_by,
input_context=input_context)
elif input_type == "dataset":
return input_lib.DistributedDataset(
else:
return input_lib.DistributedDatasetV1(
dataset,
input_workers,
strategy,
@ -139,6 +141,9 @@ class DistributedIteratorTestBase(test.TestCase):
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_iterator" and input_type == "input_fn":
self.skipTest("unsupported test combination.")
devices = nest.flatten([ds for _, ds in worker_device_pairs])
input_workers = input_lib.InputWorkers(worker_device_pairs)
@ -161,7 +166,7 @@ class DistributedIteratorTestBase(test.TestCase):
strategy,
input_context=input_context)
if context.executing_eagerly():
if ops.executing_eagerly_outside_functions():
iterator = iter(dataset)
else:
if isinstance(dataset, input_lib.DistributedDatasetV1):
@ -171,7 +176,7 @@ class DistributedIteratorTestBase(test.TestCase):
if iteration_type == "get_next":
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
if isinstance(iterator, input_lib.DistributedIteratorV1):
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
for expected_value in expected_values:
@ -190,10 +195,13 @@ class DistributedIteratorTestBase(test.TestCase):
next_element) for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if isinstance(iterator, input_lib.DistributedIteratorV1):
if not ops.executing_eagerly_outside_functions():
evaluate(control_flow_ops.group(iterator.initializer))
else:
evaluate(control_flow_ops.group(iterator._initializer))
if api_type == "wrap_into_iterator":
self.skipTest("unsupported test combination")
else:
iterator = iter(dataset)
for expected_value in expected_values:
next_element = iterator.get_next()
@ -225,6 +233,48 @@ class DistributedIteratorTestBase(test.TestCase):
class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
parameterized.TestCase):
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["input_fn", "dataset"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu
]))
def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
if not tf2.enabled():
self.skipTest("unsupported test combination")
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
input_workers = input_lib.InputWorkers(worker_device_pairs)
if input_type == "dataset":
dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
input_workers,
distribution)
else:
dist_dataset = input_lib.get_distributed_datasets_from_function(
dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
distribution)
# Default Iterator types in TF2.
iterator = iter(dist_dataset)
self.assertIsInstance(iterator, input_lib.DistributedIterator)
self.assertIsInstance(iterator._iterators[0],
input_lib._SingleWorkerOwnedDatasetIterator)
# Disable creating owned iterators by setting a property on the strategy.
distribution._enable_legacy_iterators = True
iterator = iter(dist_dataset)
self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
self.assertIsInstance(iterator._iterators[0],
input_lib._SingleWorkerDatasetIterator)
@combinations.generate(
combinations.combine(
mode=["eager"],
@ -234,7 +284,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
def testMultiDeviceIterInitialize(self, distribution):
if tf2.enabled():
self.skipTest("Only V1 is supported.")
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
input_workers = input_lib.InputWorkers(worker_device_pairs)
@ -250,25 +301,6 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
init_func_for_iter()
@combinations.generate(
combinations.combine(
mode=["graph"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu
]))
def testDatasetV2IterError(self, distribution):
worker_device_pairs = [("", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
dist_dataset = input_lib.get_distributed_dataset(
dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
with self.assertRaisesRegexp(RuntimeError,
"or when eager execution is enabled"):
iter(dist_dataset)
@combinations.generate(
combinations.combine(
mode=["graph", "eager"],
@ -282,11 +314,11 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
enable_get_next_as_optional=[True, False]))
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("", ["/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
dataset_or_input_fn = self._create_dataset_or_input_fn(
input_type, dataset_fn)
@ -316,7 +348,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
enable_get_next_as_optional=[True, False]))
def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
distribution, enable_get_next_as_optional):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
else:
@ -386,7 +419,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
enable_get_next_as_optional=[True, False]))
def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
def dataset_fn(ctx):
del ctx
@ -422,7 +456,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
strategy_combinations.mirrored_strategy_with_one_cpu
]))
def testIterableIterator(self, distribution):
worker_device_pairs = [("", ["/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
input_workers = input_lib.InputWorkers(worker_device_pairs)
dataset = dataset_ops.DatasetV2.range(10)
@ -446,7 +480,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
]))
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
drop_remainder, distribution):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch( # pylint: disable=g-long-lambda
2, drop_remainder=drop_remainder)
@ -486,7 +521,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
def testBatchSplitting(self, input_type, api_type, iteration_type,
split_batch_by, distribution,
enable_get_next_as_optional):
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
"/device:CPU:0"])]
batch_size = 10
if tf2.enabled():
dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size)
@ -1075,68 +1111,5 @@ class DistributedIteratorMultiWorkerTest(
strategy,
sess=sess)
class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
input_type=["dataset", "dataset_fn"],
))
def testInputSignatureForPerReplicaValues(self, distribution, input_type):
def dataset_fn(ctx):
del ctx # unused
return dataset_ops.DatasetV2.from_tensor_slices(
np.ones([10, 12]).astype(np.float32)).batch(4)
if input_type == "dataset":
ds = distribution.experimental_distribute_dataset(
dataset_fn(distribute_lib.InputContext()))
type_spec = ds.element_spec
else:
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
iterator = iter(ds)
type_spec = iterator.element_spec
@def_function.function(input_signature=[type_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
for x in ds:
process_inputs(x)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
))
def testInputSignatureForNestedPerReplicaValues(self, distribution):
a = np.ones((10, 2)) * 5
b = np.ones((10, 3)) * 6
dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
@def_function.function(input_signature=[dist_dataset.element_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
for x in dist_dataset:
process_inputs(x)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,366 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the input_lib library which tests iterator type specs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import values
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
class DistributedIteratorTest(test.TestCase,
parameterized.TestCase):
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
enable_get_next_as_optional=[True, False]))
def testTypeSpec(self, input_type, distribution,
enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator has CompositeTensor support in "
"TF 2 only.")
dataset = dataset_ops.DatasetV2.range(10).batch(2)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
spec = iterator._type_spec
self.assertEqual(spec._input_workers, iterator._input_workers)
self.assertEqual(spec._element_spec._value_specs,
(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64,
name=None),
tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64,
name=None)))
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
enable_get_next_as_optional=[True, False]))
def testTypeSpecRoundTrip(self, input_type,
distribution, enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
dataset = dataset_ops.DatasetV2.range(10).batch(2)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
spec = iterator._type_spec
tensor_list = spec._to_components(iterator)
re_iterator = spec._from_components(tensor_list)
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
@combinations.generate(
combinations.combine(
mode=["eager"],
input_type=["dataset"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
enable_get_next_as_optional=[True, False]))
def testDoesNotTriggerFunctionTracing(self, input_type, distribution,
enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
trace_count = [0]
@def_function.function
def f(iterator):
trace_count[0] += 1
counter = np.int64(0)
for _ in range(5):
next(iterator)
counter += 1
return counter
dataset = dataset_ops.DatasetV2.range(10).batch(2)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
for _ in range(3):
iterator = iter(dist_dataset)
counter = f(iterator)
self.assertEqual(trace_count[0], 1)
self.assertEqual(counter, 5)
class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
input_type=["dataset", "dataset_fn"],
))
def testInputSignatureForPerReplicaValues(self, distribution, input_type):
def dataset_fn(ctx):
del ctx # unused
return dataset_ops.DatasetV2.from_tensor_slices(
np.ones([10, 12]).astype(np.float32)).batch(4)
if input_type == "dataset":
ds = distribution.experimental_distribute_dataset(
dataset_fn(distribute_lib.InputContext()))
type_spec = ds.element_spec
else:
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn)
iterator = iter(ds)
type_spec = iterator.element_spec
@def_function.function(input_signature=[type_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
for x in ds:
process_inputs(x)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
))
def testInputSignatureForNestedPerReplicaValues(self, distribution):
a = np.ones((10, 2)) * 5
b = np.ones((10, 3)) * 6
dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
@def_function.function(input_signature=[dist_dataset.element_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
for x in dist_dataset:
process_inputs(x)
class RaggedTensorDistributedIteratorTest(test.TestCase,
parameterized.TestCase):
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
enable_get_next_as_optional=[True, False]))
def testTypeSpec(self, distribution, enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator has CompositeTensor support in "
"TF 2.0 only.")
ctx = distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(8)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"dense": ragged_tensor.to_tensor(),
"ragged": ragged_tensor,
"sparse": ragged_tensor.to_sparse(),
})
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
dataset = dataset.batch(batch_size)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
spec = iterator._type_spec
self.assertEqual(spec._input_workers, iterator._input_workers)
self.assertEqual(
spec._element_spec, {
"sparse":
values.PerReplicaSpec(
sparse_tensor.SparseTensorSpec(
tensor_shape.TensorShape([None, 3]), dtypes.float32),
sparse_tensor.SparseTensorSpec(
tensor_shape.TensorShape([None, 3]), dtypes.float32)),
"dense":
values.PerReplicaSpec(
tensor_spec.TensorSpec(
shape=(None, 3), dtype=dtypes.float32, name=None),
tensor_spec.TensorSpec(
shape=(None, 3), dtype=dtypes.float32, name=None)),
"ragged":
values.PerReplicaSpec(
ragged_tensor_lib.RaggedTensorSpec(
tensor_shape.TensorShape([None, None]), dtypes.float32,
1, dtypes.int64),
ragged_tensor_lib.RaggedTensorSpec(
tensor_shape.TensorShape([None, None]), dtypes.float32,
1, dtypes.int64))
})
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
enable_get_next_as_optional=[True, False]))
def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
ctx = distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(8)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"dense": ragged_tensor.to_tensor(),
"ragged": ragged_tensor,
"sparse": ragged_tensor.to_sparse(),
})
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
dataset = dataset.batch(batch_size)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
iterator = iter(dist_dataset)
spec = iterator._type_spec
tensor_list = spec._to_components(iterator)
re_iterator = spec._from_components(tensor_list)
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
],
enable_get_next_as_optional=[True, False]))
def testDoesNotTriggerFunctionTracing(self, distribution,
enable_get_next_as_optional):
if not tf2.enabled():
self.skipTest("DistributedIterator CompositeTensor support is only "
"present in TF 2.0 only.")
trace_count = [0]
@def_function.function
def f(iterator):
trace_count[0] += 1
counter = np.int64(0)
for _ in range(5):
next(iterator)
counter += 1
return counter
ctx = distribute_lib.InputContext()
batch_size = ctx.get_per_replica_batch_size(8)
# Use 20 which isn't divisible by 8 to test partial batch behavior.
row_lengths = np.mod(np.arange(50), 4).astype(np.int64)
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
np.repeat(np.arange(50, dtype=np.float32), row_lengths), row_lengths)
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"dense": ragged_tensor.to_tensor(),
"ragged": ragged_tensor,
"sparse": ragged_tensor.to_sparse(),
})
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
dataset = dataset.batch(batch_size)
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
with distribution.scope():
for _ in range(3):
iterator = iter(dist_dataset)
counter = f(iterator)
self.assertEqual(trace_count[0], 1)
self.assertEqual(counter, 5)
if __name__ == "__main__":
test.main()

View File

@ -1161,11 +1161,6 @@ class DataHandler(object):
if self._insufficient_data: # Set by `catch_stop_iteration`.
break
if self._adapter.should_recreate_iterator():
if ds_context.has_strategy():
# TODO(b/138326910): remove this when MultiDeviceIterator is a
# CompositeTensor (unless this is more efficient)
data_iterator._initializer # pylint: disable=pointless-statement, protected-access
else:
data_iterator = iter(self._dataset)
yield epoch, data_iterator
self._adapter.on_epoch_end()