[tf.data] Co-locate placement of iterator ops with the final dataset transformation.

PiperOrigin-RevId: 330841511
Change-Id: I3486226f7b50015396317cf8175fe185fd2ac7fd
This commit is contained in:
Jiri Simsa 2020-09-09 18:25:24 -07:00 committed by TensorFlower Gardener
parent f56f3a7391
commit eb461280fe
3 changed files with 87 additions and 24 deletions

View File

@ -150,6 +150,17 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(device_dataset, list(range(10)))
@combinations.generate(test_base.default_test_combinations())
def testPrefetchToDeviceCorrectPlacement(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
@combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
@ -213,6 +224,48 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(device_dataset._variant_tensor.device,
"/job:localhost/replica:0/task:0/device:GPU:0")
@combinations.generate(test_base.eager_only_combinations())
def testIteratorOnDeviceEagerMode(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = iter(dataset)
data = next(iterator)
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
self.assertIn("gpu:0", data.device.lower())
@combinations.generate(test_base.graph_only_combinations())
def testIteratorOnDeviceGraphModeOneShotIterator(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = dataset_ops.make_one_shot_iterator(dataset)
data = iterator.get_next()
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
self.assertIn("gpu:0", data.device.lower())
@combinations.generate(test_base.graph_only_combinations())
def testIteratorOnDeviceGraphModeInitializableIterator(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = dataset_ops.make_initializable_iterator(dataset)
data = iterator.get_next()
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
self.assertIn("gpu:0", data.device.lower())
if __name__ == "__main__":
test.main()

View File

@ -417,7 +417,8 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
RuntimeError: If not inside of tf.function and not executing eagerly.
"""
if context.executing_eagerly() or ops.inside_function():
return iterator_ops.OwnedIterator(self)
with ops.device(self._variant_tensor.device):
return iterator_ops.OwnedIterator(self)
else:
raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.")
@ -2271,7 +2272,8 @@ class DatasetV1(DatasetV2):
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
if context.executing_eagerly():
return iterator_ops.OwnedIterator(self)
with ops.device(self._variant_tensor.device):
return iterator_ops.OwnedIterator(self)
_ensure_same_dataset_graph(self)
# Now that we create datasets at python object creation time, the capture
@ -2313,12 +2315,13 @@ class DatasetV1(DatasetV2):
else:
six.reraise(ValueError, err)
# pylint: disable=protected-access
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
dataset_factory=_make_dataset, **self._flat_structure), None,
get_legacy_output_types(self), get_legacy_output_shapes(self),
get_legacy_output_classes(self))
with ops.device(self._variant_tensor.device):
# pylint: disable=protected-access
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
dataset_factory=_make_dataset, **self._flat_structure), None,
get_legacy_output_types(self), get_legacy_output_shapes(self),
get_legacy_output_classes(self))
@deprecation.deprecated(
None, "This is a deprecated API that should only be used in TF 1 graph "
@ -2378,16 +2381,20 @@ class DatasetV1(DatasetV2):
dataset = self._apply_options()
if shared_name is None:
shared_name = ""
iterator_resource = gen_dataset_ops.iterator_v2(
container="", shared_name=shared_name, **self._flat_structure)
with ops.colocate_with(iterator_resource):
with ops.device(self._variant_tensor.device):
iterator_resource = gen_dataset_ops.iterator_v2(
container="", shared_name=shared_name, **self._flat_structure)
initializer = gen_dataset_ops.make_iterator(
dataset._variant_tensor, # pylint: disable=protected-access
iterator_resource)
# pylint: disable=protected-access
return iterator_ops.Iterator(
iterator_resource, initializer, get_legacy_output_types(dataset),
get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
# pylint: disable=protected-access
return iterator_ops.Iterator(iterator_resource, initializer,
get_legacy_output_types(dataset),
get_legacy_output_shapes(dataset),
get_legacy_output_classes(dataset))
@property
@deprecation.deprecated(

View File

@ -371,9 +371,11 @@ class Iterator(trackable.Trackable):
raise TypeError("Expected output shapes compatible with %r but got "
"dataset with output shapes %r." %
(self.output_shapes, dataset_output_shapes))
with ops.colocate_with(self._iterator_resource):
with ops.device(self._iterator_resource.device):
# pylint: disable=protected-access
return gen_dataset_ops.make_iterator(
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
dataset._variant_tensor, self._iterator_resource, name=name)
def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s representing the next element.
@ -423,13 +425,14 @@ class Iterator(trackable.Trackable):
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
# pylint: disable=protected-access
flat_ret = gen_dataset_ops.iterator_get_next(
self._iterator_resource,
output_types=self._flat_tensor_types,
output_shapes=self._flat_tensor_shapes,
name=name)
return structure.from_tensor_list(self._element_spec, flat_ret)
with ops.device(self._iterator_resource.device):
# pylint: disable=protected-access
flat_ret = gen_dataset_ops.iterator_get_next(
self._iterator_resource,
output_types=self._flat_tensor_types,
output_shapes=self._flat_tensor_shapes,
name=name)
return structure.from_tensor_list(self._element_spec, flat_ret)
def get_next_as_optional(self):
# pylint: disable=protected-access