diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 6aad2535598..aa303b9e298 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -847,6 +847,38 @@ distribute_py_test( srcs = ["input_lib_test.py"], main = "input_lib_test.py", shard_count = 10, + tags = [ + "multi_and_single_gpu", + "no_gpu_presubmit", # TODO(b/154660040) + ], + deps = [ + ":collective_all_reduce_strategy", + ":combinations", + ":input_lib", + ":mirrored_strategy", + ":multi_worker_test_base", + ":reduce_util", + ":strategy_combinations", + ":values", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "//tensorflow/python/ops/ragged:ragged_tensor", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +distribute_py_test( + name = "input_lib_type_spec_test", + srcs = ["input_lib_type_spec_test.py"], + main = "input_lib_type_spec_test.py", + shard_count = 10, tags = [ "multi_and_single_gpu", ], @@ -1453,9 +1485,10 @@ distribute_py_test( name = "ctl_correctness_test", srcs = ["ctl_correctness_test.py"], main = "ctl_correctness_test.py", - shard_count = 10, + shard_count = 20, tags = [ "multi_and_single_gpu", + "no_gpu_presubmit", # TODO(b/154660040) "noguitar", # b/140755528 ], deps = [ diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index d35bb85cd1b..7c7f521af98 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -330,8 +330,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): communication=self._communication) super(CollectiveAllReduceExtended, self)._initialize_single_worker( local_devices) + host_device = device_util.get_host_for_device(self._worker_device) self._input_workers = input_lib.InputWorkers( - [(self._worker_device, self.worker_devices)]) + [(host_device, self.worker_devices)]) # Add a default device so that ops without specified devices will not end up # on other workers. diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 9919884642e..0bc183649e6 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -33,6 +33,7 @@ from tensorflow.python.distribute import input_ops from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import values from tensorflow.python.eager import context +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes @@ -41,6 +42,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -143,9 +145,10 @@ class InputWorkers(object): worker_device_pairs: A sequence of pairs: `(input device, a tuple of compute devices fed by that input device)`. """ - self._input_worker_devices = tuple(d for d, _ in worker_device_pairs) + self._worker_device_pairs = worker_device_pairs + self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs) self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) - for _, f in worker_device_pairs) + for _, f in self._worker_device_pairs) @property def num_workers(self): @@ -165,6 +168,12 @@ class InputWorkers(object): for i in range(len(devices))) return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) + def serialize(self): + return self._worker_device_pairs + + def deserialize(self, worker_device_pairs): + return InputWorkers(worker_device_pairs) + def _get_next_as_optional(iterator, strategy, name=None): """Returns an empty dataset indicator and the next input from the iterator.""" @@ -208,7 +217,7 @@ def _get_next_as_optional(iterator, strategy, name=None): def _is_statically_shaped(tensor_class, shape): - """Test if an iteratort output is statically shaped. + """Test if an iterator output is statically shaped. For sparse and ragged tensors this only tests the batch dimension. @@ -231,20 +240,27 @@ def _is_statically_shaped(tensor_class, shape): return shape.is_fully_defined() -class DistributedIterator(object): +def _get_static_shape(iterators): + """Returns a boolean indicating if the input is fully defined.""" + static_shape = True + for iterator in iterators: + if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator, + _SingleWorkerDatasetIterator)): + continue + flattened = zip(nest.flatten(iterator.output_shapes), + nest.flatten(iterator.output_classes)) + for output_shape, output_class in flattened: + if not _is_statically_shaped(output_class, output_shape): + static_shape = False + break + return static_shape + + +class DistributedIteratorBase(object): """Common implementation for all input iterators.""" def __init__(self, input_workers, iterators, strategy): - static_shape = True - for iterator in iterators: - if not isinstance(iterator, _SingleWorkerDatasetIterator): - continue - flattened = zip(nest.flatten(iterator.output_shapes), - nest.flatten(iterator.output_classes)) - for output_shape, output_class in flattened: - if not _is_statically_shaped(output_class, output_shape): - static_shape = False - break + static_shape = _get_static_shape(iterators) # TODO(b/133073708): we currently need a flag to control the usage because # there is a performance difference between get_next() and @@ -360,6 +376,10 @@ class DistributedIterator(object): return values.regroup(replicas) + +class DistributedIteratorV1(DistributedIteratorBase): + """Input Iterator for a distributed dataset.""" + # We need a private initializer method for re-initializing multidevice # iterators when used with Keras training loops. If we don't reinitialize the # iterator we run into memory leak issues (b/123315763). @@ -370,23 +390,14 @@ class DistributedIterator(object): init_ops.extend(it.initialize()) return control_flow_ops.group(init_ops) - @property - def element_spec(self): - """The type specification of an element of this iterator.""" - return self._element_spec - - -class DistributedIteratorV1(DistributedIterator): - """Input Iterator for a distributed dataset instance.""" - @deprecated(None, "Use the iterator's `initializer` property instead.") def initialize(self): - """Initialze underlying iterators. + """Initialize underlying iterators. Returns: A list of any initializer ops that should be run. """ - return super(DistributedIteratorV1, self)._initializer + return self._initializer @property def initializer(self): @@ -415,6 +426,161 @@ class DistributedIteratorV1(DistributedIterator): return self._iterators[i] return None + @property + def element_spec(self): + """The type specification of an element of this iterator.""" + return self._element_spec + + +class DistributedIteratorSpec(type_spec.TypeSpec): + """Type specification for `DistributedIterator`.""" + + __slots__ = ["_input_workers", "_element_spec", "_strategy"] + + def __init__(self, input_workers, element_spec, strategy): + # We don't want to allow deserialization of this class because we don't + # serialize the strategy object. Currently the only places where + # _deserialize is called is when we save/restore using SavedModels. + if isinstance(input_workers, tuple): + raise NotImplementedError("DistributedIteratorSpec does not have support " + "for deserialization.") + else: + self._input_workers = input_workers + self._element_spec = element_spec + self._strategy = strategy + + @property + def value_type(self): + return DistributedIterator + + def _serialize(self): + # We cannot serialize the strategy object so we convert it to an id that we + # can use for comparison. + return (self._input_workers.serialize(), + self._element_spec, id(self._strategy)) + + def _deserialize(self): + raise ValueError("Deserialization is currently unsupported for " + "DistributedIteratorSpec.") + + @staticmethod + def _is_compatible(a, b): + """Returns true if the given type serializations compatible.""" + if type(a) is not type(b): + return False + if isinstance(a, tuple): + return (len(a) == len(b) and + all(DistributedIteratorSpec._is_compatible(x, y) for (x, y) in + zip(a, b))) + if isinstance(a, dict): + return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all( + DistributedIteratorSpec._is_compatible(a[k], b[k]) for k in a.keys())) + if isinstance(a, (type_spec.TypeSpec, tensor_shape.TensorShape, + dtypes.DType)): + return a.is_compatible_with(b) + return a == b + + # Overriding this method so that we can merge and reconstruct the spec object + def most_specific_compatible_type(self, other): + """Returns the most specific TypeSpec compatible with `self` and `other`. + + Args: + other: A `TypeSpec`. + + Raises: + ValueError: If there is no TypeSpec that is compatible with both `self` + and `other`. + """ + # pylint: disable=protected-access + if type(self) is not type(other): + raise ValueError("No TypeSpec is compatible with both %s and %s" % + (self, other)) + if not self._is_compatible(self._input_workers.serialize(), + other._input_workers.serialize()): + raise ValueError("_input_workers is not compatible with both %s " + "and %s" % (self, other)) + if self._element_spec != other._element_spec: + raise ValueError("_element_spec is not compatible with both %s " + "and %s" % (self, other)) + if id(self._strategy) != id(other._strategy): + raise ValueError("tf.distribute strategy is not compatible with both %s " + "and %s" % (self, other)) + return DistributedIteratorSpec(self._input_workers, self._element_spec, + self._strategy) + + @property + def _component_specs(self): + specs = [] + worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access + for i in range(len(worker_device_pairs)): + input_device, compute_devices = worker_device_pairs[i] + specs.append(_SingleWorkerDatasetIteratorSpec(input_device, + compute_devices, + element_spec= + self._element_spec)) + return specs + + def _to_components(self, value): + return value._iterators # pylint: disable=protected-access + + def _from_components(self, components): + return DistributedIterator(input_workers=self._input_workers, + iterators=None, + components=components, + element_spec=self._element_spec, + strategy=self._strategy) + + @staticmethod + def from_value(value): + # pylint: disable=protected-access + return DistributedIteratorSpec(value._input_workers, value._element_spec, + value._strategy) + + +class DistributedIterator(DistributedIteratorBase, + composite_tensor.CompositeTensor): + """Input Iterator for a distributed dataset.""" + + def __init__(self, input_workers=None, iterators=None, strategy=None, + components=None, element_spec=None): + if input_workers is None: + raise ValueError("`input_workers` should be " + "provided.") + + error_message = ("Either `input_workers` or " + "both `components` and `element_spec` need to be " + "provided.") + + if iterators is None: + if (components is None or element_spec is None): + raise ValueError(error_message) + self._element_spec = element_spec + self._input_workers = input_workers + self._iterators = components + static_shape = _get_static_shape(self._iterators) + self._strategy = strategy + if getattr( + strategy.extended, "experimental_enable_get_next_as_optional", False): + self._enable_get_next_as_optional = not static_shape + else: + self._enable_get_next_as_optional = False + else: + if (components is not None and element_spec is not None): + raise ValueError(error_message) + + super(DistributedIterator, self).__init__(input_workers, iterators, + strategy) + + @property + def element_spec(self): + return self._element_spec + + @property + def _type_spec(self): + return DistributedIteratorSpec(self._input_workers, + self.element_spec, + self._strategy) + class _IterableInput(object): """Base class for iterable inputs for distribution strategies.""" @@ -482,7 +648,6 @@ class DistributedDataset(_IterableInput): `num_input_pipelines` in the `InputContext`. """ super(DistributedDataset, self).__init__(input_workers=input_workers) - # We clone and shard the dataset on each worker. The current setup tries to # shard the dataset by files if possible so that each worker sees a # different subset of files. If that is not possible, will attempt to shard @@ -541,10 +706,20 @@ class DistributedDataset(_IterableInput): raise RuntimeError("__iter__() is only supported inside of tf.function " "or when eager execution is enabled.") + # This is an optional flag that can be used to turn off using + # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators + # as a stop gap solution that will allow us to roll out this change. + enable_legacy_iterators = getattr(self._strategy, + "_enable_legacy_iterators", False) worker_iterators = _create_iterators_per_worker(self._cloned_datasets, - self._input_workers) - iterator = DistributedIterator(self._input_workers, worker_iterators, - self._strategy) + self._input_workers, + enable_legacy_iterators) + if enable_legacy_iterators: + iterator = DistributedIteratorV1(self._input_workers, worker_iterators, + self._strategy) + else: + iterator = DistributedIterator(self._input_workers, worker_iterators, + self._strategy) iterator._element_spec = self.element_spec # pylint: disable=protected-access return iterator @@ -615,12 +790,21 @@ class DistributedDatasetV1(DistributedDataset): def _get_iterator(self): worker_iterators = _create_iterators_per_worker(self._cloned_datasets, - self._input_workers) + self._input_workers, + True) iterator = DistributedIteratorV1(self._input_workers, worker_iterators, self._strategy) iterator._element_spec = self.element_spec # pylint: disable=protected-access return iterator + def __iter__(self): + if (ops.executing_eagerly_outside_functions() or + ops.get_default_graph().building_function): + return self._get_iterator() + + raise RuntimeError("__iter__() is only supported inside of tf.function " + "or when eager execution is enabled.") + # TODO(priyag): Add other replication modes. class DistributedDatasetsFromFunction(_IterableInput): @@ -653,20 +837,36 @@ class DistributedDatasetsFromFunction(_IterableInput): self._strategy = strategy self._element_spec = None - def __iter__(self): - if not (context.executing_eagerly() or - ops.get_default_graph().building_function): - raise RuntimeError("__iter__() is only supported inside of tf.function " - "or when eager execution is enabled.") + super(DistributedDatasetsFromFunction, self).__init__( + input_workers=input_workers) - iterators, element_spec = _create_iterators_per_worker_with_input_context( - self._input_contexts, self._input_workers, self._dataset_fn) - iterator = DistributedIterator(self._input_workers, iterators, - self._strategy) - self._element_spec = _create_distributed_tensor_spec(self._strategy, - element_spec) - iterator._element_spec = self._element_spec # pylint: disable=protected-access - return iterator + def __iter__(self): + if (ops.executing_eagerly_outside_functions() or + ops.get_default_graph().building_function): + # This is an optional flag that can be used to turn off using + # OwnedMultiDeviceIterators and instead use the legacy + # MultiDeviceIterators as a stop gap solution that will allow us to roll + # out this change. + enable_legacy_iterators = getattr(self._strategy, + "_enable_legacy_iterators", False) + + iterators, element_spec = _create_iterators_per_worker_with_input_context( + self._input_contexts, self._input_workers, self._dataset_fn, + enable_legacy_iterators) + + if enable_legacy_iterators: + iterator = DistributedIteratorV1(self._input_workers, iterators, + self._strategy) + else: + iterator = DistributedIterator(self._input_workers, iterators, + self._strategy) + self._element_spec = _create_distributed_tensor_spec(self._strategy, + element_spec) + iterator._element_spec = self._element_spec # pylint: disable=protected-access + return iterator + + raise RuntimeError("__iter__() is only supported inside of tf.function " + "or when eager execution is enabled.") @property def element_spec(self): @@ -705,7 +905,8 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): def _get_iterator(self): iterators, element_spec = _create_iterators_per_worker_with_input_context( - self._input_contexts, self._input_workers, self._dataset_fn) + self._input_contexts, self._input_workers, self._dataset_fn, + True) iterator = DistributedIteratorV1(self._input_workers, iterators, self._strategy) self._element_spec = _create_distributed_tensor_spec(self._strategy, @@ -713,6 +914,14 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): iterator._element_spec = self._element_spec # pylint: disable=protected-access return iterator + def __iter__(self): + if (ops.executing_eagerly_outside_functions() or + ops.get_default_graph().building_function): + return self._get_iterator() + + raise RuntimeError("__iter__() is only supported inside of tf.function " + "or when eager execution is enabled.") + # TODO(anjalisridhar): This class will be soon be removed in favor of newer # APIs. @@ -797,7 +1006,7 @@ class DatasetIterator(DistributedIteratorV1): split_batch_by=split_batch_by, input_context=input_context) worker_iterators = _create_iterators_per_worker( - dist_dataset._cloned_datasets, input_workers) # pylint: disable=protected-access + dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access super(DatasetIterator, self).__init__( input_workers, worker_iterators, # pylint: disable=protected-access @@ -808,18 +1017,18 @@ class DatasetIterator(DistributedIteratorV1): def _dummy_tensor_fn(value_structure): """A function to create dummy tensors from `value_structure`.""" - def create_dummy_tensor(type_spec): + def create_dummy_tensor(spec): """Create a dummy tensor with possible batch dimensions set to 0.""" - if isinstance(type_spec, ragged_tensor.RaggedTensorSpec): + if isinstance(spec, ragged_tensor.RaggedTensorSpec): # Splice out the ragged dimensions. # pylint: disable=protected-access - feature_shape = type_spec._shape[:1].concatenate( - type_spec._shape[(1 + type_spec._ragged_rank):]) - feature_type = type_spec._dtype + feature_shape = spec._shape[:1].concatenate( + spec._shape[(1 + spec._ragged_rank):]) + feature_type = spec._dtype # pylint: enable=protected-access else: - feature_shape = type_spec.shape - feature_type = type_spec.dtype + feature_shape = spec.shape + feature_type = spec.dtype # Ideally we should set the batch dimension to 0, however as in # DistributionStrategy we don't know the batch dimension, we try to # guess it as much as possible. If the feature has unknown dimensions, we @@ -827,11 +1036,11 @@ def _dummy_tensor_fn(value_structure): # first dimension as batch dimension and set it to 0. dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] if feature_shape else []) - if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or + if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or feature_shape.is_fully_defined()): dims[0] = tensor_shape.Dimension(0) - if isinstance(type_spec, sparse_tensor.SparseTensorSpec): + if isinstance(spec, sparse_tensor.SparseTensorSpec): return sparse_tensor.SparseTensor( values=array_ops.zeros(0, feature_type), indices=array_ops.zeros((0, len(dims)), dtypes.int64), @@ -839,26 +1048,26 @@ def _dummy_tensor_fn(value_structure): # Create the dummy tensor. dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) - if isinstance(type_spec, ragged_tensor.RaggedTensorSpec): + if isinstance(spec, ragged_tensor.RaggedTensorSpec): # Reinsert the ragged dimensions with size 0. # pylint: disable=protected-access - row_splits = array_ops.zeros(1, type_spec._row_splits_dtype) + row_splits = array_ops.zeros(1, spec._row_splits_dtype) dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( - dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False) + dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False) # pylint: enable=protected-access return dummy_tensor return nest.map_structure(create_dummy_tensor, value_structure) -class _SingleWorkerDatasetIterator(object): +class _SingleWorkerDatasetIteratorBase(object): """Iterator for a single `tf.data.Dataset`.""" def __init__(self, dataset, worker, devices): """Create iterator for the `dataset` to fetch data to worker's `devices` . - `MultiDeviceIterator` is used to prefetch input to the devices on the - given worker. + A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch + input to the devices on the given worker. Args: dataset: A `tf.data.Dataset` instance. @@ -868,13 +1077,11 @@ class _SingleWorkerDatasetIterator(object): self._dataset = dataset self._worker = worker self._devices = devices + self._element_spec = dataset.element_spec self._make_iterator() def _make_iterator(self): - """Make appropriate iterator on the dataset.""" - with ops.device(self._worker): - self._iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices) + raise NotImplementedError("must be implemented in descendants") def get_next(self, device, name=None): """Get next element for the given device.""" @@ -923,9 +1130,9 @@ class _SingleWorkerDatasetIterator(object): # Place the condition op in the same device as the data so the data # doesn't need to be sent back to the worker. with ops.device(self._devices[i]): - # As MultiDeviceIterator will fetch data in order, so we only need to - # check if the first replica has value to see whether there is data - # left for this single worker. + # Data will be fetched in order, so we only need to check if the first + # replica has value to see whether there is data left for this single + # worker. if i == 0: worker_has_value = data.has_value() @@ -943,8 +1150,159 @@ class _SingleWorkerDatasetIterator(object): return worker_has_value, result + +class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): + """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" + + __slots__ = ["_worker", "_devices", "_element_spec"] + + def __init__(self, worker, devices, element_spec): + self._worker = worker + self._devices = devices + self._element_spec = element_spec + + @property + def value_type(self): + return _SingleWorkerOwnedDatasetIterator + + def _serialize(self): + return (self._worker, tuple(self._devices), self._element_spec) + + @property + def _component_specs(self): + specs = [] + specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec( + self._devices, self._worker, element_spec=self._element_spec)) + return specs + + def _to_components(self, value): + return [value._iterator] # pylint: disable=protected-access + + def _from_components(self, components): + return _SingleWorkerOwnedDatasetIterator( + dataset=None, + worker=self._worker, + devices=self._devices, + components=components, + element_spec=self._element_spec) + + @staticmethod + def from_value(value): + # pylint: disable=protected-access + return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, + value._element_spec) + + +class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, + composite_tensor.CompositeTensor): + """Iterator for a DistributedDataset instance.""" + + def __init__(self, dataset=None, worker=None, devices=None, components=None, + element_spec=None): + """Create iterator for the `dataset` to fetch data to worker's `devices` . + + `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the + given worker. The lifetime of this iterator is tied to the encompassing + python object. Once we go out of scope of the python object or return from + a tf.function the underlying iterator resource is deleted. + + Args: + dataset: A `tf.data.Dataset` instance. + worker: Worker on which ops should be created. + devices: Distribute data from `dataset` to these devices. + components: Tensor components to construct the + _SingleWorkerOwnedDatasetIterator from. + element_spec: A nested structure of `TypeSpec` objects that represents the + type specification of elements of the iterator. + """ + if worker is None or devices is None: + raise ValueError("Both `worker` and `devices` should be provided") + + error_message = ("Either `dataset` or both `components` and `element_spec` " + "need to be provided.") + + if dataset is None: + if (components is None or element_spec is None): + raise ValueError(error_message) + self._element_spec = element_spec + self._worker = worker + self._devices = devices + self._iterator = components[0] + else: + if (components is not None or element_spec is not None): + raise ValueError(error_message) + super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker, + devices) + + def _make_iterator(self): + """Make appropriate iterator on the dataset.""" + if not self._worker: + raise ValueError("Worked device must be specified when creating an " + "owned iterator.") + host_device = device_util.get_host_for_device(self._worker) + with ops.device(self._worker): + self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( + self._dataset, self._devices, source_device=host_device) + + @property + def element_spec(self): + return self._element_spec + + @property + def _type_spec(self): + return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, + self._element_spec) + + @property + def output_classes(self): + """Returns the class of each component of an element of this iterator. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of an element of this dataset. + """ + return nest.map_structure( + lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access + self._element_spec) + + @property + def output_shapes(self): + """Returns the shape of each component of an element of this iterator. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of an element of this dataset. + """ + return nest.map_structure( + lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access + self._element_spec) + + @property + def output_types(self): + """Returns the type of each component of an element of this iterator. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of an element of this dataset. + """ + return nest.map_structure( + lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access + self._element_spec) + + +class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase): + """Iterator for a single DistributedDatasetV1 instance.""" + + def _make_iterator(self): + """Make appropriate iterator on the dataset.""" + with ops.device(self._worker): + self._iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._devices) + def initialize(self): - """Initialze underlying iterator. + """Initialize underlying iterator. In eager execution, this simply recreates the underlying iterator. In graph execution, it returns the initializer ops for the underlying @@ -1005,7 +1363,8 @@ class _SingleWorkerCallableIterator(object): return [] -def _create_iterators_per_worker(worker_datasets, input_workers): +def _create_iterators_per_worker(worker_datasets, input_workers, + enable_legacy_iterators): """Create a multidevice iterator on each of the workers.""" assert isinstance(input_workers, InputWorkers) @@ -1014,23 +1373,35 @@ def _create_iterators_per_worker(worker_datasets, input_workers): for i, worker in enumerate(input_workers.worker_devices): with ops.device(worker): worker_devices = input_workers.compute_devices_for_worker(i) - iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, - worker_devices) + if tf2.enabled() and not enable_legacy_iterators: + iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker, + worker_devices) + else: + iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, + worker_devices) iterators.append(iterator) return iterators def _create_iterators_per_worker_with_input_context(input_contexts, input_workers, - dataset_fn): + dataset_fn, + enable_legacy_iterators): """Create a multidevice iterator per workers given a dataset function.""" iterators = [] + element_specs = [] for i, ctx in enumerate(input_contexts): worker = input_workers.worker_devices[i] with ops.device(worker): dataset = dataset_fn(ctx) + element_specs.append(dataset.element_spec) devices = input_workers.compute_devices_for_worker(i) - iterator = _SingleWorkerDatasetIterator(dataset, worker, devices) + if tf2.enabled() and not enable_legacy_iterators: + iterator = _SingleWorkerOwnedDatasetIterator(dataset, worker, + devices) + else: + iterator = _SingleWorkerDatasetIterator(dataset, worker, + devices) iterators.append(iterator) return iterators, dataset.element_spec diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index be78bf16190..7aa0c804786 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -45,6 +45,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -105,20 +106,21 @@ class DistributedIteratorTestBase(test.TestCase): split_batch_by, strategy, input_context=None): - if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)): - return input_lib.DistributedDatasetV1( - dataset, - input_workers, - strategy, - split_batch_by=split_batch_by, - input_context=input_context) - elif input_type == "dataset": - return input_lib.DistributedDataset( - dataset, - input_workers, - strategy, - split_batch_by=split_batch_by, - input_context=input_context) + if input_type == "dataset": + if tf2.enabled(): + return input_lib.DistributedDataset( + dataset, + input_workers, + strategy, + split_batch_by=split_batch_by, + input_context=input_context) + else: + return input_lib.DistributedDatasetV1( + dataset, + input_workers, + strategy, + split_batch_by=split_batch_by, + input_context=input_context) else: return strategy.experimental_distribute_datasets_from_function(dataset) @@ -139,6 +141,9 @@ class DistributedIteratorTestBase(test.TestCase): if api_type == "wrap_into_iterator" and iteration_type == "for_loop": self.skipTest("unsupported test combination.") + if api_type == "wrap_into_iterator" and input_type == "input_fn": + self.skipTest("unsupported test combination.") + devices = nest.flatten([ds for _, ds in worker_device_pairs]) input_workers = input_lib.InputWorkers(worker_device_pairs) @@ -161,7 +166,7 @@ class DistributedIteratorTestBase(test.TestCase): strategy, input_context=input_context) - if context.executing_eagerly(): + if ops.executing_eagerly_outside_functions(): iterator = iter(dataset) else: if isinstance(dataset, input_lib.DistributedDatasetV1): @@ -171,7 +176,7 @@ class DistributedIteratorTestBase(test.TestCase): if iteration_type == "get_next": evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) - if isinstance(iterator, input_lib.DistributedIteratorV1): + if not ops.executing_eagerly_outside_functions(): evaluate(control_flow_ops.group(iterator.initializer)) for expected_value in expected_values: @@ -190,10 +195,13 @@ class DistributedIteratorTestBase(test.TestCase): next_element) for r in range(len(devices))]) # After re-initializing the iterator, should be able to iterate again. - if isinstance(iterator, input_lib.DistributedIteratorV1): + if not ops.executing_eagerly_outside_functions(): evaluate(control_flow_ops.group(iterator.initializer)) else: - evaluate(control_flow_ops.group(iterator._initializer)) + if api_type == "wrap_into_iterator": + self.skipTest("unsupported test combination") + else: + iterator = iter(dataset) for expected_value in expected_values: next_element = iterator.get_next() @@ -225,6 +233,48 @@ class DistributedIteratorTestBase(test.TestCase): class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, parameterized.TestCase): + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["input_fn", "dataset"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu + ])) + def testDisablingOwnedIteratorsInTF2(self, distribution, input_type): + if not tf2.enabled(): + self.skipTest("unsupported test combination") + + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] + input_workers = input_lib.InputWorkers(worker_device_pairs) + dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) + dataset_or_input_fn = self._create_dataset_or_input_fn( + input_type, dataset_fn) + + input_workers = input_lib.InputWorkers(worker_device_pairs) + if input_type == "dataset": + dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn, + input_workers, + distribution) + else: + dist_dataset = input_lib.get_distributed_datasets_from_function( + dataset_or_input_fn, input_workers, [distribute_lib.InputContext()], + distribution) + + # Default Iterator types in TF2. + iterator = iter(dist_dataset) + self.assertIsInstance(iterator, input_lib.DistributedIterator) + self.assertIsInstance(iterator._iterators[0], + input_lib._SingleWorkerOwnedDatasetIterator) + + # Disable creating owned iterators by setting a property on the strategy. + distribution._enable_legacy_iterators = True + iterator = iter(dist_dataset) + self.assertIsInstance(iterator, input_lib.DistributedIteratorV1) + self.assertIsInstance(iterator._iterators[0], + input_lib._SingleWorkerDatasetIterator) + @combinations.generate( combinations.combine( mode=["eager"], @@ -234,7 +284,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, def testMultiDeviceIterInitialize(self, distribution): if tf2.enabled(): self.skipTest("Only V1 is supported.") - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", + "/device:CPU:0"])] dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) input_workers = input_lib.InputWorkers(worker_device_pairs) @@ -250,25 +301,6 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, init_func_for_iter() - @combinations.generate( - combinations.combine( - mode=["graph"], - distribution=[ - strategy_combinations.one_device_strategy, - strategy_combinations.mirrored_strategy_with_one_cpu - ])) - def testDatasetV2IterError(self, distribution): - worker_device_pairs = [("", ["/device:CPU:0"])] - input_workers = input_lib.InputWorkers(worker_device_pairs) - dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) - - dist_dataset = input_lib.get_distributed_dataset( - dataset_fn(distribute_lib.InputContext()), input_workers, distribution) - - with self.assertRaisesRegexp(RuntimeError, - "or when eager execution is enabled"): - iter(dist_dataset) - @combinations.generate( combinations.combine( mode=["graph", "eager"], @@ -282,11 +314,11 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, enable_get_next_as_optional=[True, False])) def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, enable_get_next_as_optional): - worker_device_pairs = [("", ["/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] if tf2.enabled(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) else: - dataset_fn = lambda _: dataset_ops.Dataset.range(10) + dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) dataset_or_input_fn = self._create_dataset_or_input_fn( input_type, dataset_fn) @@ -316,7 +348,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, enable_get_next_as_optional=[True, False])) def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type, distribution, enable_get_next_as_optional): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", + "/device:CPU:0"])] if tf2.enabled(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) else: @@ -386,7 +419,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, enable_get_next_as_optional=[True, False])) def testTupleDataset(self, input_type, api_type, iteration_type, distribution, enable_get_next_as_optional): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", + "/device:CPU:0"])] def dataset_fn(ctx): del ctx @@ -422,7 +456,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, strategy_combinations.mirrored_strategy_with_one_cpu ])) def testIterableIterator(self, distribution): - worker_device_pairs = [("", ["/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] input_workers = input_lib.InputWorkers(worker_device_pairs) dataset = dataset_ops.DatasetV2.range(10) @@ -446,7 +480,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, ])) def testUnevenDatasetBatches(self, input_type, api_type, iteration_type, drop_remainder, distribution): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", + "/device:CPU:0"])] if tf2.enabled(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch( # pylint: disable=g-long-lambda 2, drop_remainder=drop_remainder) @@ -486,7 +521,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, def testBatchSplitting(self, input_type, api_type, iteration_type, split_batch_by, distribution, enable_get_next_as_optional): - worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] + worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", + "/device:CPU:0"])] batch_size = 10 if tf2.enabled(): dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size) @@ -1075,68 +1111,5 @@ class DistributedIteratorMultiWorkerTest( strategy, sess=sess) - -class InputTypeSpecTest(test.TestCase, parameterized.TestCase): - - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.one_device_strategy, - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - input_type=["dataset", "dataset_fn"], - )) - def testInputSignatureForPerReplicaValues(self, distribution, input_type): - def dataset_fn(ctx): - del ctx # unused - return dataset_ops.DatasetV2.from_tensor_slices( - np.ones([10, 12]).astype(np.float32)).batch(4) - - if input_type == "dataset": - ds = distribution.experimental_distribute_dataset( - dataset_fn(distribute_lib.InputContext())) - type_spec = ds.element_spec - else: - ds = distribution.experimental_distribute_datasets_from_function( - dataset_fn) - iterator = iter(ds) - type_spec = iterator.element_spec - - @def_function.function(input_signature=[type_spec]) - def process_inputs(inputs): - distribution.run(lambda inputs: inputs, args=(inputs,)) - - for x in ds: - process_inputs(x) - - @combinations.generate( - combinations.combine( - mode=["eager"], - distribution=[ - strategy_combinations.one_device_strategy, - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - )) - def testInputSignatureForNestedPerReplicaValues(self, distribution): - a = np.ones((10, 2)) * 5 - b = np.ones((10, 3)) * 6 - dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2) - - dist_dataset = distribution.experimental_distribute_dataset(dataset) - - @def_function.function(input_signature=[dist_dataset.element_spec]) - def process_inputs(inputs): - distribution.run(lambda inputs: inputs, args=(inputs,)) - - for x in dist_dataset: - process_inputs(x) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py new file mode 100644 index 00000000000..0671875b06d --- /dev/null +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -0,0 +1,366 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the input_lib library which tests iterator type specs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python import tf2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import values +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib + + +class DistributedIteratorTest(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpec(self, input_type, distribution, + enable_get_next_as_optional): + if not tf2.enabled(): + self.skipTest("DistributedIterator has CompositeTensor support in " + "TF 2 only.") + dataset = dataset_ops.DatasetV2.range(10).batch(2) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + with distribution.scope(): + iterator = iter(dist_dataset) + + spec = iterator._type_spec + self.assertEqual(spec._input_workers, iterator._input_workers) + self.assertEqual(spec._element_spec._value_specs, + (tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, + name=None), + tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, + name=None))) + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpecRoundTrip(self, input_type, + distribution, enable_get_next_as_optional): + if not tf2.enabled(): + self.skipTest("DistributedIterator CompositeTensor support is only " + "present in TF 2.0 only.") + + dataset = dataset_ops.DatasetV2.range(10).batch(2) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + with distribution.scope(): + iterator = iter(dist_dataset) + + spec = iterator._type_spec + + tensor_list = spec._to_components(iterator) + re_iterator = spec._from_components(tensor_list) + + self.assertEqual(iterator._input_workers, re_iterator._input_workers) + self.assertAllEqual(iterator._iterators, re_iterator._iterators) + + @combinations.generate( + combinations.combine( + mode=["eager"], + input_type=["dataset"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + enable_get_next_as_optional=[True, False])) + def testDoesNotTriggerFunctionTracing(self, input_type, distribution, + enable_get_next_as_optional): + if not tf2.enabled(): + self.skipTest("DistributedIterator CompositeTensor support is only " + "present in TF 2.0 only.") + + trace_count = [0] + + @def_function.function + def f(iterator): + trace_count[0] += 1 + counter = np.int64(0) + for _ in range(5): + next(iterator) + counter += 1 + return counter + + dataset = dataset_ops.DatasetV2.range(10).batch(2) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + with distribution.scope(): + for _ in range(3): + iterator = iter(dist_dataset) + counter = f(iterator) + + self.assertEqual(trace_count[0], 1) + self.assertEqual(counter, 5) + + +class InputTypeSpecTest(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + input_type=["dataset", "dataset_fn"], + )) + def testInputSignatureForPerReplicaValues(self, distribution, input_type): + def dataset_fn(ctx): + del ctx # unused + return dataset_ops.DatasetV2.from_tensor_slices( + np.ones([10, 12]).astype(np.float32)).batch(4) + + if input_type == "dataset": + ds = distribution.experimental_distribute_dataset( + dataset_fn(distribute_lib.InputContext())) + type_spec = ds.element_spec + else: + ds = distribution.experimental_distribute_datasets_from_function( + dataset_fn) + iterator = iter(ds) + type_spec = iterator.element_spec + + @def_function.function(input_signature=[type_spec]) + def process_inputs(inputs): + distribution.run(lambda inputs: inputs, args=(inputs,)) + + for x in ds: + process_inputs(x) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + )) + def testInputSignatureForNestedPerReplicaValues(self, distribution): + a = np.ones((10, 2)) * 5 + b = np.ones((10, 3)) * 6 + dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + + @def_function.function(input_signature=[dist_dataset.element_spec]) + def process_inputs(inputs): + distribution.run(lambda inputs: inputs, args=(inputs,)) + + for x in dist_dataset: + process_inputs(x) + + +class RaggedTensorDistributedIteratorTest(test.TestCase, + parameterized.TestCase): + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpec(self, distribution, enable_get_next_as_optional): + if not tf2.enabled(): + self.skipTest("DistributedIterator has CompositeTensor support in " + "TF 2.0 only.") + ctx = distribute_lib.InputContext() + batch_size = ctx.get_per_replica_batch_size(8) + # Use 20 which isn't divisible by 8 to test partial batch behavior. + row_lengths = np.mod(np.arange(20), 4).astype(np.int64) + ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( + np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) + dataset = dataset_ops.DatasetV2.from_tensor_slices({ + "dense": ragged_tensor.to_tensor(), + "ragged": ragged_tensor, + "sparse": ragged_tensor.to_sparse(), + }) + dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) + dataset = dataset.batch(batch_size) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + with distribution.scope(): + iterator = iter(dist_dataset) + + spec = iterator._type_spec + self.assertEqual(spec._input_workers, iterator._input_workers) + self.assertEqual( + spec._element_spec, { + "sparse": + values.PerReplicaSpec( + sparse_tensor.SparseTensorSpec( + tensor_shape.TensorShape([None, 3]), dtypes.float32), + sparse_tensor.SparseTensorSpec( + tensor_shape.TensorShape([None, 3]), dtypes.float32)), + "dense": + values.PerReplicaSpec( + tensor_spec.TensorSpec( + shape=(None, 3), dtype=dtypes.float32, name=None), + tensor_spec.TensorSpec( + shape=(None, 3), dtype=dtypes.float32, name=None)), + "ragged": + values.PerReplicaSpec( + ragged_tensor_lib.RaggedTensorSpec( + tensor_shape.TensorShape([None, None]), dtypes.float32, + 1, dtypes.int64), + ragged_tensor_lib.RaggedTensorSpec( + tensor_shape.TensorShape([None, None]), dtypes.float32, + 1, dtypes.int64)) + }) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional): + if not tf2.enabled(): + self.skipTest("DistributedIterator CompositeTensor support is only " + "present in TF 2.0 only.") + + ctx = distribute_lib.InputContext() + batch_size = ctx.get_per_replica_batch_size(8) + # Use 20 which isn't divisible by 8 to test partial batch behavior. + row_lengths = np.mod(np.arange(20), 4).astype(np.int64) + ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( + np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) + dataset = dataset_ops.DatasetV2.from_tensor_slices({ + "dense": ragged_tensor.to_tensor(), + "ragged": ragged_tensor, + "sparse": ragged_tensor.to_sparse(), + }) + dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) + dataset = dataset.batch(batch_size) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + with distribution.scope(): + iterator = iter(dist_dataset) + + spec = iterator._type_spec + + tensor_list = spec._to_components(iterator) + re_iterator = spec._from_components(tensor_list) + + self.assertEqual(iterator._input_workers, re_iterator._input_workers) + self.assertAllEqual(iterator._iterators, re_iterator._iterators) + + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + ], + enable_get_next_as_optional=[True, False])) + def testDoesNotTriggerFunctionTracing(self, distribution, + enable_get_next_as_optional): + if not tf2.enabled(): + self.skipTest("DistributedIterator CompositeTensor support is only " + "present in TF 2.0 only.") + + trace_count = [0] + + @def_function.function + def f(iterator): + trace_count[0] += 1 + counter = np.int64(0) + for _ in range(5): + next(iterator) + counter += 1 + return counter + + ctx = distribute_lib.InputContext() + batch_size = ctx.get_per_replica_batch_size(8) + # Use 20 which isn't divisible by 8 to test partial batch behavior. + row_lengths = np.mod(np.arange(50), 4).astype(np.int64) + ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( + np.repeat(np.arange(50, dtype=np.float32), row_lengths), row_lengths) + dataset = dataset_ops.DatasetV2.from_tensor_slices({ + "dense": ragged_tensor.to_tensor(), + "ragged": ragged_tensor, + "sparse": ragged_tensor.to_sparse(), + }) + dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) + dataset = dataset.batch(batch_size) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + dist_dataset = distribution.experimental_distribute_dataset(dataset) + with distribution.scope(): + for _ in range(3): + iterator = iter(dist_dataset) + counter = f(iterator) + + self.assertEqual(trace_count[0], 1) + self.assertEqual(counter, 5) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py index 3994db4a541..83375c6ee2b 100644 --- a/tensorflow/python/keras/engine/data_adapter.py +++ b/tensorflow/python/keras/engine/data_adapter.py @@ -1161,12 +1161,7 @@ class DataHandler(object): if self._insufficient_data: # Set by `catch_stop_iteration`. break if self._adapter.should_recreate_iterator(): - if ds_context.has_strategy(): - # TODO(b/138326910): remove this when MultiDeviceIterator is a - # CompositeTensor (unless this is more efficient) - data_iterator._initializer # pylint: disable=pointless-statement, protected-access - else: - data_iterator = iter(self._dataset) + data_iterator = iter(self._dataset) yield epoch, data_iterator self._adapter.on_epoch_end()