diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
index 611fbab4b8b..70fb64554a3 100644
--- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
@@ -150,6 +150,17 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
 
     self.assertDatasetProduces(device_dataset, list(range(10)))
 
+  @combinations.generate(test_base.default_test_combinations())
+  def testPrefetchToDeviceCorrectPlacement(self):
+
+    if not test_util.is_gpu_available():
+      self.skipTest("No GPU available")
+
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
+
+    self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
+
   @combinations.generate(test_base.graph_only_combinations())
   def testPrefetchToDeviceWithReInit(self):
     host_dataset = dataset_ops.Dataset.range(10)
@@ -213,6 +224,48 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
     self.assertEqual(device_dataset._variant_tensor.device,
                      "/job:localhost/replica:0/task:0/device:GPU:0")
 
+  @combinations.generate(test_base.eager_only_combinations())
+  def testIteratorOnDeviceEagerMode(self):
+    if not test_util.is_gpu_available():
+      self.skipTest("No GPU available")
+
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
+    iterator = iter(dataset)
+    data = next(iterator)
+
+    self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
+    self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
+    self.assertIn("gpu:0", data.device.lower())
+
+  @combinations.generate(test_base.graph_only_combinations())
+  def testIteratorOnDeviceGraphModeOneShotIterator(self):
+    if not test_util.is_gpu_available():
+      self.skipTest("No GPU available")
+
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
+    iterator = dataset_ops.make_one_shot_iterator(dataset)
+    data = iterator.get_next()
+
+    self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
+    self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
+    self.assertIn("gpu:0", data.device.lower())
+
+  @combinations.generate(test_base.graph_only_combinations())
+  def testIteratorOnDeviceGraphModeInitializableIterator(self):
+    if not test_util.is_gpu_available():
+      self.skipTest("No GPU available")
+
+    dataset = dataset_ops.Dataset.range(10)
+    dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0"))
+    iterator = dataset_ops.make_initializable_iterator(dataset)
+    data = iterator.get_next()
+
+    self.assertIn("gpu:0", dataset._variant_tensor.device.lower())
+    self.assertIn("gpu:0", iterator._iterator_resource.device.lower())
+    self.assertIn("gpu:0", data.device.lower())
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 33e0333f493..e3cce9c55cd 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -417,7 +417,8 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
       RuntimeError: If not inside of tf.function and not executing eagerly.
     """
     if context.executing_eagerly() or ops.inside_function():
-      return iterator_ops.OwnedIterator(self)
+      with ops.device(self._variant_tensor.device):
+        return iterator_ops.OwnedIterator(self)
     else:
       raise RuntimeError("__iter__() is only supported inside of tf.function "
                          "or when eager execution is enabled.")
@@ -2271,7 +2272,8 @@ class DatasetV1(DatasetV2):
 
   def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring
     if context.executing_eagerly():
-      return iterator_ops.OwnedIterator(self)
+      with ops.device(self._variant_tensor.device):
+        return iterator_ops.OwnedIterator(self)
 
     _ensure_same_dataset_graph(self)
     # Now that we create datasets at python object creation time, the capture
@@ -2313,12 +2315,13 @@ class DatasetV1(DatasetV2):
       else:
         six.reraise(ValueError, err)
 
-    # pylint: disable=protected-access
-    return iterator_ops.Iterator(
-        gen_dataset_ops.one_shot_iterator(
-            dataset_factory=_make_dataset, **self._flat_structure), None,
-        get_legacy_output_types(self), get_legacy_output_shapes(self),
-        get_legacy_output_classes(self))
+    with ops.device(self._variant_tensor.device):
+      # pylint: disable=protected-access
+      return iterator_ops.Iterator(
+          gen_dataset_ops.one_shot_iterator(
+              dataset_factory=_make_dataset, **self._flat_structure), None,
+          get_legacy_output_types(self), get_legacy_output_shapes(self),
+          get_legacy_output_classes(self))
 
   @deprecation.deprecated(
       None, "This is a deprecated API that should only be used in TF 1 graph "
@@ -2378,16 +2381,20 @@ class DatasetV1(DatasetV2):
     dataset = self._apply_options()
     if shared_name is None:
       shared_name = ""
-    iterator_resource = gen_dataset_ops.iterator_v2(
-        container="", shared_name=shared_name, **self._flat_structure)
-    with ops.colocate_with(iterator_resource):
+
+    with ops.device(self._variant_tensor.device):
+      iterator_resource = gen_dataset_ops.iterator_v2(
+          container="", shared_name=shared_name, **self._flat_structure)
+
       initializer = gen_dataset_ops.make_iterator(
           dataset._variant_tensor,  # pylint: disable=protected-access
           iterator_resource)
-    # pylint: disable=protected-access
-    return iterator_ops.Iterator(
-        iterator_resource, initializer, get_legacy_output_types(dataset),
-        get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
+
+      # pylint: disable=protected-access
+      return iterator_ops.Iterator(iterator_resource, initializer,
+                                   get_legacy_output_types(dataset),
+                                   get_legacy_output_shapes(dataset),
+                                   get_legacy_output_classes(dataset))
 
   @property
   @deprecation.deprecated(
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 479c8d337a0..20adaf2f887 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -371,9 +371,11 @@ class Iterator(trackable.Trackable):
           raise TypeError("Expected output shapes compatible with %r but got "
                           "dataset with output shapes %r." %
                           (self.output_shapes, dataset_output_shapes))
-    with ops.colocate_with(self._iterator_resource):
+
+    with ops.device(self._iterator_resource.device):
+      # pylint: disable=protected-access
       return gen_dataset_ops.make_iterator(
-          dataset._variant_tensor, self._iterator_resource, name=name)  # pylint: disable=protected-access
+          dataset._variant_tensor, self._iterator_resource, name=name)
 
   def get_next(self, name=None):
     """Returns a nested structure of `tf.Tensor`s representing the next element.
@@ -423,13 +425,14 @@ class Iterator(trackable.Trackable):
     if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
       warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
 
-    # pylint: disable=protected-access
-    flat_ret = gen_dataset_ops.iterator_get_next(
-        self._iterator_resource,
-        output_types=self._flat_tensor_types,
-        output_shapes=self._flat_tensor_shapes,
-        name=name)
-    return structure.from_tensor_list(self._element_spec, flat_ret)
+    with ops.device(self._iterator_resource.device):
+      # pylint: disable=protected-access
+      flat_ret = gen_dataset_ops.iterator_get_next(
+          self._iterator_resource,
+          output_types=self._flat_tensor_types,
+          output_shapes=self._flat_tensor_shapes,
+          name=name)
+      return structure.from_tensor_list(self._element_spec, flat_ret)
 
   def get_next_as_optional(self):
     # pylint: disable=protected-access