From 5ab3af7a7bc7657f6a5de678c686d12e1194683b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Apr 2020 10:17:02 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 308645783 Change-Id: Iaa11d70b2cac141b4eb067465e78ef4c45384171 --- .../kernel_tests/prefetch_to_device_test.py | 15 ---- tensorflow/python/data/kernel_tests/BUILD | 1 - .../python/data/kernel_tests/iterator_test.py | 73 ------------------- tensorflow/python/data/ops/dataset_ops.py | 55 ++++++-------- tensorflow/python/data/ops/iterator_ops.py | 18 ++--- 5 files changed, 29 insertions(+), 133 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 789b6f21a15..1641243edcc 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -150,21 +150,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces(device_dataset, list(range(10))) - @combinations.generate(test_base.default_test_combinations()) - def testPrefetchToDeviceCorrectPlacement(self): - - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - self.assertTrue(("" == host_dataset._variant_tensor.device or - "cpu:0" in host_dataset._variant_tensor.device.lower())) - - self.assertTrue("gpu:0" in device_dataset._variant_tensor.device.lower()) - @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceWithReInit(self): host_dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index af2725d9fe8..2e01021cfd2 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -374,7 +374,6 @@ cuda_py_test( "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/compat", - "//tensorflow/python/data/experimental/ops:prefetching_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 8a1358a86f2..36689ed75fb 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -25,7 +25,6 @@ import numpy as np from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session -from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops @@ -1017,78 +1016,6 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), 10) - def assert_dataset_placement(self, host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, device_tensor): - - self.assertTrue("cpu:0" in host_dataset._variant_tensor.device.lower() or - host_dataset._variant_tensor.device == "") - self.assertTrue( - "cpu:0" in host_iterator._iterator_resource.device.lower() or - host_iterator._iterator_resource.device == "") - self.assertTrue("cpu:0" in host_tensor.device.lower() or - host_tensor.device == "") - - self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower()) - self.assertIn("gpu:0", device_iterator._iterator_resource.device.lower()) - self.assertIn("gpu:0", device_tensor.device.lower()) - - @combinations.generate(test_base.eager_only_combinations()) - def testIteratorOnDeviceEagerMode(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - host_iterator = iter(host_dataset) - device_iterator = iter(device_dataset) - - host_tensor = next(host_iterator) - device_tensor = next(device_iterator) - - self.assert_dataset_placement(host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, - device_tensor) - - @combinations.generate(test_base.graph_only_combinations()) - def testIteratorOnDeviceGraphModeOneShotIterator(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - host_iterator = dataset_ops.make_one_shot_iterator(host_dataset) - device_iterator = dataset_ops.make_one_shot_iterator(device_dataset) - - host_tensor = host_iterator.get_next() - device_tensor = device_iterator.get_next() - - self.assert_dataset_placement(host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, - device_tensor) - - @combinations.generate(test_base.graph_only_combinations()) - def testIteratorOnDeviceGraphModeInitializableIterator(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - host_iterator = dataset_ops.make_initializable_iterator(host_dataset) - device_iterator = dataset_ops.make_initializable_iterator(device_dataset) - - host_tensor = host_iterator.get_next() - device_tensor = device_iterator.get_next() - - self.assert_dataset_placement(host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, - device_tensor) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index fc662e7f65a..8c6d1f8d454 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -406,8 +406,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): RuntimeError: If not inside of tf.function and not executing eagerly. """ if context.executing_eagerly() or ops.inside_function(): - with ops.device(self._variant_tensor.device): - return iterator_ops.OwnedIterator(self) + return iterator_ops.OwnedIterator(self) else: raise RuntimeError("__iter__() is only supported inside of tf.function " "or when eager execution is enabled.") @@ -482,15 +481,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): if not context.executing_eagerly(): raise RuntimeError("as_numpy_iterator() is not supported while tracing " "functions") - for component_spec in nest.flatten(self.element_spec): if not isinstance(component_spec, tensor_spec.TensorSpec): raise TypeError( "Dataset.as_numpy_iterator() does not support datasets containing " + str(component_spec.value_type)) - with ops.device(self._variant_tensor.device): - return _NumpyIterator(self) + return _NumpyIterator(self) @property def _flat_shapes(self): @@ -2161,10 +2158,8 @@ class DatasetV1(DatasetV2): return self._make_one_shot_iterator() def _make_one_shot_iterator(self): # pylint: disable=missing-docstring - if context.executing_eagerly(): - with ops.device(self._variant_tensor.device): - return iterator_ops.OwnedIterator(self) + return iterator_ops.OwnedIterator(self) _ensure_same_dataset_graph(self) # Now that we create datasets at python object creation time, the capture @@ -2206,13 +2201,12 @@ class DatasetV1(DatasetV2): else: six.reraise(ValueError, err) - with ops.device(self._variant_tensor.device): - # pylint: disable=protected-access - return iterator_ops.Iterator( - gen_dataset_ops.one_shot_iterator( - dataset_factory=_make_dataset, **self._flat_structure), None, - get_legacy_output_types(self), get_legacy_output_shapes(self), - get_legacy_output_classes(self)) + # pylint: disable=protected-access + return iterator_ops.Iterator( + gen_dataset_ops.one_shot_iterator( + dataset_factory=_make_dataset, **self._flat_structure), None, + get_legacy_output_types(self), get_legacy_output_shapes(self), + get_legacy_output_classes(self)) @deprecation.deprecated( None, "This is a deprecated API that should only be used in TF 1 graph " @@ -2272,20 +2266,16 @@ class DatasetV1(DatasetV2): dataset = self._apply_options() if shared_name is None: shared_name = "" - - with ops.device(self._variant_tensor.device): - iterator_resource = gen_dataset_ops.iterator_v2( - container="", shared_name=shared_name, **self._flat_structure) - + iterator_resource = gen_dataset_ops.iterator_v2( + container="", shared_name=shared_name, **self._flat_structure) + with ops.colocate_with(iterator_resource): initializer = gen_dataset_ops.make_iterator( dataset._variant_tensor, # pylint: disable=protected-access iterator_resource) - - # pylint: disable=protected-access - return iterator_ops.Iterator(iterator_resource, initializer, - get_legacy_output_types(dataset), - get_legacy_output_shapes(dataset), - get_legacy_output_classes(dataset)) + # pylint: disable=protected-access + return iterator_ops.Iterator( + iterator_resource, initializer, get_legacy_output_types(dataset), + get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) @property @deprecation.deprecated( @@ -4339,14 +4329,11 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): buffer_size = -1 # This is the sentinel for auto-tuning. self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - - with ops.device(input_dataset._variant_tensor.device): - variant_tensor = gen_dataset_ops.prefetch_dataset( - input_dataset._variant_tensor, # pylint: disable=protected-access - buffer_size=self._buffer_size, - slack_period=slack_period, - **self._flat_structure) - + variant_tensor = gen_dataset_ops.prefetch_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access + buffer_size=self._buffer_size, + slack_period=slack_period, + **self._flat_structure) super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index c0fb6760821..db810de6689 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -368,8 +368,7 @@ class Iterator(trackable.Trackable): raise TypeError("Expected output shapes compatible with %r but got " "dataset with output shapes %r." % (self.output_shapes, dataset_output_shapes)) - - with ops.device(dataset._variant_tensor.device): + with ops.colocate_with(self._iterator_resource): return gen_dataset_ops.make_iterator( dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access @@ -421,14 +420,13 @@ class Iterator(trackable.Trackable): if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) - with ops.device(self._iterator_resource.device): - # pylint: disable=protected-access - flat_ret = gen_dataset_ops.iterator_get_next( - self._iterator_resource, - output_types=self._flat_tensor_types, - output_shapes=self._flat_tensor_shapes, - name=name) - return structure.from_tensor_list(self._element_spec, flat_ret) + # pylint: disable=protected-access + flat_ret = gen_dataset_ops.iterator_get_next( + self._iterator_resource, + output_types=self._flat_tensor_types, + output_shapes=self._flat_tensor_shapes, + name=name) + return structure.from_tensor_list(self._element_spec, flat_ret) def string_handle(self, name=None): """Returns a string-valued `tf.Tensor` that represents this iterator.