Add CompositeTensor support to distributed iterators.
PiperOrigin-RevId: 308423654 Change-Id: I406d7adeca6059112768f0b8a02a80fbbd463463
This commit is contained in:
parent
be3e3961f8
commit
dac6d6ae7c
@ -847,6 +847,38 @@ distribute_py_test(
|
||||
srcs = ["input_lib_test.py"],
|
||||
main = "input_lib_test.py",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_gpu_presubmit", # TODO(b/154660040)
|
||||
],
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":input_lib",
|
||||
":mirrored_strategy",
|
||||
":multi_worker_test_base",
|
||||
":reduce_util",
|
||||
":strategy_combinations",
|
||||
":values",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
distribute_py_test(
|
||||
name = "input_lib_type_spec_test",
|
||||
srcs = ["input_lib_type_spec_test.py"],
|
||||
main = "input_lib_type_spec_test.py",
|
||||
shard_count = 10,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
@ -1453,9 +1485,10 @@ distribute_py_test(
|
||||
name = "ctl_correctness_test",
|
||||
srcs = ["ctl_correctness_test.py"],
|
||||
main = "ctl_correctness_test.py",
|
||||
shard_count = 10,
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_gpu_presubmit", # TODO(b/154660040)
|
||||
"noguitar", # b/140755528
|
||||
],
|
||||
deps = [
|
||||
|
@ -330,8 +330,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
communication=self._communication)
|
||||
super(CollectiveAllReduceExtended, self)._initialize_single_worker(
|
||||
local_devices)
|
||||
host_device = device_util.get_host_for_device(self._worker_device)
|
||||
self._input_workers = input_lib.InputWorkers(
|
||||
[(self._worker_device, self.worker_devices)])
|
||||
[(host_device, self.worker_devices)])
|
||||
|
||||
# Add a default device so that ops without specified devices will not end up
|
||||
# on other workers.
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.distribute import input_ops
|
||||
from tensorflow.python.distribute import reduce_util
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -41,6 +42,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -143,9 +145,10 @@ class InputWorkers(object):
|
||||
worker_device_pairs: A sequence of pairs:
|
||||
`(input device, a tuple of compute devices fed by that input device)`.
|
||||
"""
|
||||
self._input_worker_devices = tuple(d for d, _ in worker_device_pairs)
|
||||
self._worker_device_pairs = worker_device_pairs
|
||||
self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
|
||||
self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f)
|
||||
for _, f in worker_device_pairs)
|
||||
for _, f in self._worker_device_pairs)
|
||||
|
||||
@property
|
||||
def num_workers(self):
|
||||
@ -165,6 +168,12 @@ class InputWorkers(object):
|
||||
for i in range(len(devices)))
|
||||
return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)
|
||||
|
||||
def serialize(self):
|
||||
return self._worker_device_pairs
|
||||
|
||||
def deserialize(self, worker_device_pairs):
|
||||
return InputWorkers(worker_device_pairs)
|
||||
|
||||
|
||||
def _get_next_as_optional(iterator, strategy, name=None):
|
||||
"""Returns an empty dataset indicator and the next input from the iterator."""
|
||||
@ -208,7 +217,7 @@ def _get_next_as_optional(iterator, strategy, name=None):
|
||||
|
||||
|
||||
def _is_statically_shaped(tensor_class, shape):
|
||||
"""Test if an iteratort output is statically shaped.
|
||||
"""Test if an iterator output is statically shaped.
|
||||
|
||||
For sparse and ragged tensors this only tests the batch dimension.
|
||||
|
||||
@ -231,13 +240,12 @@ def _is_statically_shaped(tensor_class, shape):
|
||||
return shape.is_fully_defined()
|
||||
|
||||
|
||||
class DistributedIterator(object):
|
||||
"""Common implementation for all input iterators."""
|
||||
|
||||
def __init__(self, input_workers, iterators, strategy):
|
||||
def _get_static_shape(iterators):
|
||||
"""Returns a boolean indicating if the input is fully defined."""
|
||||
static_shape = True
|
||||
for iterator in iterators:
|
||||
if not isinstance(iterator, _SingleWorkerDatasetIterator):
|
||||
if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
|
||||
_SingleWorkerDatasetIterator)):
|
||||
continue
|
||||
flattened = zip(nest.flatten(iterator.output_shapes),
|
||||
nest.flatten(iterator.output_classes))
|
||||
@ -245,6 +253,14 @@ class DistributedIterator(object):
|
||||
if not _is_statically_shaped(output_class, output_shape):
|
||||
static_shape = False
|
||||
break
|
||||
return static_shape
|
||||
|
||||
|
||||
class DistributedIteratorBase(object):
|
||||
"""Common implementation for all input iterators."""
|
||||
|
||||
def __init__(self, input_workers, iterators, strategy):
|
||||
static_shape = _get_static_shape(iterators)
|
||||
|
||||
# TODO(b/133073708): we currently need a flag to control the usage because
|
||||
# there is a performance difference between get_next() and
|
||||
@ -360,6 +376,10 @@ class DistributedIterator(object):
|
||||
|
||||
return values.regroup(replicas)
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIteratorBase):
|
||||
"""Input Iterator for a distributed dataset."""
|
||||
|
||||
# We need a private initializer method for re-initializing multidevice
|
||||
# iterators when used with Keras training loops. If we don't reinitialize the
|
||||
# iterator we run into memory leak issues (b/123315763).
|
||||
@ -370,23 +390,14 @@ class DistributedIterator(object):
|
||||
init_ops.extend(it.initialize())
|
||||
return control_flow_ops.group(init_ops)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator."""
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIterator):
|
||||
"""Input Iterator for a distributed dataset instance."""
|
||||
|
||||
@deprecated(None, "Use the iterator's `initializer` property instead.")
|
||||
def initialize(self):
|
||||
"""Initialze underlying iterators.
|
||||
"""Initialize underlying iterators.
|
||||
|
||||
Returns:
|
||||
A list of any initializer ops that should be run.
|
||||
"""
|
||||
return super(DistributedIteratorV1, self)._initializer
|
||||
return self._initializer
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
@ -415,6 +426,161 @@ class DistributedIteratorV1(DistributedIterator):
|
||||
return self._iterators[i]
|
||||
return None
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this iterator."""
|
||||
return self._element_spec
|
||||
|
||||
|
||||
class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
"""Type specification for `DistributedIterator`."""
|
||||
|
||||
__slots__ = ["_input_workers", "_element_spec", "_strategy"]
|
||||
|
||||
def __init__(self, input_workers, element_spec, strategy):
|
||||
# We don't want to allow deserialization of this class because we don't
|
||||
# serialize the strategy object. Currently the only places where
|
||||
# _deserialize is called is when we save/restore using SavedModels.
|
||||
if isinstance(input_workers, tuple):
|
||||
raise NotImplementedError("DistributedIteratorSpec does not have support "
|
||||
"for deserialization.")
|
||||
else:
|
||||
self._input_workers = input_workers
|
||||
self._element_spec = element_spec
|
||||
self._strategy = strategy
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return DistributedIterator
|
||||
|
||||
def _serialize(self):
|
||||
# We cannot serialize the strategy object so we convert it to an id that we
|
||||
# can use for comparison.
|
||||
return (self._input_workers.serialize(),
|
||||
self._element_spec, id(self._strategy))
|
||||
|
||||
def _deserialize(self):
|
||||
raise ValueError("Deserialization is currently unsupported for "
|
||||
"DistributedIteratorSpec.")
|
||||
|
||||
@staticmethod
|
||||
def _is_compatible(a, b):
|
||||
"""Returns true if the given type serializations compatible."""
|
||||
if type(a) is not type(b):
|
||||
return False
|
||||
if isinstance(a, tuple):
|
||||
return (len(a) == len(b) and
|
||||
all(DistributedIteratorSpec._is_compatible(x, y) for (x, y) in
|
||||
zip(a, b)))
|
||||
if isinstance(a, dict):
|
||||
return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
|
||||
DistributedIteratorSpec._is_compatible(a[k], b[k]) for k in a.keys()))
|
||||
if isinstance(a, (type_spec.TypeSpec, tensor_shape.TensorShape,
|
||||
dtypes.DType)):
|
||||
return a.is_compatible_with(b)
|
||||
return a == b
|
||||
|
||||
# Overriding this method so that we can merge and reconstruct the spec object
|
||||
def most_specific_compatible_type(self, other):
|
||||
"""Returns the most specific TypeSpec compatible with `self` and `other`.
|
||||
|
||||
Args:
|
||||
other: A `TypeSpec`.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is no TypeSpec that is compatible with both `self`
|
||||
and `other`.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if type(self) is not type(other):
|
||||
raise ValueError("No TypeSpec is compatible with both %s and %s" %
|
||||
(self, other))
|
||||
if not self._is_compatible(self._input_workers.serialize(),
|
||||
other._input_workers.serialize()):
|
||||
raise ValueError("_input_workers is not compatible with both %s "
|
||||
"and %s" % (self, other))
|
||||
if self._element_spec != other._element_spec:
|
||||
raise ValueError("_element_spec is not compatible with both %s "
|
||||
"and %s" % (self, other))
|
||||
if id(self._strategy) != id(other._strategy):
|
||||
raise ValueError("tf.distribute strategy is not compatible with both %s "
|
||||
"and %s" % (self, other))
|
||||
return DistributedIteratorSpec(self._input_workers, self._element_spec,
|
||||
self._strategy)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
specs = []
|
||||
worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access
|
||||
for i in range(len(worker_device_pairs)):
|
||||
input_device, compute_devices = worker_device_pairs[i]
|
||||
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
||||
compute_devices,
|
||||
element_spec=
|
||||
self._element_spec))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
return value._iterators # pylint: disable=protected-access
|
||||
|
||||
def _from_components(self, components):
|
||||
return DistributedIterator(input_workers=self._input_workers,
|
||||
iterators=None,
|
||||
components=components,
|
||||
element_spec=self._element_spec,
|
||||
strategy=self._strategy)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# pylint: disable=protected-access
|
||||
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
||||
value._strategy)
|
||||
|
||||
|
||||
class DistributedIterator(DistributedIteratorBase,
|
||||
composite_tensor.CompositeTensor):
|
||||
"""Input Iterator for a distributed dataset."""
|
||||
|
||||
def __init__(self, input_workers=None, iterators=None, strategy=None,
|
||||
components=None, element_spec=None):
|
||||
if input_workers is None:
|
||||
raise ValueError("`input_workers` should be "
|
||||
"provided.")
|
||||
|
||||
error_message = ("Either `input_workers` or "
|
||||
"both `components` and `element_spec` need to be "
|
||||
"provided.")
|
||||
|
||||
if iterators is None:
|
||||
if (components is None or element_spec is None):
|
||||
raise ValueError(error_message)
|
||||
self._element_spec = element_spec
|
||||
self._input_workers = input_workers
|
||||
self._iterators = components
|
||||
static_shape = _get_static_shape(self._iterators)
|
||||
self._strategy = strategy
|
||||
if getattr(
|
||||
strategy.extended, "experimental_enable_get_next_as_optional", False):
|
||||
self._enable_get_next_as_optional = not static_shape
|
||||
else:
|
||||
self._enable_get_next_as_optional = False
|
||||
else:
|
||||
if (components is not None and element_spec is not None):
|
||||
raise ValueError(error_message)
|
||||
|
||||
super(DistributedIterator, self).__init__(input_workers, iterators,
|
||||
strategy)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return DistributedIteratorSpec(self._input_workers,
|
||||
self.element_spec,
|
||||
self._strategy)
|
||||
|
||||
|
||||
class _IterableInput(object):
|
||||
"""Base class for iterable inputs for distribution strategies."""
|
||||
@ -482,7 +648,6 @@ class DistributedDataset(_IterableInput):
|
||||
`num_input_pipelines` in the `InputContext`.
|
||||
"""
|
||||
super(DistributedDataset, self).__init__(input_workers=input_workers)
|
||||
|
||||
# We clone and shard the dataset on each worker. The current setup tries to
|
||||
# shard the dataset by files if possible so that each worker sees a
|
||||
# different subset of files. If that is not possible, will attempt to shard
|
||||
@ -541,8 +706,18 @@ class DistributedDataset(_IterableInput):
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
|
||||
# This is an optional flag that can be used to turn off using
|
||||
# OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators
|
||||
# as a stop gap solution that will allow us to roll out this change.
|
||||
enable_legacy_iterators = getattr(self._strategy,
|
||||
"_enable_legacy_iterators", False)
|
||||
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
|
||||
self._input_workers)
|
||||
self._input_workers,
|
||||
enable_legacy_iterators)
|
||||
if enable_legacy_iterators:
|
||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
else:
|
||||
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||
@ -615,12 +790,21 @@ class DistributedDatasetV1(DistributedDataset):
|
||||
|
||||
def _get_iterator(self):
|
||||
worker_iterators = _create_iterators_per_worker(self._cloned_datasets,
|
||||
self._input_workers)
|
||||
self._input_workers,
|
||||
True)
|
||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||
self._strategy)
|
||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||
return iterator
|
||||
|
||||
def __iter__(self):
|
||||
if (ops.executing_eagerly_outside_functions() or
|
||||
ops.get_default_graph().building_function):
|
||||
return self._get_iterator()
|
||||
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
|
||||
|
||||
# TODO(priyag): Add other replication modes.
|
||||
class DistributedDatasetsFromFunction(_IterableInput):
|
||||
@ -653,14 +837,27 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
self._strategy = strategy
|
||||
self._element_spec = None
|
||||
|
||||
super(DistributedDatasetsFromFunction, self).__init__(
|
||||
input_workers=input_workers)
|
||||
|
||||
def __iter__(self):
|
||||
if not (context.executing_eagerly() or
|
||||
if (ops.executing_eagerly_outside_functions() or
|
||||
ops.get_default_graph().building_function):
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
# This is an optional flag that can be used to turn off using
|
||||
# OwnedMultiDeviceIterators and instead use the legacy
|
||||
# MultiDeviceIterators as a stop gap solution that will allow us to roll
|
||||
# out this change.
|
||||
enable_legacy_iterators = getattr(self._strategy,
|
||||
"_enable_legacy_iterators", False)
|
||||
|
||||
iterators, element_spec = _create_iterators_per_worker_with_input_context(
|
||||
self._input_contexts, self._input_workers, self._dataset_fn)
|
||||
self._input_contexts, self._input_workers, self._dataset_fn,
|
||||
enable_legacy_iterators)
|
||||
|
||||
if enable_legacy_iterators:
|
||||
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
||||
self._strategy)
|
||||
else:
|
||||
iterator = DistributedIterator(self._input_workers, iterators,
|
||||
self._strategy)
|
||||
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
||||
@ -668,6 +865,9 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
return iterator
|
||||
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this dataset."""
|
||||
@ -705,7 +905,8 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
|
||||
|
||||
def _get_iterator(self):
|
||||
iterators, element_spec = _create_iterators_per_worker_with_input_context(
|
||||
self._input_contexts, self._input_workers, self._dataset_fn)
|
||||
self._input_contexts, self._input_workers, self._dataset_fn,
|
||||
True)
|
||||
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
||||
self._strategy)
|
||||
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
||||
@ -713,6 +914,14 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
return iterator
|
||||
|
||||
def __iter__(self):
|
||||
if (ops.executing_eagerly_outside_functions() or
|
||||
ops.get_default_graph().building_function):
|
||||
return self._get_iterator()
|
||||
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
|
||||
|
||||
# TODO(anjalisridhar): This class will be soon be removed in favor of newer
|
||||
# APIs.
|
||||
@ -797,7 +1006,7 @@ class DatasetIterator(DistributedIteratorV1):
|
||||
split_batch_by=split_batch_by,
|
||||
input_context=input_context)
|
||||
worker_iterators = _create_iterators_per_worker(
|
||||
dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access
|
||||
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
||||
super(DatasetIterator, self).__init__(
|
||||
input_workers,
|
||||
worker_iterators, # pylint: disable=protected-access
|
||||
@ -808,18 +1017,18 @@ class DatasetIterator(DistributedIteratorV1):
|
||||
def _dummy_tensor_fn(value_structure):
|
||||
"""A function to create dummy tensors from `value_structure`."""
|
||||
|
||||
def create_dummy_tensor(type_spec):
|
||||
def create_dummy_tensor(spec):
|
||||
"""Create a dummy tensor with possible batch dimensions set to 0."""
|
||||
if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
|
||||
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
||||
# Splice out the ragged dimensions.
|
||||
# pylint: disable=protected-access
|
||||
feature_shape = type_spec._shape[:1].concatenate(
|
||||
type_spec._shape[(1 + type_spec._ragged_rank):])
|
||||
feature_type = type_spec._dtype
|
||||
feature_shape = spec._shape[:1].concatenate(
|
||||
spec._shape[(1 + spec._ragged_rank):])
|
||||
feature_type = spec._dtype
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
feature_shape = type_spec.shape
|
||||
feature_type = type_spec.dtype
|
||||
feature_shape = spec.shape
|
||||
feature_type = spec.dtype
|
||||
# Ideally we should set the batch dimension to 0, however as in
|
||||
# DistributionStrategy we don't know the batch dimension, we try to
|
||||
# guess it as much as possible. If the feature has unknown dimensions, we
|
||||
@ -827,11 +1036,11 @@ def _dummy_tensor_fn(value_structure):
|
||||
# first dimension as batch dimension and set it to 0.
|
||||
dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()]
|
||||
if feature_shape else [])
|
||||
if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or
|
||||
if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or
|
||||
feature_shape.is_fully_defined()):
|
||||
dims[0] = tensor_shape.Dimension(0)
|
||||
|
||||
if isinstance(type_spec, sparse_tensor.SparseTensorSpec):
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
return sparse_tensor.SparseTensor(
|
||||
values=array_ops.zeros(0, feature_type),
|
||||
indices=array_ops.zeros((0, len(dims)), dtypes.int64),
|
||||
@ -839,26 +1048,26 @@ def _dummy_tensor_fn(value_structure):
|
||||
|
||||
# Create the dummy tensor.
|
||||
dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type)
|
||||
if isinstance(type_spec, ragged_tensor.RaggedTensorSpec):
|
||||
if isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
||||
# Reinsert the ragged dimensions with size 0.
|
||||
# pylint: disable=protected-access
|
||||
row_splits = array_ops.zeros(1, type_spec._row_splits_dtype)
|
||||
row_splits = array_ops.zeros(1, spec._row_splits_dtype)
|
||||
dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False)
|
||||
dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False)
|
||||
# pylint: enable=protected-access
|
||||
return dummy_tensor
|
||||
|
||||
return nest.map_structure(create_dummy_tensor, value_structure)
|
||||
|
||||
|
||||
class _SingleWorkerDatasetIterator(object):
|
||||
class _SingleWorkerDatasetIteratorBase(object):
|
||||
"""Iterator for a single `tf.data.Dataset`."""
|
||||
|
||||
def __init__(self, dataset, worker, devices):
|
||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||
|
||||
`MultiDeviceIterator` is used to prefetch input to the devices on the
|
||||
given worker.
|
||||
A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch
|
||||
input to the devices on the given worker.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` instance.
|
||||
@ -868,13 +1077,11 @@ class _SingleWorkerDatasetIterator(object):
|
||||
self._dataset = dataset
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
self._element_spec = dataset.element_spec
|
||||
self._make_iterator()
|
||||
|
||||
def _make_iterator(self):
|
||||
"""Make appropriate iterator on the dataset."""
|
||||
with ops.device(self._worker):
|
||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, self._devices)
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def get_next(self, device, name=None):
|
||||
"""Get next element for the given device."""
|
||||
@ -923,9 +1130,9 @@ class _SingleWorkerDatasetIterator(object):
|
||||
# Place the condition op in the same device as the data so the data
|
||||
# doesn't need to be sent back to the worker.
|
||||
with ops.device(self._devices[i]):
|
||||
# As MultiDeviceIterator will fetch data in order, so we only need to
|
||||
# check if the first replica has value to see whether there is data
|
||||
# left for this single worker.
|
||||
# Data will be fetched in order, so we only need to check if the first
|
||||
# replica has value to see whether there is data left for this single
|
||||
# worker.
|
||||
if i == 0:
|
||||
worker_has_value = data.has_value()
|
||||
|
||||
@ -943,8 +1150,159 @@ class _SingleWorkerDatasetIterator(object):
|
||||
|
||||
return worker_has_value, result
|
||||
|
||||
|
||||
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
||||
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
|
||||
|
||||
__slots__ = ["_worker", "_devices", "_element_spec"]
|
||||
|
||||
def __init__(self, worker, devices, element_spec):
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
self._element_spec = element_spec
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return _SingleWorkerOwnedDatasetIterator
|
||||
|
||||
def _serialize(self):
|
||||
return (self._worker, tuple(self._devices), self._element_spec)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
specs = []
|
||||
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
|
||||
self._devices, self._worker, element_spec=self._element_spec))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
return [value._iterator] # pylint: disable=protected-access
|
||||
|
||||
def _from_components(self, components):
|
||||
return _SingleWorkerOwnedDatasetIterator(
|
||||
dataset=None,
|
||||
worker=self._worker,
|
||||
devices=self._devices,
|
||||
components=components,
|
||||
element_spec=self._element_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# pylint: disable=protected-access
|
||||
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
|
||||
value._element_spec)
|
||||
|
||||
|
||||
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
composite_tensor.CompositeTensor):
|
||||
"""Iterator for a DistributedDataset instance."""
|
||||
|
||||
def __init__(self, dataset=None, worker=None, devices=None, components=None,
|
||||
element_spec=None):
|
||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||
|
||||
`OwnedMultiDeviceIterator` is used to prefetch input to the devices on the
|
||||
given worker. The lifetime of this iterator is tied to the encompassing
|
||||
python object. Once we go out of scope of the python object or return from
|
||||
a tf.function the underlying iterator resource is deleted.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` instance.
|
||||
worker: Worker on which ops should be created.
|
||||
devices: Distribute data from `dataset` to these devices.
|
||||
components: Tensor components to construct the
|
||||
_SingleWorkerOwnedDatasetIterator from.
|
||||
element_spec: A nested structure of `TypeSpec` objects that represents the
|
||||
type specification of elements of the iterator.
|
||||
"""
|
||||
if worker is None or devices is None:
|
||||
raise ValueError("Both `worker` and `devices` should be provided")
|
||||
|
||||
error_message = ("Either `dataset` or both `components` and `element_spec` "
|
||||
"need to be provided.")
|
||||
|
||||
if dataset is None:
|
||||
if (components is None or element_spec is None):
|
||||
raise ValueError(error_message)
|
||||
self._element_spec = element_spec
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
self._iterator = components[0]
|
||||
else:
|
||||
if (components is not None or element_spec is not None):
|
||||
raise ValueError(error_message)
|
||||
super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker,
|
||||
devices)
|
||||
|
||||
def _make_iterator(self):
|
||||
"""Make appropriate iterator on the dataset."""
|
||||
if not self._worker:
|
||||
raise ValueError("Worked device must be specified when creating an "
|
||||
"owned iterator.")
|
||||
host_device = device_util.get_host_for_device(self._worker)
|
||||
with ops.device(self._worker):
|
||||
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
|
||||
self._dataset, self._devices, source_device=host_device)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._element_spec
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
"""Returns the class of each component of an element of this iterator.
|
||||
|
||||
The expected values are `tf.Tensor` and `tf.SparseTensor`.
|
||||
|
||||
Returns:
|
||||
A nested structure of Python `type` objects corresponding to each
|
||||
component of an element of this dataset.
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
"""Returns the shape of each component of an element of this iterator.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.TensorShape` objects corresponding to each
|
||||
component of an element of this dataset.
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access
|
||||
self._element_spec)
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
"""Returns the type of each component of an element of this iterator.
|
||||
|
||||
Returns:
|
||||
A nested structure of `tf.DType` objects corresponding to each component
|
||||
of an element of this dataset.
|
||||
"""
|
||||
return nest.map_structure(
|
||||
lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access
|
||||
self._element_spec)
|
||||
|
||||
|
||||
class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
|
||||
"""Iterator for a single DistributedDatasetV1 instance."""
|
||||
|
||||
def _make_iterator(self):
|
||||
"""Make appropriate iterator on the dataset."""
|
||||
with ops.device(self._worker):
|
||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, self._devices)
|
||||
|
||||
def initialize(self):
|
||||
"""Initialze underlying iterator.
|
||||
"""Initialize underlying iterator.
|
||||
|
||||
In eager execution, this simply recreates the underlying iterator.
|
||||
In graph execution, it returns the initializer ops for the underlying
|
||||
@ -1005,7 +1363,8 @@ class _SingleWorkerCallableIterator(object):
|
||||
return []
|
||||
|
||||
|
||||
def _create_iterators_per_worker(worker_datasets, input_workers):
|
||||
def _create_iterators_per_worker(worker_datasets, input_workers,
|
||||
enable_legacy_iterators):
|
||||
"""Create a multidevice iterator on each of the workers."""
|
||||
assert isinstance(input_workers, InputWorkers)
|
||||
|
||||
@ -1014,6 +1373,10 @@ def _create_iterators_per_worker(worker_datasets, input_workers):
|
||||
for i, worker in enumerate(input_workers.worker_devices):
|
||||
with ops.device(worker):
|
||||
worker_devices = input_workers.compute_devices_for_worker(i)
|
||||
if tf2.enabled() and not enable_legacy_iterators:
|
||||
iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker,
|
||||
worker_devices)
|
||||
else:
|
||||
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
|
||||
worker_devices)
|
||||
iterators.append(iterator)
|
||||
@ -1022,15 +1385,23 @@ def _create_iterators_per_worker(worker_datasets, input_workers):
|
||||
|
||||
def _create_iterators_per_worker_with_input_context(input_contexts,
|
||||
input_workers,
|
||||
dataset_fn):
|
||||
dataset_fn,
|
||||
enable_legacy_iterators):
|
||||
"""Create a multidevice iterator per workers given a dataset function."""
|
||||
iterators = []
|
||||
element_specs = []
|
||||
for i, ctx in enumerate(input_contexts):
|
||||
worker = input_workers.worker_devices[i]
|
||||
with ops.device(worker):
|
||||
dataset = dataset_fn(ctx)
|
||||
element_specs.append(dataset.element_spec)
|
||||
devices = input_workers.compute_devices_for_worker(i)
|
||||
iterator = _SingleWorkerDatasetIterator(dataset, worker, devices)
|
||||
if tf2.enabled() and not enable_legacy_iterators:
|
||||
iterator = _SingleWorkerOwnedDatasetIterator(dataset, worker,
|
||||
devices)
|
||||
else:
|
||||
iterator = _SingleWorkerDatasetIterator(dataset, worker,
|
||||
devices)
|
||||
iterators.append(iterator)
|
||||
return iterators, dataset.element_spec
|
||||
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -105,15 +106,16 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
split_batch_by,
|
||||
strategy,
|
||||
input_context=None):
|
||||
if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)):
|
||||
return input_lib.DistributedDatasetV1(
|
||||
if input_type == "dataset":
|
||||
if tf2.enabled():
|
||||
return input_lib.DistributedDataset(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
split_batch_by=split_batch_by,
|
||||
input_context=input_context)
|
||||
elif input_type == "dataset":
|
||||
return input_lib.DistributedDataset(
|
||||
else:
|
||||
return input_lib.DistributedDatasetV1(
|
||||
dataset,
|
||||
input_workers,
|
||||
strategy,
|
||||
@ -139,6 +141,9 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
|
||||
self.skipTest("unsupported test combination.")
|
||||
|
||||
if api_type == "wrap_into_iterator" and input_type == "input_fn":
|
||||
self.skipTest("unsupported test combination.")
|
||||
|
||||
devices = nest.flatten([ds for _, ds in worker_device_pairs])
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
|
||||
@ -161,7 +166,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
strategy,
|
||||
input_context=input_context)
|
||||
|
||||
if context.executing_eagerly():
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
iterator = iter(dataset)
|
||||
else:
|
||||
if isinstance(dataset, input_lib.DistributedDatasetV1):
|
||||
@ -171,7 +176,7 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
|
||||
if iteration_type == "get_next":
|
||||
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
|
||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
|
||||
for expected_value in expected_values:
|
||||
@ -190,10 +195,13 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
next_element) for r in range(len(devices))])
|
||||
|
||||
# After re-initializing the iterator, should be able to iterate again.
|
||||
if isinstance(iterator, input_lib.DistributedIteratorV1):
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
evaluate(control_flow_ops.group(iterator.initializer))
|
||||
else:
|
||||
evaluate(control_flow_ops.group(iterator._initializer))
|
||||
if api_type == "wrap_into_iterator":
|
||||
self.skipTest("unsupported test combination")
|
||||
else:
|
||||
iterator = iter(dataset)
|
||||
|
||||
for expected_value in expected_values:
|
||||
next_element = iterator.get_next()
|
||||
@ -225,6 +233,48 @@ class DistributedIteratorTestBase(test.TestCase):
|
||||
class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["input_fn", "dataset"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu
|
||||
]))
|
||||
def testDisablingOwnedIteratorsInTF2(self, distribution, input_type):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("unsupported test combination")
|
||||
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
if input_type == "dataset":
|
||||
dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn,
|
||||
input_workers,
|
||||
distribution)
|
||||
else:
|
||||
dist_dataset = input_lib.get_distributed_datasets_from_function(
|
||||
dataset_or_input_fn, input_workers, [distribute_lib.InputContext()],
|
||||
distribution)
|
||||
|
||||
# Default Iterator types in TF2.
|
||||
iterator = iter(dist_dataset)
|
||||
self.assertIsInstance(iterator, input_lib.DistributedIterator)
|
||||
self.assertIsInstance(iterator._iterators[0],
|
||||
input_lib._SingleWorkerOwnedDatasetIterator)
|
||||
|
||||
# Disable creating owned iterators by setting a property on the strategy.
|
||||
distribution._enable_legacy_iterators = True
|
||||
iterator = iter(dist_dataset)
|
||||
self.assertIsInstance(iterator, input_lib.DistributedIteratorV1)
|
||||
self.assertIsInstance(iterator._iterators[0],
|
||||
input_lib._SingleWorkerDatasetIterator)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
@ -234,7 +284,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
def testMultiDeviceIterInitialize(self, distribution):
|
||||
if tf2.enabled():
|
||||
self.skipTest("Only V1 is supported.")
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
|
||||
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
@ -250,25 +301,6 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
|
||||
init_func_for_iter()
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu
|
||||
]))
|
||||
def testDatasetV2IterError(self, distribution):
|
||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2)
|
||||
|
||||
dist_dataset = input_lib.get_distributed_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()), input_workers, distribution)
|
||||
|
||||
with self.assertRaisesRegexp(RuntimeError,
|
||||
"or when eager execution is enabled"):
|
||||
iter(dist_dataset)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["graph", "eager"],
|
||||
@ -282,11 +314,11 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution,
|
||||
enable_get_next_as_optional):
|
||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
dataset_fn = lambda _: dataset_ops.Dataset.range(10)
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV1.range(10)
|
||||
dataset_or_input_fn = self._create_dataset_or_input_fn(
|
||||
input_type, dataset_fn)
|
||||
|
||||
@ -316,7 +348,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type,
|
||||
distribution, enable_get_next_as_optional):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(10)
|
||||
else:
|
||||
@ -386,7 +419,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTupleDataset(self, input_type, api_type, iteration_type, distribution,
|
||||
enable_get_next_as_optional):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
@ -422,7 +456,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu
|
||||
]))
|
||||
def testIterableIterator(self, distribution):
|
||||
worker_device_pairs = [("", ["/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])]
|
||||
input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
|
||||
dataset = dataset_ops.DatasetV2.range(10)
|
||||
@ -446,7 +480,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
]))
|
||||
def testUnevenDatasetBatches(self, input_type, api_type, iteration_type,
|
||||
drop_remainder, distribution):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch( # pylint: disable=g-long-lambda
|
||||
2, drop_remainder=drop_remainder)
|
||||
@ -486,7 +521,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
def testBatchSplitting(self, input_type, api_type, iteration_type,
|
||||
split_batch_by, distribution,
|
||||
enable_get_next_as_optional):
|
||||
worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])]
|
||||
worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0",
|
||||
"/device:CPU:0"])]
|
||||
batch_size = 10
|
||||
if tf2.enabled():
|
||||
dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size)
|
||||
@ -1075,68 +1111,5 @@ class DistributedIteratorMultiWorkerTest(
|
||||
strategy,
|
||||
sess=sess)
|
||||
|
||||
|
||||
class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
input_type=["dataset", "dataset_fn"],
|
||||
))
|
||||
def testInputSignatureForPerReplicaValues(self, distribution, input_type):
|
||||
def dataset_fn(ctx):
|
||||
del ctx # unused
|
||||
return dataset_ops.DatasetV2.from_tensor_slices(
|
||||
np.ones([10, 12]).astype(np.float32)).batch(4)
|
||||
|
||||
if input_type == "dataset":
|
||||
ds = distribution.experimental_distribute_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
type_spec = ds.element_spec
|
||||
else:
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
iterator = iter(ds)
|
||||
type_spec = iterator.element_spec
|
||||
|
||||
@def_function.function(input_signature=[type_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
for x in ds:
|
||||
process_inputs(x)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
))
|
||||
def testInputSignatureForNestedPerReplicaValues(self, distribution):
|
||||
a = np.ones((10, 2)) * 5
|
||||
b = np.ones((10, 3)) * 6
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
|
||||
@def_function.function(input_signature=[dist_dataset.element_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
for x in dist_dataset:
|
||||
process_inputs(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
366
tensorflow/python/distribute/input_lib_type_spec_test.py
Normal file
366
tensorflow/python/distribute/input_lib_type_spec_test.py
Normal file
@ -0,0 +1,366 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the input_lib library which tests iterator type specs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
|
||||
|
||||
|
||||
class DistributedIteratorTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["dataset"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTypeSpec(self, input_type, distribution,
|
||||
enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator has CompositeTensor support in "
|
||||
"TF 2 only.")
|
||||
dataset = dataset_ops.DatasetV2.range(10).batch(2)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
|
||||
spec = iterator._type_spec
|
||||
self.assertEqual(spec._input_workers, iterator._input_workers)
|
||||
self.assertEqual(spec._element_spec._value_specs,
|
||||
(tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64,
|
||||
name=None),
|
||||
tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64,
|
||||
name=None)))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["dataset"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTypeSpecRoundTrip(self, input_type,
|
||||
distribution, enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
dataset = dataset_ops.DatasetV2.range(10).batch(2)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
|
||||
spec = iterator._type_spec
|
||||
|
||||
tensor_list = spec._to_components(iterator)
|
||||
re_iterator = spec._from_components(tensor_list)
|
||||
|
||||
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
|
||||
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
input_type=["dataset"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testDoesNotTriggerFunctionTracing(self, input_type, distribution,
|
||||
enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
trace_count = [0]
|
||||
|
||||
@def_function.function
|
||||
def f(iterator):
|
||||
trace_count[0] += 1
|
||||
counter = np.int64(0)
|
||||
for _ in range(5):
|
||||
next(iterator)
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
dataset = dataset_ops.DatasetV2.range(10).batch(2)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
for _ in range(3):
|
||||
iterator = iter(dist_dataset)
|
||||
counter = f(iterator)
|
||||
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
self.assertEqual(counter, 5)
|
||||
|
||||
|
||||
class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
input_type=["dataset", "dataset_fn"],
|
||||
))
|
||||
def testInputSignatureForPerReplicaValues(self, distribution, input_type):
|
||||
def dataset_fn(ctx):
|
||||
del ctx # unused
|
||||
return dataset_ops.DatasetV2.from_tensor_slices(
|
||||
np.ones([10, 12]).astype(np.float32)).batch(4)
|
||||
|
||||
if input_type == "dataset":
|
||||
ds = distribution.experimental_distribute_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
type_spec = ds.element_spec
|
||||
else:
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
iterator = iter(ds)
|
||||
type_spec = iterator.element_spec
|
||||
|
||||
@def_function.function(input_signature=[type_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
for x in ds:
|
||||
process_inputs(x)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
))
|
||||
def testInputSignatureForNestedPerReplicaValues(self, distribution):
|
||||
a = np.ones((10, 2)) * 5
|
||||
b = np.ones((10, 3)) * 6
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
|
||||
@def_function.function(input_signature=[dist_dataset.element_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
for x in dist_dataset:
|
||||
process_inputs(x)
|
||||
|
||||
|
||||
class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTypeSpec(self, distribution, enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator has CompositeTensor support in "
|
||||
"TF 2.0 only.")
|
||||
ctx = distribute_lib.InputContext()
|
||||
batch_size = ctx.get_per_replica_batch_size(8)
|
||||
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
|
||||
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
||||
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
||||
"dense": ragged_tensor.to_tensor(),
|
||||
"ragged": ragged_tensor,
|
||||
"sparse": ragged_tensor.to_sparse(),
|
||||
})
|
||||
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
|
||||
spec = iterator._type_spec
|
||||
self.assertEqual(spec._input_workers, iterator._input_workers)
|
||||
self.assertEqual(
|
||||
spec._element_spec, {
|
||||
"sparse":
|
||||
values.PerReplicaSpec(
|
||||
sparse_tensor.SparseTensorSpec(
|
||||
tensor_shape.TensorShape([None, 3]), dtypes.float32),
|
||||
sparse_tensor.SparseTensorSpec(
|
||||
tensor_shape.TensorShape([None, 3]), dtypes.float32)),
|
||||
"dense":
|
||||
values.PerReplicaSpec(
|
||||
tensor_spec.TensorSpec(
|
||||
shape=(None, 3), dtype=dtypes.float32, name=None),
|
||||
tensor_spec.TensorSpec(
|
||||
shape=(None, 3), dtype=dtypes.float32, name=None)),
|
||||
"ragged":
|
||||
values.PerReplicaSpec(
|
||||
ragged_tensor_lib.RaggedTensorSpec(
|
||||
tensor_shape.TensorShape([None, None]), dtypes.float32,
|
||||
1, dtypes.int64),
|
||||
ragged_tensor_lib.RaggedTensorSpec(
|
||||
tensor_shape.TensorShape([None, None]), dtypes.float32,
|
||||
1, dtypes.int64))
|
||||
})
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
ctx = distribute_lib.InputContext()
|
||||
batch_size = ctx.get_per_replica_batch_size(8)
|
||||
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||
row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
|
||||
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
||||
np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths)
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
||||
"dense": ragged_tensor.to_tensor(),
|
||||
"ragged": ragged_tensor,
|
||||
"sparse": ragged_tensor.to_sparse(),
|
||||
})
|
||||
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
iterator = iter(dist_dataset)
|
||||
|
||||
spec = iterator._type_spec
|
||||
|
||||
tensor_list = spec._to_components(iterator)
|
||||
re_iterator = spec._from_components(tensor_list)
|
||||
|
||||
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
|
||||
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testDoesNotTriggerFunctionTracing(self, distribution,
|
||||
enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
trace_count = [0]
|
||||
|
||||
@def_function.function
|
||||
def f(iterator):
|
||||
trace_count[0] += 1
|
||||
counter = np.int64(0)
|
||||
for _ in range(5):
|
||||
next(iterator)
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
ctx = distribute_lib.InputContext()
|
||||
batch_size = ctx.get_per_replica_batch_size(8)
|
||||
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||
row_lengths = np.mod(np.arange(50), 4).astype(np.int64)
|
||||
ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
|
||||
np.repeat(np.arange(50, dtype=np.float32), row_lengths), row_lengths)
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
||||
"dense": ragged_tensor.to_tensor(),
|
||||
"ragged": ragged_tensor,
|
||||
"sparse": ragged_tensor.to_sparse(),
|
||||
})
|
||||
dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
|
||||
dataset = dataset.batch(batch_size)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
with distribution.scope():
|
||||
for _ in range(3):
|
||||
iterator = iter(dist_dataset)
|
||||
counter = f(iterator)
|
||||
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
self.assertEqual(counter, 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1161,11 +1161,6 @@ class DataHandler(object):
|
||||
if self._insufficient_data: # Set by `catch_stop_iteration`.
|
||||
break
|
||||
if self._adapter.should_recreate_iterator():
|
||||
if ds_context.has_strategy():
|
||||
# TODO(b/138326910): remove this when MultiDeviceIterator is a
|
||||
# CompositeTensor (unless this is more efficient)
|
||||
data_iterator._initializer # pylint: disable=pointless-statement, protected-access
|
||||
else:
|
||||
data_iterator = iter(self._dataset)
|
||||
yield epoch, data_iterator
|
||||
self._adapter.on_epoch_end()
|
||||
|
Loading…
Reference in New Issue
Block a user