diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 611fbab4b8b..70fb64554a3 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -150,6 +150,17 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(device_dataset, list(range(10))) + @combinations.generate(test_base.default_test_combinations()) + def testPrefetchToDeviceCorrectPlacement(self): + + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + + self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) + @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceWithReInit(self): host_dataset = dataset_ops.Dataset.range(10) @@ -213,6 +224,48 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(device_dataset._variant_tensor.device, "/job:localhost/replica:0/task:0/device:GPU:0") + @combinations.generate(test_base.eager_only_combinations()) + def testIteratorOnDeviceEagerMode(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + iterator = iter(dataset) + data = next(iterator) + + self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) + self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) + self.assertIn("gpu:0", data.device.lower()) + + @combinations.generate(test_base.graph_only_combinations()) + def testIteratorOnDeviceGraphModeOneShotIterator(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + iterator = dataset_ops.make_one_shot_iterator(dataset) + data = iterator.get_next() + + self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) + self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) + self.assertIn("gpu:0", data.device.lower()) + + @combinations.generate(test_base.graph_only_combinations()) + def testIteratorOnDeviceGraphModeInitializableIterator(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + dataset = dataset_ops.Dataset.range(10) + dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) + iterator = dataset_ops.make_initializable_iterator(dataset) + data = iterator.get_next() + + self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) + self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) + self.assertIn("gpu:0", data.device.lower()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 33e0333f493..e3cce9c55cd 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -417,7 +417,8 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, RuntimeError: If not inside of tf.function and not executing eagerly. """ if context.executing_eagerly() or ops.inside_function(): - return iterator_ops.OwnedIterator(self) + with ops.device(self._variant_tensor.device): + return iterator_ops.OwnedIterator(self) else: raise RuntimeError("__iter__() is only supported inside of tf.function " "or when eager execution is enabled.") @@ -2271,7 +2272,8 @@ class DatasetV1(DatasetV2): def _make_one_shot_iterator(self): # pylint: disable=missing-docstring if context.executing_eagerly(): - return iterator_ops.OwnedIterator(self) + with ops.device(self._variant_tensor.device): + return iterator_ops.OwnedIterator(self) _ensure_same_dataset_graph(self) # Now that we create datasets at python object creation time, the capture @@ -2313,12 +2315,13 @@ class DatasetV1(DatasetV2): else: six.reraise(ValueError, err) - # pylint: disable=protected-access - return iterator_ops.Iterator( - gen_dataset_ops.one_shot_iterator( - dataset_factory=_make_dataset, **self._flat_structure), None, - get_legacy_output_types(self), get_legacy_output_shapes(self), - get_legacy_output_classes(self)) + with ops.device(self._variant_tensor.device): + # pylint: disable=protected-access + return iterator_ops.Iterator( + gen_dataset_ops.one_shot_iterator( + dataset_factory=_make_dataset, **self._flat_structure), None, + get_legacy_output_types(self), get_legacy_output_shapes(self), + get_legacy_output_classes(self)) @deprecation.deprecated( None, "This is a deprecated API that should only be used in TF 1 graph " @@ -2378,16 +2381,20 @@ class DatasetV1(DatasetV2): dataset = self._apply_options() if shared_name is None: shared_name = "" - iterator_resource = gen_dataset_ops.iterator_v2( - container="", shared_name=shared_name, **self._flat_structure) - with ops.colocate_with(iterator_resource): + + with ops.device(self._variant_tensor.device): + iterator_resource = gen_dataset_ops.iterator_v2( + container="", shared_name=shared_name, **self._flat_structure) + initializer = gen_dataset_ops.make_iterator( dataset._variant_tensor, # pylint: disable=protected-access iterator_resource) - # pylint: disable=protected-access - return iterator_ops.Iterator( - iterator_resource, initializer, get_legacy_output_types(dataset), - get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) + + # pylint: disable=protected-access + return iterator_ops.Iterator(iterator_resource, initializer, + get_legacy_output_types(dataset), + get_legacy_output_shapes(dataset), + get_legacy_output_classes(dataset)) @property @deprecation.deprecated( diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 479c8d337a0..20adaf2f887 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -371,9 +371,11 @@ class Iterator(trackable.Trackable): raise TypeError("Expected output shapes compatible with %r but got " "dataset with output shapes %r." % (self.output_shapes, dataset_output_shapes)) - with ops.colocate_with(self._iterator_resource): + + with ops.device(self._iterator_resource.device): + # pylint: disable=protected-access return gen_dataset_ops.make_iterator( - dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access + dataset._variant_tensor, self._iterator_resource, name=name) def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s representing the next element. @@ -423,13 +425,14 @@ class Iterator(trackable.Trackable): if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) - # pylint: disable=protected-access - flat_ret = gen_dataset_ops.iterator_get_next( - self._iterator_resource, - output_types=self._flat_tensor_types, - output_shapes=self._flat_tensor_shapes, - name=name) - return structure.from_tensor_list(self._element_spec, flat_ret) + with ops.device(self._iterator_resource.device): + # pylint: disable=protected-access + flat_ret = gen_dataset_ops.iterator_get_next( + self._iterator_resource, + output_types=self._flat_tensor_types, + output_shapes=self._flat_tensor_shapes, + name=name) + return structure.from_tensor_list(self._element_spec, flat_ret) def get_next_as_optional(self): # pylint: disable=protected-access