diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py index b7392e5d6d6..69ecccbd596 100644 --- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py +++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from absl.testing import parameterized +import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -427,6 +428,35 @@ class MultiDeviceIteratorV2Test(test_base.DatasetTestBase): for i, el in enumerate(multi_device_iterator): self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()]) + @test_util.run_v2_only + def testLimitedRetracing(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + trace_count = [0] + + @def_function.function + def f(iterator): + trace_count[0] += 1 + counter = np.int64(0) + for _ in range(5): + elem = next(iterator) + counter += elem[0] + counter += elem[1] + return counter + + dataset = dataset_ops.Dataset.range(10) + dataset2 = dataset_ops.Dataset.range(20) + + for _ in range(10): + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIteratorV2( + dataset, ["/cpu:0", "/gpu:0"]) + self.assertEqual(self.evaluate(f(multi_device_iterator)), 45) + multi_device_iterator2 = multi_device_iterator_ops.MultiDeviceIteratorV2( + dataset2, ["/cpu:0", "/gpu:0"]) + self.assertEqual(self.evaluate(f(multi_device_iterator2)), 45) + self.assertEqual(trace_count[0], 1) + if __name__ == "__main__": ops.enable_eager_execution( diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index fdbbcae4c47..0a5fd456645 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -22,10 +22,12 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import structure from tensorflow.python.eager import context from tensorflow.python.eager import function +from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops @@ -402,15 +404,67 @@ class MultiDeviceIteratorResourceDeleter(object): deleter=self._deleter) -class MultiDeviceIteratorV2(object): +class MultiDeviceIteratorSpec(type_spec.TypeSpec): + """Type specification for `MultiDeviceIteratorV2`.""" + + __slots__ = ["_devices", "_source_device", "_element_spec"] + + def __init__(self, devices, source_device, element_spec): + self._devices = devices + self._source_device = source_device + self._element_spec = element_spec + + @property + def value_type(self): + return MultiDeviceIteratorV2 + + def _serialize(self): + return (tuple(self._devices), self._source_device, self._element_spec) + + @property + def _component_specs(self): + specs = [ + tensor_spec.TensorSpec([], dtypes.resource), + tensor_spec.TensorSpec([], dtypes.scalar) + ] + for _ in range(len(self._devices)): + specs.append(iterator_ops.IteratorSpec(self._element_spec)) + return specs + + def _to_components(self, value): + # pylint: disable=protected-access + c = [value._multi_device_iterator_resource, value._deleter] + c.extend(value._device_iterators) + return c + + def _from_components(self, components): + return MultiDeviceIteratorV2( + dataset=None, + devices=self._devices, + source_device=self._source_device, + components=components, + element_spec=self._element_spec) + + @staticmethod + def from_value(value): + # pylint: disable=protected-access + return MultiDeviceIteratorSpec( + value._devices, + value._source_device, + value.element_spec) + + +class MultiDeviceIteratorV2(composite_tensor.CompositeTensor): """An iterator over multiple devices.""" def __init__(self, - dataset, - devices, + dataset=None, + devices=None, max_buffer_size=1, prefetch_buffer_size=1, - source_device="/cpu:0"): + source_device="/cpu:0", + components=None, + element_spec=None): """Constructs a MultiDeviceIteratorV2 object. Args: @@ -422,6 +476,9 @@ class MultiDeviceIteratorV2(object): source_device: The host device to place the `dataset` on. In order to prevent deadlocks, if the prefetch_buffer_size is greater than the max_buffer_size, we set the max_buffer_size to prefetch_buffer_size. + components: Tensor components to construct the MultiDeviceIterator from. + element_spec: A nested structure of `TypeSpec` objects that + represents the type specification of elements of the iterator. Raises: RuntimeError: If executed in graph mode or outside of function building @@ -431,62 +488,79 @@ class MultiDeviceIteratorV2(object): not ops.get_default_graph()._building_function): # pylint: disable=protected-access raise RuntimeError("MultiDeviceIteratorV2 is only supported inside of " "tf.function or when eager execution is enabled.") - options = dataset_ops.Options() - options.experimental_distribute.num_devices = len(devices) - dataset = dataset.with_options(options) - self._dataset = dataset._apply_options() # pylint: disable=protected-access - self._experimental_slack = dataset.options().experimental_slack - self._devices = devices - self._source_device = source_device - self._source_device_tensor = ops.convert_to_tensor(source_device) - self._max_buffer_size = max_buffer_size - self._prefetch_buffer_size = prefetch_buffer_size + if devices is None: + raise ValueError("`devices` must be provided") + error_message = "Either `dataset` or both `components` and " + "`element_spec` need to be provided." - if self._prefetch_buffer_size > self._max_buffer_size: - self._max_buffer_size = self._prefetch_buffer_size + if dataset is None: + if (components is None or element_spec is None): + raise ValueError(error_message) + self._element_spec = element_spec + self._devices = devices + self._source_device = source_device + self._multi_device_iterator_resource = components[0] + self._deleter = components[1] + self._device_iterators = components[2:] + iterator_handles = [] + for it in self._device_iterators: + iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access + else: + if (components is not None or element_spec is not None): + raise ValueError(error_message) + options = dataset_ops.Options() + options.experimental_distribute.num_devices = len(devices) + dataset = dataset.with_options(options) + dataset = dataset._apply_options() # pylint: disable=protected-access + self._element_spec = dataset.element_spec + experimental_slack = dataset.options().experimental_slack + self._devices = devices + self._source_device = source_device + source_device_tensor = ops.convert_to_tensor(self._source_device) - # Create the MultiDeviceIterator. - with ops.device(self._source_device): - self._multi_device_iterator_resource, self._deleter = ( - gen_dataset_ops.anonymous_multi_device_iterator( - devices=self._devices, **self._dataset._flat_structure)) # pylint: disable=protected-access + if prefetch_buffer_size > max_buffer_size: + max_buffer_size = prefetch_buffer_size - # The incarnation ID is used to ensure consistency between the per-device - # iterators and the multi-device iterator. - self._incarnation_id = gen_dataset_ops.multi_device_iterator_init( - self._dataset._variant_tensor, # pylint: disable=protected-access - self._multi_device_iterator_resource, - max_buffer_size=self._max_buffer_size) + # Create the MultiDeviceIterator. + with ops.device(self._source_device): + self._multi_device_iterator_resource, self._deleter = ( + gen_dataset_ops.anonymous_multi_device_iterator( + devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access - self._prototype_device_datasets = [] - for i, device in enumerate(self._devices): - with ops.device(device): - ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, - self._incarnation_id, - self._source_device_tensor, - self._dataset.element_spec) - self._prototype_device_datasets.append(ds) + # The incarnation ID is used to ensure consistency between the + # per-device iterators and the multi-device iterator. + incarnation_id = gen_dataset_ops.multi_device_iterator_init( + dataset._variant_tensor, # pylint: disable=protected-access + self._multi_device_iterator_resource, + max_buffer_size=max_buffer_size) - # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to - # initialize the device side of the pipeline. This would allow the - # MultiDeviceIterator to choose, for example, to move some transformations - # into the device side from its input. It might be useful in rewriting. - # Create the per device iterators. - self._device_iterators = [] - self._iterator_handles = [] - for i, device in enumerate(self._devices): - with ops.device(device): - ds = _create_device_dataset(self._prototype_device_datasets[i], - self._incarnation_id, - self._prefetch_buffer_size, - self._experimental_slack) - iterator = iter(ds) - self._device_iterators.append(iterator) - self._iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access + prototype_device_datasets = [] + for i, device in enumerate(self._devices): + with ops.device(device): + ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource, + incarnation_id, source_device_tensor, + dataset.element_spec) + prototype_device_datasets.append(ds) + + # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to + # initialize the device side of the pipeline. This would allow the + # MultiDeviceIterator to choose, for example, to move some transformations + # into the device side from its input. It might be useful in rewriting. + # Create the per device iterators. + self._device_iterators = [] + iterator_handles = [] + for i, device in enumerate(self._devices): + with ops.device(device): + ds = _create_device_dataset(prototype_device_datasets[i], + incarnation_id, prefetch_buffer_size, + experimental_slack) + iterator = iter(ds) + self._device_iterators.append(iterator) + iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access self._resource_deleter = MultiDeviceIteratorResourceDeleter( multi_device_iterator=self._multi_device_iterator_resource, - iterators=self._iterator_handles, + iterators=iterator_handles, device=self._source_device, deleter=self._deleter) @@ -524,4 +598,9 @@ class MultiDeviceIteratorV2(object): @property def element_spec(self): - return self._dataset.element_spec + return self._element_spec + + @property + def _type_spec(self): + return MultiDeviceIteratorSpec(self._devices, self._source_device, + self._element_spec)