[tf.data] Co-locate placement of iterator ops with the final dataset transformation.
PiperOrigin-RevId: 330841511 Change-Id: I3486226f7b50015396317cf8175fe185fd2ac7fd
This commit is contained in:
parent
f56f3a7391
commit
eb461280fe
@ -150,6 +150,17 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertDatasetProduces(device_dataset, list(range(10)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPrefetchToDeviceCorrectPlacement(self):
|
||||
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
@ -213,6 +224,48 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(device_dataset._variant_tensor.device,
|
||||
"/job:localhost/replica:0/task:0/device:GPU:0")
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testIteratorOnDeviceEagerMode(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
iterator = iter(dataset)
|
||||
data = next(iterator)
|
||||
|
||||
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
|
||||
self.assertIn("gpu:0", data.device.lower())
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorOnDeviceGraphModeOneShotIterator(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
data = iterator.get_next()
|
||||
|
||||
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
|
||||
self.assertIn("gpu:0", data.device.lower())
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorOnDeviceGraphModeInitializableIterator(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
data = iterator.get_next()
|
||||
|
||||
self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
|
||||
self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
|
||||
self.assertIn("gpu:0", data.device.lower())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -417,7 +417,8 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
|
||||
RuntimeError: If not inside of tf.function and not executing eagerly.
|
||||
"""
|
||||
if context.executing_eagerly() or ops.inside_function():
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
with ops.device(self._variant_tensor.device):
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
else:
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
@ -2271,7 +2272,8 @@ class DatasetV1(DatasetV2):
|
||||
|
||||
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
|
||||
if context.executing_eagerly():
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
with ops.device(self._variant_tensor.device):
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
|
||||
_ensure_same_dataset_graph(self)
|
||||
# Now that we create datasets at python object creation time, the capture
|
||||
@ -2313,12 +2315,13 @@ class DatasetV1(DatasetV2):
|
||||
else:
|
||||
six.reraise(ValueError, err)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(
|
||||
gen_dataset_ops.one_shot_iterator(
|
||||
dataset_factory=_make_dataset, **self._flat_structure), None,
|
||||
get_legacy_output_types(self), get_legacy_output_shapes(self),
|
||||
get_legacy_output_classes(self))
|
||||
with ops.device(self._variant_tensor.device):
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(
|
||||
gen_dataset_ops.one_shot_iterator(
|
||||
dataset_factory=_make_dataset, **self._flat_structure), None,
|
||||
get_legacy_output_types(self), get_legacy_output_shapes(self),
|
||||
get_legacy_output_classes(self))
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "This is a deprecated API that should only be used in TF 1 graph "
|
||||
@ -2378,16 +2381,20 @@ class DatasetV1(DatasetV2):
|
||||
dataset = self._apply_options()
|
||||
if shared_name is None:
|
||||
shared_name = ""
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
with ops.colocate_with(iterator_resource):
|
||||
|
||||
with ops.device(self._variant_tensor.device):
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
|
||||
initializer = gen_dataset_ops.make_iterator(
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
iterator_resource)
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(
|
||||
iterator_resource, initializer, get_legacy_output_types(dataset),
|
||||
get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(iterator_resource, initializer,
|
||||
get_legacy_output_types(dataset),
|
||||
get_legacy_output_shapes(dataset),
|
||||
get_legacy_output_classes(dataset))
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
|
@ -371,9 +371,11 @@ class Iterator(trackable.Trackable):
|
||||
raise TypeError("Expected output shapes compatible with %r but got "
|
||||
"dataset with output shapes %r." %
|
||||
(self.output_shapes, dataset_output_shapes))
|
||||
with ops.colocate_with(self._iterator_resource):
|
||||
|
||||
with ops.device(self._iterator_resource.device):
|
||||
# pylint: disable=protected-access
|
||||
return gen_dataset_ops.make_iterator(
|
||||
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
|
||||
dataset._variant_tensor, self._iterator_resource, name=name)
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns a nested structure of `tf.Tensor`s representing the next element.
|
||||
@ -423,13 +425,14 @@ class Iterator(trackable.Trackable):
|
||||
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
|
||||
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
flat_ret = gen_dataset_ops.iterator_get_next(
|
||||
self._iterator_resource,
|
||||
output_types=self._flat_tensor_types,
|
||||
output_shapes=self._flat_tensor_shapes,
|
||||
name=name)
|
||||
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||
with ops.device(self._iterator_resource.device):
|
||||
# pylint: disable=protected-access
|
||||
flat_ret = gen_dataset_ops.iterator_get_next(
|
||||
self._iterator_resource,
|
||||
output_types=self._flat_tensor_types,
|
||||
output_shapes=self._flat_tensor_shapes,
|
||||
name=name)
|
||||
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||
|
||||
def get_next_as_optional(self):
|
||||
# pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user