Adding CompositeTensor support for MultiDeviceIteratorV2.

PiperOrigin-RevId: 258893591
This commit is contained in:
Rohan Jain 2019-07-18 19:47:41 -07:00 committed by TensorFlower Gardener
parent e5ced34f45
commit 11f1d8ccd8
2 changed files with 162 additions and 53 deletions

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@ -427,6 +428,35 @@ class MultiDeviceIteratorV2Test(test_base.DatasetTestBase):
for i, el in enumerate(multi_device_iterator):
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
@test_util.run_v2_only
def testLimitedRetracing(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
trace_count = [0]
@def_function.function
def f(iterator):
trace_count[0] += 1
counter = np.int64(0)
for _ in range(5):
elem = next(iterator)
counter += elem[0]
counter += elem[1]
return counter
dataset = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(20)
for _ in range(10):
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIteratorV2(
dataset, ["/cpu:0", "/gpu:0"])
self.assertEqual(self.evaluate(f(multi_device_iterator)), 45)
multi_device_iterator2 = multi_device_iterator_ops.MultiDeviceIteratorV2(
dataset2, ["/cpu:0", "/gpu:0"])
self.assertEqual(self.evaluate(f(multi_device_iterator2)), 45)
self.assertEqual(trace_count[0], 1)
if __name__ == "__main__":
ops.enable_eager_execution(

View File

@ -22,10 +22,12 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
@ -402,15 +404,67 @@ class MultiDeviceIteratorResourceDeleter(object):
deleter=self._deleter)
class MultiDeviceIteratorV2(object):
class MultiDeviceIteratorSpec(type_spec.TypeSpec):
"""Type specification for `MultiDeviceIteratorV2`."""
__slots__ = ["_devices", "_source_device", "_element_spec"]
def __init__(self, devices, source_device, element_spec):
self._devices = devices
self._source_device = source_device
self._element_spec = element_spec
@property
def value_type(self):
return MultiDeviceIteratorV2
def _serialize(self):
return (tuple(self._devices), self._source_device, self._element_spec)
@property
def _component_specs(self):
specs = [
tensor_spec.TensorSpec([], dtypes.resource),
tensor_spec.TensorSpec([], dtypes.scalar)
]
for _ in range(len(self._devices)):
specs.append(iterator_ops.IteratorSpec(self._element_spec))
return specs
def _to_components(self, value):
# pylint: disable=protected-access
c = [value._multi_device_iterator_resource, value._deleter]
c.extend(value._device_iterators)
return c
def _from_components(self, components):
return MultiDeviceIteratorV2(
dataset=None,
devices=self._devices,
source_device=self._source_device,
components=components,
element_spec=self._element_spec)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return MultiDeviceIteratorSpec(
value._devices,
value._source_device,
value.element_spec)
class MultiDeviceIteratorV2(composite_tensor.CompositeTensor):
"""An iterator over multiple devices."""
def __init__(self,
dataset,
devices,
dataset=None,
devices=None,
max_buffer_size=1,
prefetch_buffer_size=1,
source_device="/cpu:0"):
source_device="/cpu:0",
components=None,
element_spec=None):
"""Constructs a MultiDeviceIteratorV2 object.
Args:
@ -422,6 +476,9 @@ class MultiDeviceIteratorV2(object):
source_device: The host device to place the `dataset` on. In order to
prevent deadlocks, if the prefetch_buffer_size is greater than the
max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
components: Tensor components to construct the MultiDeviceIterator from.
element_spec: A nested structure of `TypeSpec` objects that
represents the type specification of elements of the iterator.
Raises:
RuntimeError: If executed in graph mode or outside of function building
@ -431,62 +488,79 @@ class MultiDeviceIteratorV2(object):
not ops.get_default_graph()._building_function): # pylint: disable=protected-access
raise RuntimeError("MultiDeviceIteratorV2 is only supported inside of "
"tf.function or when eager execution is enabled.")
options = dataset_ops.Options()
options.experimental_distribute.num_devices = len(devices)
dataset = dataset.with_options(options)
self._dataset = dataset._apply_options() # pylint: disable=protected-access
self._experimental_slack = dataset.options().experimental_slack
self._devices = devices
self._source_device = source_device
self._source_device_tensor = ops.convert_to_tensor(source_device)
self._max_buffer_size = max_buffer_size
self._prefetch_buffer_size = prefetch_buffer_size
if devices is None:
raise ValueError("`devices` must be provided")
error_message = "Either `dataset` or both `components` and "
"`element_spec` need to be provided."
if self._prefetch_buffer_size > self._max_buffer_size:
self._max_buffer_size = self._prefetch_buffer_size
if dataset is None:
if (components is None or element_spec is None):
raise ValueError(error_message)
self._element_spec = element_spec
self._devices = devices
self._source_device = source_device
self._multi_device_iterator_resource = components[0]
self._deleter = components[1]
self._device_iterators = components[2:]
iterator_handles = []
for it in self._device_iterators:
iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access
else:
if (components is not None or element_spec is not None):
raise ValueError(error_message)
options = dataset_ops.Options()
options.experimental_distribute.num_devices = len(devices)
dataset = dataset.with_options(options)
dataset = dataset._apply_options() # pylint: disable=protected-access
self._element_spec = dataset.element_spec
experimental_slack = dataset.options().experimental_slack
self._devices = devices
self._source_device = source_device
source_device_tensor = ops.convert_to_tensor(self._source_device)
# Create the MultiDeviceIterator.
with ops.device(self._source_device):
self._multi_device_iterator_resource, self._deleter = (
gen_dataset_ops.anonymous_multi_device_iterator(
devices=self._devices, **self._dataset._flat_structure)) # pylint: disable=protected-access
if prefetch_buffer_size > max_buffer_size:
max_buffer_size = prefetch_buffer_size
# The incarnation ID is used to ensure consistency between the per-device
# iterators and the multi-device iterator.
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
self._dataset._variant_tensor, # pylint: disable=protected-access
self._multi_device_iterator_resource,
max_buffer_size=self._max_buffer_size)
# Create the MultiDeviceIterator.
with ops.device(self._source_device):
self._multi_device_iterator_resource, self._deleter = (
gen_dataset_ops.anonymous_multi_device_iterator(
devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access
self._prototype_device_datasets = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
self._incarnation_id,
self._source_device_tensor,
self._dataset.element_spec)
self._prototype_device_datasets.append(ds)
# The incarnation ID is used to ensure consistency between the
# per-device iterators and the multi-device iterator.
incarnation_id = gen_dataset_ops.multi_device_iterator_init(
dataset._variant_tensor, # pylint: disable=protected-access
self._multi_device_iterator_resource,
max_buffer_size=max_buffer_size)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
# initialize the device side of the pipeline. This would allow the
# MultiDeviceIterator to choose, for example, to move some transformations
# into the device side from its input. It might be useful in rewriting.
# Create the per device iterators.
self._device_iterators = []
self._iterator_handles = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _create_device_dataset(self._prototype_device_datasets[i],
self._incarnation_id,
self._prefetch_buffer_size,
self._experimental_slack)
iterator = iter(ds)
self._device_iterators.append(iterator)
self._iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access
prototype_device_datasets = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
incarnation_id, source_device_tensor,
dataset.element_spec)
prototype_device_datasets.append(ds)
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
# initialize the device side of the pipeline. This would allow the
# MultiDeviceIterator to choose, for example, to move some transformations
# into the device side from its input. It might be useful in rewriting.
# Create the per device iterators.
self._device_iterators = []
iterator_handles = []
for i, device in enumerate(self._devices):
with ops.device(device):
ds = _create_device_dataset(prototype_device_datasets[i],
incarnation_id, prefetch_buffer_size,
experimental_slack)
iterator = iter(ds)
self._device_iterators.append(iterator)
iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access
self._resource_deleter = MultiDeviceIteratorResourceDeleter(
multi_device_iterator=self._multi_device_iterator_resource,
iterators=self._iterator_handles,
iterators=iterator_handles,
device=self._source_device,
deleter=self._deleter)
@ -524,4 +598,9 @@ class MultiDeviceIteratorV2(object):
@property
def element_spec(self):
return self._dataset.element_spec
return self._element_spec
@property
def _type_spec(self):
return MultiDeviceIteratorSpec(self._devices, self._source_device,
self._element_spec)