[tf.data] Co-locate placement of iterator ops with the final dataset transformation.
PiperOrigin-RevId: 330841511 Change-Id: I3486226f7b50015396317cf8175fe185fd2ac7fd
This commit is contained in:
parent
f56f3a7391
commit
eb461280fe
tensorflow/python/data
@ -150,6 +150,17 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertDatasetProduces(device_dataset, list(range(10)))
|
self.assertDatasetProduces(device_dataset, list(range(10)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
|
def testPrefetchToDeviceCorrectPlacement(self):
|
||||||
|
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||||
|
|
||||||
|
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||||
|
|
||||||
@combinations.generate(test_base.graph_only_combinations())
|
@combinations.generate(test_base.graph_only_combinations())
|
||||||
def testPrefetchToDeviceWithReInit(self):
|
def testPrefetchToDeviceWithReInit(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
@ -213,6 +224,48 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(device_dataset._variant_tensor.device,
|
self.assertEqual(device_dataset._variant_tensor.device,
|
||||||
"/job:localhost/replica:0/task:0/device:GPU:0")
|
"/job:localhost/replica:0/task:0/device:GPU:0")
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testIteratorOnDeviceEagerMode(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||||
|
iterator = iter(dataset)
|
||||||
|
data = next(iterator)
|
||||||
|
|
||||||
|
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||||
|
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
|
||||||
|
self.assertIn("gpu:0", data.device.lower())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.graph_only_combinations())
|
||||||
|
def testIteratorOnDeviceGraphModeOneShotIterator(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||||
|
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||||
|
data = iterator.get_next()
|
||||||
|
|
||||||
|
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||||
|
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
|
||||||
|
self.assertIn("gpu:0", data.device.lower())
|
||||||
|
|
||||||
|
@combinations.generate(test_base.graph_only_combinations())
|
||||||
|
def testIteratorOnDeviceGraphModeInitializableIterator(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
dataset = dataset_ops.Dataset.range(10)
|
||||||
|
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||||
|
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||||
|
data = iterator.get_next()
|
||||||
|
|
||||||
|
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||||
|
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
|
||||||
|
self.assertIn("gpu:0", data.device.lower())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -417,7 +417,8 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
|||||||
RuntimeError: If not inside of tf.function and not executing eagerly.
|
RuntimeError: If not inside of tf.function and not executing eagerly.
|
||||||
"""
|
"""
|
||||||
if context.executing_eagerly() or ops.inside_function():
|
if context.executing_eagerly() or ops.inside_function():
|
||||||
return iterator_ops.OwnedIterator(self)
|
with ops.device(self._variant_tensor.device):
|
||||||
|
return iterator_ops.OwnedIterator(self)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||||
"or when eager execution is enabled.")
|
"or when eager execution is enabled.")
|
||||||
@ -2271,7 +2272,8 @@ class DatasetV1(DatasetV2):
|
|||||||
|
|
||||||
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
|
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return iterator_ops.OwnedIterator(self)
|
with ops.device(self._variant_tensor.device):
|
||||||
|
return iterator_ops.OwnedIterator(self)
|
||||||
|
|
||||||
_ensure_same_dataset_graph(self)
|
_ensure_same_dataset_graph(self)
|
||||||
# Now that we create datasets at python object creation time, the capture
|
# Now that we create datasets at python object creation time, the capture
|
||||||
@ -2313,12 +2315,13 @@ class DatasetV1(DatasetV2):
|
|||||||
else:
|
else:
|
||||||
six.reraise(ValueError, err)
|
six.reraise(ValueError, err)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
with ops.device(self._variant_tensor.device):
|
||||||
return iterator_ops.Iterator(
|
# pylint: disable=protected-access
|
||||||
gen_dataset_ops.one_shot_iterator(
|
return iterator_ops.Iterator(
|
||||||
dataset_factory=_make_dataset, **self._flat_structure), None,
|
gen_dataset_ops.one_shot_iterator(
|
||||||
get_legacy_output_types(self), get_legacy_output_shapes(self),
|
dataset_factory=_make_dataset, **self._flat_structure), None,
|
||||||
get_legacy_output_classes(self))
|
get_legacy_output_types(self), get_legacy_output_shapes(self),
|
||||||
|
get_legacy_output_classes(self))
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
None, "This is a deprecated API that should only be used in TF 1 graph "
|
None, "This is a deprecated API that should only be used in TF 1 graph "
|
||||||
@ -2378,16 +2381,20 @@ class DatasetV1(DatasetV2):
|
|||||||
dataset = self._apply_options()
|
dataset = self._apply_options()
|
||||||
if shared_name is None:
|
if shared_name is None:
|
||||||
shared_name = ""
|
shared_name = ""
|
||||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
|
||||||
container="", shared_name=shared_name, **self._flat_structure)
|
with ops.device(self._variant_tensor.device):
|
||||||
with ops.colocate_with(iterator_resource):
|
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||||
|
container="", shared_name=shared_name, **self._flat_structure)
|
||||||
|
|
||||||
initializer = gen_dataset_ops.make_iterator(
|
initializer = gen_dataset_ops.make_iterator(
|
||||||
dataset._variant_tensor, # pylint: disable=protected-access
|
dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
iterator_resource)
|
iterator_resource)
|
||||||
# pylint: disable=protected-access
|
|
||||||
return iterator_ops.Iterator(
|
# pylint: disable=protected-access
|
||||||
iterator_resource, initializer, get_legacy_output_types(dataset),
|
return iterator_ops.Iterator(iterator_resource, initializer,
|
||||||
get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
|
get_legacy_output_types(dataset),
|
||||||
|
get_legacy_output_shapes(dataset),
|
||||||
|
get_legacy_output_classes(dataset))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
|
@ -371,9 +371,11 @@ class Iterator(trackable.Trackable):
|
|||||||
raise TypeError("Expected output shapes compatible with %r but got "
|
raise TypeError("Expected output shapes compatible with %r but got "
|
||||||
"dataset with output shapes %r." %
|
"dataset with output shapes %r." %
|
||||||
(self.output_shapes, dataset_output_shapes))
|
(self.output_shapes, dataset_output_shapes))
|
||||||
with ops.colocate_with(self._iterator_resource):
|
|
||||||
|
with ops.device(self._iterator_resource.device):
|
||||||
|
# pylint: disable=protected-access
|
||||||
return gen_dataset_ops.make_iterator(
|
return gen_dataset_ops.make_iterator(
|
||||||
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
|
dataset._variant_tensor, self._iterator_resource, name=name)
|
||||||
|
|
||||||
def get_next(self, name=None):
|
def get_next(self, name=None):
|
||||||
"""Returns a nested structure of `tf.Tensor`s representing the next element.
|
"""Returns a nested structure of `tf.Tensor`s representing the next element.
|
||||||
@ -423,13 +425,14 @@ class Iterator(trackable.Trackable):
|
|||||||
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
|
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
|
||||||
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
|
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
with ops.device(self._iterator_resource.device):
|
||||||
flat_ret = gen_dataset_ops.iterator_get_next(
|
# pylint: disable=protected-access
|
||||||
self._iterator_resource,
|
flat_ret = gen_dataset_ops.iterator_get_next(
|
||||||
output_types=self._flat_tensor_types,
|
self._iterator_resource,
|
||||||
output_shapes=self._flat_tensor_shapes,
|
output_types=self._flat_tensor_types,
|
||||||
name=name)
|
output_shapes=self._flat_tensor_shapes,
|
||||||
return structure.from_tensor_list(self._element_spec, flat_ret)
|
name=name)
|
||||||
|
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||||
|
|
||||||
def get_next_as_optional(self):
|
def get_next_as_optional(self):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user