[tf.data] Co-locate placement of iterator ops with the final dataset transformation.
PiperOrigin-RevId: 330841511 Change-Id: I3486226f7b50015396317cf8175fe185fd2ac7fd
This commit is contained in:
		
							parent
							
								
									f56f3a7391
								
							
						
					
					
						commit
						eb461280fe
					
				| @ -150,6 +150,17 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): | ||||
| 
 | ||||
|     self.assertDatasetProduces(device_dataset, list(range(10))) | ||||
| 
 | ||||
|   @combinations.generate(test_base.default_test_combinations()) | ||||
|   def testPrefetchToDeviceCorrectPlacement(self): | ||||
| 
 | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     dataset = dataset_ops.Dataset.range(10) | ||||
|     dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) | ||||
| 
 | ||||
|     self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) | ||||
| 
 | ||||
|   @combinations.generate(test_base.graph_only_combinations()) | ||||
|   def testPrefetchToDeviceWithReInit(self): | ||||
|     host_dataset = dataset_ops.Dataset.range(10) | ||||
| @ -213,6 +224,48 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): | ||||
|     self.assertEqual(device_dataset._variant_tensor.device, | ||||
|                      "/job:localhost/replica:0/task:0/device:GPU:0") | ||||
| 
 | ||||
|   @combinations.generate(test_base.eager_only_combinations()) | ||||
|   def testIteratorOnDeviceEagerMode(self): | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     dataset = dataset_ops.Dataset.range(10) | ||||
|     dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) | ||||
|     iterator = iter(dataset) | ||||
|     data = next(iterator) | ||||
| 
 | ||||
|     self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) | ||||
|     self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) | ||||
|     self.assertIn("gpu:0", data.device.lower()) | ||||
| 
 | ||||
|   @combinations.generate(test_base.graph_only_combinations()) | ||||
|   def testIteratorOnDeviceGraphModeOneShotIterator(self): | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     dataset = dataset_ops.Dataset.range(10) | ||||
|     dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) | ||||
|     iterator = dataset_ops.make_one_shot_iterator(dataset) | ||||
|     data = iterator.get_next() | ||||
| 
 | ||||
|     self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) | ||||
|     self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) | ||||
|     self.assertIn("gpu:0", data.device.lower()) | ||||
| 
 | ||||
|   @combinations.generate(test_base.graph_only_combinations()) | ||||
|   def testIteratorOnDeviceGraphModeInitializableIterator(self): | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     dataset = dataset_ops.Dataset.range(10) | ||||
|     dataset = dataset.apply(prefetching_ops.prefetch_to_device("/gpu:0")) | ||||
|     iterator = dataset_ops.make_initializable_iterator(dataset) | ||||
|     data = iterator.get_next() | ||||
| 
 | ||||
