Internal change

PiperOrigin-RevId: 308645783
Change-Id: Iaa11d70b2cac141b4eb067465e78ef4c45384171
This commit is contained in:
A. Unique TensorFlower 2020-04-27 10:17:02 -07:00 committed by TensorFlower Gardener
parent c1ceeb28c3
commit 5ab3af7a7b
5 changed files with 29 additions and 133 deletions

View File

@ -150,21 +150,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(device_dataset, list(range(10))) self.assertDatasetProduces(device_dataset, list(range(10)))
@combinations.generate(test_base.default_test_combinations())
def testPrefetchToDeviceCorrectPlacement(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
self.assertTrue(("" == host_dataset._variant_tensor.device or
"cpu:0" in host_dataset._variant_tensor.device.lower()))
self.assertTrue("gpu:0" in device_dataset._variant_tensor.device.lower())
@combinations.generate(test_base.graph_only_combinations()) @combinations.generate(test_base.graph_only_combinations())
def testPrefetchToDeviceWithReInit(self): def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10) host_dataset = dataset_ops.Dataset.range(10)

View File

@ -374,7 +374,6 @@ cuda_py_test(
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/compat", "//tensorflow/python/compat",
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers", "//tensorflow/python/data/ops:readers",

View File

@ -25,7 +25,6 @@ import numpy as np
from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
@ -1017,78 +1016,6 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
self.evaluate(counter_var.initializer) self.evaluate(counter_var.initializer)
self.assertEqual(self.evaluate(fn()), 10) self.assertEqual(self.evaluate(fn()), 10)
def assert_dataset_placement(self, host_dataset, host_iterator, host_tensor,
device_dataset, device_iterator, device_tensor):
self.assertTrue("cpu:0" in host_dataset._variant_tensor.device.lower() or
host_dataset._variant_tensor.device == "")
self.assertTrue(
"cpu:0" in host_iterator._iterator_resource.device.lower() or
host_iterator._iterator_resource.device == "")
self.assertTrue("cpu:0" in host_tensor.device.lower() or
host_tensor.device == "")
self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower())
self.assertIn("gpu:0", device_iterator._iterator_resource.device.lower())
self.assertIn("gpu:0", device_tensor.device.lower())
@combinations.generate(test_base.eager_only_combinations())
def testIteratorOnDeviceEagerMode(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
host_iterator = iter(host_dataset)
device_iterator = iter(device_dataset)
host_tensor = next(host_iterator)
device_tensor = next(device_iterator)
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
device_dataset, device_iterator,
device_tensor)
@combinations.generate(test_base.graph_only_combinations())
def testIteratorOnDeviceGraphModeOneShotIterator(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
host_iterator = dataset_ops.make_one_shot_iterator(host_dataset)
device_iterator = dataset_ops.make_one_shot_iterator(device_dataset)
host_tensor = host_iterator.get_next()
device_tensor = device_iterator.get_next()
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
device_dataset, device_iterator,
device_tensor)
@combinations.generate(test_base.graph_only_combinations())
def testIteratorOnDeviceGraphModeInitializableIterator(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
host_iterator = dataset_ops.make_initializable_iterator(host_dataset)
device_iterator = dataset_ops.make_initializable_iterator(device_dataset)
host_tensor = host_iterator.get_next()
device_tensor = device_iterator.get_next()
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
device_dataset, device_iterator,
device_tensor)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -406,8 +406,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
RuntimeError: If not inside of tf.function and not executing eagerly. RuntimeError: If not inside of tf.function and not executing eagerly.
""" """
if context.executing_eagerly() or ops.inside_function(): if context.executing_eagerly() or ops.inside_function():
with ops.device(self._variant_tensor.device): return iterator_ops.OwnedIterator(self)
return iterator_ops.OwnedIterator(self)
else: else:
raise RuntimeError("__iter__() is only supported inside of tf.function " raise RuntimeError("__iter__() is only supported inside of tf.function "
"or when eager execution is enabled.") "or when eager execution is enabled.")
@ -482,15 +481,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
if not context.executing_eagerly(): if not context.executing_eagerly():
raise RuntimeError("as_numpy_iterator() is not supported while tracing " raise RuntimeError("as_numpy_iterator() is not supported while tracing "
"functions") "functions")
for component_spec in nest.flatten(self.element_spec): for component_spec in nest.flatten(self.element_spec):
if not isinstance(component_spec, tensor_spec.TensorSpec): if not isinstance(component_spec, tensor_spec.TensorSpec):
raise TypeError( raise TypeError(
"Dataset.as_numpy_iterator() does not support datasets containing " "Dataset.as_numpy_iterator() does not support datasets containing "
+ str(component_spec.value_type)) + str(component_spec.value_type))
with ops.device(self._variant_tensor.device): return _NumpyIterator(self)
return _NumpyIterator(self)
@property @property
def _flat_shapes(self): def _flat_shapes(self):
@ -2161,10 +2158,8 @@ class DatasetV1(DatasetV2):
return self._make_one_shot_iterator() return self._make_one_shot_iterator()
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
if context.executing_eagerly(): if context.executing_eagerly():
with ops.device(self._variant_tensor.device): return iterator_ops.OwnedIterator(self)
return iterator_ops.OwnedIterator(self)
_ensure_same_dataset_graph(self) _ensure_same_dataset_graph(self)
# Now that we create datasets at python object creation time, the capture # Now that we create datasets at python object creation time, the capture
@ -2206,13 +2201,12 @@ class DatasetV1(DatasetV2):
else: else:
six.reraise(ValueError, err) six.reraise(ValueError, err)
with ops.device(self._variant_tensor.device): # pylint: disable=protected-access
# pylint: disable=protected-access return iterator_ops.Iterator(
return iterator_ops.Iterator( gen_dataset_ops.one_shot_iterator(
gen_dataset_ops.one_shot_iterator( dataset_factory=_make_dataset, **self._flat_structure), None,
dataset_factory=_make_dataset, **self._flat_structure), None, get_legacy_output_types(self), get_legacy_output_shapes(self),
get_legacy_output_types(self), get_legacy_output_shapes(self), get_legacy_output_classes(self))
get_legacy_output_classes(self))
@deprecation.deprecated( @deprecation.deprecated(
None, "This is a deprecated API that should only be used in TF 1 graph " None, "This is a deprecated API that should only be used in TF 1 graph "
@ -2272,20 +2266,16 @@ class DatasetV1(DatasetV2):
dataset = self._apply_options() dataset = self._apply_options()
if shared_name is None: if shared_name is None:
shared_name = "" shared_name = ""
iterator_resource = gen_dataset_ops.iterator_v2(
with ops.device(self._variant_tensor.device): container="", shared_name=shared_name, **self._flat_structure)
iterator_resource = gen_dataset_ops.iterator_v2( with ops.colocate_with(iterator_resource):
container="", shared_name=shared_name, **self._flat_structure)
initializer = gen_dataset_ops.make_iterator( initializer = gen_dataset_ops.make_iterator(
dataset._variant_tensor, # pylint: disable=protected-access dataset._variant_tensor, # pylint: disable=protected-access
iterator_resource) iterator_resource)
# pylint: disable=protected-access
# pylint: disable=protected-access return iterator_ops.Iterator(
return iterator_ops.Iterator(iterator_resource, initializer, iterator_resource, initializer, get_legacy_output_types(dataset),
get_legacy_output_types(dataset), get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
get_legacy_output_shapes(dataset),
get_legacy_output_classes(dataset))
@property @property
@deprecation.deprecated( @deprecation.deprecated(
@ -4339,14 +4329,11 @@ class PrefetchDataset(UnaryUnchangedStructureDataset):
buffer_size = -1 # This is the sentinel for auto-tuning. buffer_size = -1 # This is the sentinel for auto-tuning.
self._buffer_size = ops.convert_to_tensor( self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size") buffer_size, dtype=dtypes.int64, name="buffer_size")
variant_tensor = gen_dataset_ops.prefetch_dataset(
with ops.device(input_dataset._variant_tensor.device): input_dataset._variant_tensor, # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.prefetch_dataset( buffer_size=self._buffer_size,
input_dataset._variant_tensor, # pylint: disable=protected-access slack_period=slack_period,
buffer_size=self._buffer_size, **self._flat_structure)
slack_period=slack_period,
**self._flat_structure)
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)

View File

@ -368,8 +368,7 @@ class Iterator(trackable.Trackable):
raise TypeError("Expected output shapes compatible with %r but got " raise TypeError("Expected output shapes compatible with %r but got "
"dataset with output shapes %r." % "dataset with output shapes %r." %
(self.output_shapes, dataset_output_shapes)) (self.output_shapes, dataset_output_shapes))
with ops.colocate_with(self._iterator_resource):
with ops.device(dataset._variant_tensor.device):
return gen_dataset_ops.make_iterator( return gen_dataset_ops.make_iterator(
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
@ -421,14 +420,13 @@ class Iterator(trackable.Trackable):
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
with ops.device(self._iterator_resource.device): # pylint: disable=protected-access
# pylint: disable=protected-access flat_ret = gen_dataset_ops.iterator_get_next(
flat_ret = gen_dataset_ops.iterator_get_next( self._iterator_resource,
self._iterator_resource, output_types=self._flat_tensor_types,
output_types=self._flat_tensor_types, output_shapes=self._flat_tensor_shapes,
output_shapes=self._flat_tensor_shapes, name=name)
name=name) return structure.from_tensor_list(self._element_spec, flat_ret)
return structure.from_tensor_list(self._element_spec, flat_ret)
def string_handle(self, name=None): def string_handle(self, name=None):
"""Returns a string-valued `tf.Tensor` that represents this iterator. """Returns a string-valued `tf.Tensor` that represents this iterator.