Adding CompositeTensor support for MultiDeviceIteratorV2.
PiperOrigin-RevId: 258893591
This commit is contained in:
parent
e5ced34f45
commit
11f1d8ccd8
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
@ -427,6 +428,35 @@ class MultiDeviceIteratorV2Test(test_base.DatasetTestBase):
|
||||
for i, el in enumerate(multi_device_iterator):
|
||||
self.assertEqual([i * 2, i * 2 + 1], [el[0].numpy(), el[1].numpy()])
|
||||
|
||||
@test_util.run_v2_only
|
||||
def testLimitedRetracing(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
trace_count = [0]
|
||||
|
||||
@def_function.function
|
||||
def f(iterator):
|
||||
trace_count[0] += 1
|
||||
counter = np.int64(0)
|
||||
for _ in range(5):
|
||||
elem = next(iterator)
|
||||
counter += elem[0]
|
||||
counter += elem[1]
|
||||
return counter
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset2 = dataset_ops.Dataset.range(20)
|
||||
|
||||
for _ in range(10):
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIteratorV2(
|
||||
dataset, ["/cpu:0", "/gpu:0"])
|
||||
self.assertEqual(self.evaluate(f(multi_device_iterator)), 45)
|
||||
multi_device_iterator2 = multi_device_iterator_ops.MultiDeviceIteratorV2(
|
||||
dataset2, ["/cpu:0", "/gpu:0"])
|
||||
self.assertEqual(self.evaluate(f(multi_device_iterator2)), 45)
|
||||
self.assertEqual(trace_count[0], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution(
|
||||
|
@ -22,10 +22,12 @@ from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
@ -402,15 +404,67 @@ class MultiDeviceIteratorResourceDeleter(object):
|
||||
deleter=self._deleter)
|
||||
|
||||
|
||||
class MultiDeviceIteratorV2(object):
|
||||
class MultiDeviceIteratorSpec(type_spec.TypeSpec):
|
||||
"""Type specification for `MultiDeviceIteratorV2`."""
|
||||
|
||||
__slots__ = ["_devices", "_source_device", "_element_spec"]
|
||||
|
||||
def __init__(self, devices, source_device, element_spec):
|
||||
self._devices = devices
|
||||
self._source_device = source_device
|
||||
self._element_spec = element_spec
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return MultiDeviceIteratorV2
|
||||
|
||||
def _serialize(self):
|
||||
return (tuple(self._devices), self._source_device, self._element_spec)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
specs = [
|
||||
tensor_spec.TensorSpec([], dtypes.resource),
|
||||
tensor_spec.TensorSpec([], dtypes.scalar)
|
||||
]
|
||||
for _ in range(len(self._devices)):
|
||||
specs.append(iterator_ops.IteratorSpec(self._element_spec))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
# pylint: disable=protected-access
|
||||
c = [value._multi_device_iterator_resource, value._deleter]
|
||||
c.extend(value._device_iterators)
|
||||
return c
|
||||
|
||||
def _from_components(self, components):
|
||||
return MultiDeviceIteratorV2(
|
||||
dataset=None,
|
||||
devices=self._devices,
|
||||
source_device=self._source_device,
|
||||
components=components,
|
||||
element_spec=self._element_spec)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# pylint: disable=protected-access
|
||||
return MultiDeviceIteratorSpec(
|
||||
value._devices,
|
||||
value._source_device,
|
||||
value.element_spec)
|
||||
|
||||
|
||||
class MultiDeviceIteratorV2(composite_tensor.CompositeTensor):
|
||||
"""An iterator over multiple devices."""
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
devices,
|
||||
dataset=None,
|
||||
devices=None,
|
||||
max_buffer_size=1,
|
||||
prefetch_buffer_size=1,
|
||||
source_device="/cpu:0"):
|
||||
source_device="/cpu:0",
|
||||
components=None,
|
||||
element_spec=None):
|
||||
"""Constructs a MultiDeviceIteratorV2 object.
|
||||
|
||||
Args:
|
||||
@ -422,6 +476,9 @@ class MultiDeviceIteratorV2(object):
|
||||
source_device: The host device to place the `dataset` on. In order to
|
||||
prevent deadlocks, if the prefetch_buffer_size is greater than the
|
||||
max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
|
||||
components: Tensor components to construct the MultiDeviceIterator from.
|
||||
element_spec: A nested structure of `TypeSpec` objects that
|
||||
represents the type specification of elements of the iterator.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If executed in graph mode or outside of function building
|
||||
@ -431,62 +488,79 @@ class MultiDeviceIteratorV2(object):
|
||||
not ops.get_default_graph()._building_function): # pylint: disable=protected-access
|
||||
raise RuntimeError("MultiDeviceIteratorV2 is only supported inside of "
|
||||
"tf.function or when eager execution is enabled.")
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.num_devices = len(devices)
|
||||
dataset = dataset.with_options(options)
|
||||
self._dataset = dataset._apply_options() # pylint: disable=protected-access
|
||||
self._experimental_slack = dataset.options().experimental_slack
|
||||
self._devices = devices
|
||||
self._source_device = source_device
|
||||
self._source_device_tensor = ops.convert_to_tensor(source_device)
|
||||
self._max_buffer_size = max_buffer_size
|
||||
self._prefetch_buffer_size = prefetch_buffer_size
|
||||
if devices is None:
|
||||
raise ValueError("`devices` must be provided")
|
||||
error_message = "Either `dataset` or both `components` and "
|
||||
"`element_spec` need to be provided."
|
||||
|
||||
if self._prefetch_buffer_size > self._max_buffer_size:
|
||||
self._max_buffer_size = self._prefetch_buffer_size
|
||||
if dataset is None:
|
||||
if (components is None or element_spec is None):
|
||||
raise ValueError(error_message)
|
||||
self._element_spec = element_spec
|
||||
self._devices = devices
|
||||
self._source_device = source_device
|
||||
self._multi_device_iterator_resource = components[0]
|
||||
self._deleter = components[1]
|
||||
self._device_iterators = components[2:]
|
||||
iterator_handles = []
|
||||
for it in self._device_iterators:
|
||||
iterator_handles.append(it._iterator_resource) # pylint: disable=protected-access
|
||||
else:
|
||||
if (components is not None or element_spec is not None):
|
||||
raise ValueError(error_message)
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.num_devices = len(devices)
|
||||
dataset = dataset.with_options(options)
|
||||
dataset = dataset._apply_options() # pylint: disable=protected-access
|
||||
self._element_spec = dataset.element_spec
|
||||
experimental_slack = dataset.options().experimental_slack
|
||||
self._devices = devices
|
||||
self._source_device = source_device
|
||||
source_device_tensor = ops.convert_to_tensor(self._source_device)
|
||||
|
||||
# Create the MultiDeviceIterator.
|
||||
with ops.device(self._source_device):
|
||||
self._multi_device_iterator_resource, self._deleter = (
|
||||
gen_dataset_ops.anonymous_multi_device_iterator(
|
||||
devices=self._devices, **self._dataset._flat_structure)) # pylint: disable=protected-access
|
||||
if prefetch_buffer_size > max_buffer_size:
|
||||
max_buffer_size = prefetch_buffer_size
|
||||
|
||||
# The incarnation ID is used to ensure consistency between the per-device
|
||||
# iterators and the multi-device iterator.
|
||||
self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
|
||||
self._dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._multi_device_iterator_resource,
|
||||
max_buffer_size=self._max_buffer_size)
|
||||
# Create the MultiDeviceIterator.
|
||||
with ops.device(self._source_device):
|
||||
self._multi_device_iterator_resource, self._deleter = (
|
||||
gen_dataset_ops.anonymous_multi_device_iterator(
|
||||
devices=self._devices, **dataset._flat_structure)) # pylint: disable=protected-access
|
||||
|
||||
self._prototype_device_datasets = []
|
||||
for i, device in enumerate(self._devices):
|
||||
with ops.device(device):
|
||||
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
|
||||
self._incarnation_id,
|
||||
self._source_device_tensor,
|
||||
self._dataset.element_spec)
|
||||
self._prototype_device_datasets.append(ds)
|
||||
# The incarnation ID is used to ensure consistency between the
|
||||
# per-device iterators and the multi-device iterator.
|
||||
incarnation_id = gen_dataset_ops.multi_device_iterator_init(
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
self._multi_device_iterator_resource,
|
||||
max_buffer_size=max_buffer_size)
|
||||
|
||||
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
|
||||
# initialize the device side of the pipeline. This would allow the
|
||||
# MultiDeviceIterator to choose, for example, to move some transformations
|
||||
# into the device side from its input. It might be useful in rewriting.
|
||||
# Create the per device iterators.
|
||||
self._device_iterators = []
|
||||
self._iterator_handles = []
|
||||
for i, device in enumerate(self._devices):
|
||||
with ops.device(device):
|
||||
ds = _create_device_dataset(self._prototype_device_datasets[i],
|
||||
self._incarnation_id,
|
||||
self._prefetch_buffer_size,
|
||||
self._experimental_slack)
|
||||
iterator = iter(ds)
|
||||
self._device_iterators.append(iterator)
|
||||
self._iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access
|
||||
prototype_device_datasets = []
|
||||
for i, device in enumerate(self._devices):
|
||||
with ops.device(device):
|
||||
ds = _PerDeviceGenerator(i, self._multi_device_iterator_resource,
|
||||
incarnation_id, source_device_tensor,
|
||||
dataset.element_spec)
|
||||
prototype_device_datasets.append(ds)
|
||||
|
||||
# TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
|
||||
# initialize the device side of the pipeline. This would allow the
|
||||
# MultiDeviceIterator to choose, for example, to move some transformations
|
||||
# into the device side from its input. It might be useful in rewriting.
|
||||
# Create the per device iterators.
|
||||
self._device_iterators = []
|
||||
iterator_handles = []
|
||||
for i, device in enumerate(self._devices):
|
||||
with ops.device(device):
|
||||
ds = _create_device_dataset(prototype_device_datasets[i],
|
||||
incarnation_id, prefetch_buffer_size,
|
||||
experimental_slack)
|
||||
iterator = iter(ds)
|
||||
self._device_iterators.append(iterator)
|
||||
iterator_handles.append(iterator._iterator_resource) # pylint: disable=protected-access
|
||||
|
||||
self._resource_deleter = MultiDeviceIteratorResourceDeleter(
|
||||
multi_device_iterator=self._multi_device_iterator_resource,
|
||||
iterators=self._iterator_handles,
|
||||
iterators=iterator_handles,
|
||||
device=self._source_device,
|
||||
deleter=self._deleter)
|
||||
|
||||
@ -524,4 +598,9 @@ class MultiDeviceIteratorV2(object):
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return self._dataset.element_spec
|
||||
return self._element_spec
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return MultiDeviceIteratorSpec(self._devices, self._source_device,
|
||||
self._element_spec)
|
||||
|
Loading…
x
Reference in New Issue
Block a user