|     self.assertIn("gpu:0", dataset._variant_tensor.device.lower()) | ||||
|     self.assertIn("gpu:0", iterator._iterator_resource.device.lower()) | ||||
|     self.assertIn("gpu:0", data.device.lower()) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
|  | ||||
| @ -417,7 +417,8 @@ class DatasetV2(collections_abc.Iterable, tracking_base.Trackable, | ||||
|       RuntimeError: If not inside of tf.function and not executing eagerly. | ||||
|     """ | ||||
|     if context.executing_eagerly() or ops.inside_function(): | ||||
|       return iterator_ops.OwnedIterator(self) | ||||
|       with ops.device(self._variant_tensor.device): | ||||
|         return iterator_ops.OwnedIterator(self) | ||||
|     else: | ||||
|       raise RuntimeError("__iter__() is only supported inside of tf.function " | ||||
|                          "or when eager execution is enabled.") | ||||
| @ -2271,7 +2272,8 @@ class DatasetV1(DatasetV2): | ||||
| 
 | ||||
|   def _make_one_shot_iterator(self):  # pylint: disable=missing-docstring | ||||
|     if context.executing_eagerly(): | ||||
|       return iterator_ops.OwnedIterator(self) | ||||
|       with ops.device(self._variant_tensor.device): | ||||
|         return iterator_ops.OwnedIterator(self) | ||||
| 
 | ||||
|     _ensure_same_dataset_graph(self) | ||||
|     # Now that we create datasets at python object creation time, the capture | ||||
| @ -2313,12 +2315,13 @@ class DatasetV1(DatasetV2): | ||||
|       else: | ||||
|         six.reraise(ValueError, err) | ||||
| 
 | ||||
|     # pylint: disable=protected-access | ||||
|     return iterator_ops.Iterator( | ||||
|         gen_dataset_ops.one_shot_iterator( | ||||
|             dataset_factory=_make_dataset, **self._flat_structure), None, | ||||
|         get_legacy_output_types(self), get_legacy_output_shapes(self), | ||||
|         get_legacy_output_classes(self)) | ||||
|     with ops.device(self._variant_tensor.device): | ||||
|       # pylint: disable=protected-access | ||||
|       return iterator_ops.Iterator( | ||||
|           gen_dataset_ops.one_shot_iterator( | ||||
|               dataset_factory=_make_dataset, **self._flat_structure), None, | ||||
|           get_legacy_output_types(self), get_legacy_output_shapes(self), | ||||
|           get_legacy_output_classes(self)) | ||||
| 
 | ||||
|   @deprecation.deprecated( | ||||
|       None, "This is a deprecated API that should only be used in TF 1 graph " | ||||
| @ -2378,16 +2381,20 @@ class DatasetV1(DatasetV2): | ||||
|     dataset = self._apply_options() | ||||
|     if shared_name is None: | ||||
|       shared_name = "" | ||||
|     iterator_resource = gen_dataset_ops.iterator_v2( | ||||
|         container="", shared_name=shared_name, **self._flat_structure) | ||||
|     with ops.colocate_with(iterator_resource): | ||||
| 
 | ||||
|     with ops.device(self._variant_tensor.device): | ||||
|       iterator_resource = gen_dataset_ops.iterator_v2( | ||||
|           container="", shared_name=shared_name, **self._flat_structure) | ||||
| 
 | ||||
|       initializer = gen_dataset_ops.make_iterator( | ||||
|           dataset._variant_tensor,  # pylint: disable=protected-access | ||||
|           iterator_resource) | ||||
|     # pylint: disable=protected-access | ||||
|     return iterator_ops.Iterator( | ||||
|         iterator_resource, initializer, get_legacy_output_types(dataset), | ||||
|         get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) | ||||
| 
 | ||||
|       # pylint: disable=protected-access | ||||
|       return iterator_ops.Iterator(iterator_resource, initializer, | ||||
|                                    get_legacy_output_types(dataset), | ||||
|                                    get_legacy_output_shapes(dataset), | ||||
|                                    get_legacy_output_classes(dataset)) | ||||
| 
 | ||||
|   @property | ||||
|   @deprecation.deprecated( | ||||
|  | ||||
| @ -371,9 +371,11 @@ class Iterator(trackable.Trackable): | ||||
|           raise TypeError("Expected output shapes compatible with %r but got " | ||||
|                           "dataset with output shapes %r." % | ||||
|                           (self.output_shapes, dataset_output_shapes)) | ||||
|     with ops.colocate_with(self._iterator_resource): | ||||
| 
 | ||||
|     with ops.device(self._iterator_resource.device): | ||||
|       # pylint: disable=protected-access | ||||
|       return gen_dataset_ops.make_iterator( | ||||
|           dataset._variant_tensor, self._iterator_resource, name=name)  # pylint: disable=protected-access | ||||
|           dataset._variant_tensor, self._iterator_resource, name=name) | ||||
| 
 | ||||
|   def get_next(self, name=None): | ||||
|     """Returns a nested structure of `tf.Tensor`s representing the next element. | ||||
| @ -423,13 +425,14 @@ class Iterator(trackable.Trackable): | ||||
|     if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: | ||||
|       warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) | ||||
| 
 | ||||
|     # pylint: disable=protected-access | ||||
|     flat_ret = gen_dataset_ops.iterator_get_next( | ||||
|         self._iterator_resource, | ||||
|         output_types=self._flat_tensor_types, | ||||
|         output_shapes=self._flat_tensor_shapes, | ||||
|         name=name) | ||||
|     return structure.from_tensor_list(self._element_spec, flat_ret) | ||||
|     with ops.device(self._iterator_resource.device): | ||||
|       # pylint: disable=protected-access | ||||
|       flat_ret = gen_dataset_ops.iterator_get_next( | ||||
|           self._iterator_resource, | ||||
|           output_types=self._flat_tensor_types, | ||||
|           output_shapes=self._flat_tensor_shapes, | ||||
|           name=name) | ||||
|       return structure.from_tensor_list(self._element_spec, flat_ret) | ||||
| 
 | ||||
|   def get_next_as_optional(self): | ||||
|     # pylint: disable=protected-access | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user