Internal change
PiperOrigin-RevId: 308645783 Change-Id: Iaa11d70b2cac141b4eb067465e78ef4c45384171
This commit is contained in:
parent
c1ceeb28c3
commit
5ab3af7a7b
|
@ -150,21 +150,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
|
||||
self.assertDatasetProduces(device_dataset, list(range(10)))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPrefetchToDeviceCorrectPlacement(self):
|
||||
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
self.assertTrue(("" == host_dataset._variant_tensor.device or
|
||||
"cpu:0" in host_dataset._variant_tensor.device.lower()))
|
||||
|
||||
self.assertTrue("gpu:0" in device_dataset._variant_tensor.device.lower())
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testPrefetchToDeviceWithReInit(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
|
|
|
@ -374,7 +374,6 @@ cuda_py_test(
|
|||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat",
|
||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
|
|
|
@ -25,7 +25,6 @@ import numpy as np
|
|||
from tensorflow.core.protobuf import cluster_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
|
@ -1017,78 +1016,6 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||
self.evaluate(counter_var.initializer)
|
||||
self.assertEqual(self.evaluate(fn()), 10)
|
||||
|
||||
def assert_dataset_placement(self, host_dataset, host_iterator, host_tensor,
|
||||
device_dataset, device_iterator, device_tensor):
|
||||
|
||||
self.assertTrue("cpu:0" in host_dataset._variant_tensor.device.lower() or
|
||||
host_dataset._variant_tensor.device == "")
|
||||
self.assertTrue(
|
||||
"cpu:0" in host_iterator._iterator_resource.device.lower() or
|
||||
host_iterator._iterator_resource.device == "")
|
||||
self.assertTrue("cpu:0" in host_tensor.device.lower() or
|
||||
host_tensor.device == "")
|
||||
|
||||
self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower())
|
||||
self.assertIn("gpu:0", device_iterator._iterator_resource.device.lower())
|
||||
self.assertIn("gpu:0", device_tensor.device.lower())
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testIteratorOnDeviceEagerMode(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
host_iterator = iter(host_dataset)
|
||||
device_iterator = iter(device_dataset)
|
||||
|
||||
host_tensor = next(host_iterator)
|
||||
device_tensor = next(device_iterator)
|
||||
|
||||
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
|
||||
device_dataset, device_iterator,
|
||||
device_tensor)
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorOnDeviceGraphModeOneShotIterator(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
host_iterator = dataset_ops.make_one_shot_iterator(host_dataset)
|
||||
device_iterator = dataset_ops.make_one_shot_iterator(device_dataset)
|
||||
|
||||
host_tensor = host_iterator.get_next()
|
||||
device_tensor = device_iterator.get_next()
|
||||
|
||||
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
|
||||
device_dataset, device_iterator,
|
||||
device_tensor)
|
||||
|
||||
@combinations.generate(test_base.graph_only_combinations())
|
||||
def testIteratorOnDeviceGraphModeInitializableIterator(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
host_iterator = dataset_ops.make_initializable_iterator(host_dataset)
|
||||
device_iterator = dataset_ops.make_initializable_iterator(device_dataset)
|
||||
|
||||
host_tensor = host_iterator.get_next()
|
||||
device_tensor = device_iterator.get_next()
|
||||
|
||||
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
|
||||
device_dataset, device_iterator,
|
||||
device_tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
|
@ -406,8 +406,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
RuntimeError: If not inside of tf.function and not executing eagerly.
|
||||
"""
|
||||
if context.executing_eagerly() or ops.inside_function():
|
||||
with ops.device(self._variant_tensor.device):
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
else:
|
||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||
"or when eager execution is enabled.")
|
||||
|
@ -482,15 +481,13 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||
if not context.executing_eagerly():
|
||||
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
|
||||
"functions")
|
||||
|
||||
for component_spec in nest.flatten(self.element_spec):
|
||||
if not isinstance(component_spec, tensor_spec.TensorSpec):
|
||||
raise TypeError(
|
||||
"Dataset.as_numpy_iterator() does not support datasets containing "
|
||||
+ str(component_spec.value_type))
|
||||
|
||||
with ops.device(self._variant_tensor.device):
|
||||
return _NumpyIterator(self)
|
||||
return _NumpyIterator(self)
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
|
@ -2161,10 +2158,8 @@ class DatasetV1(DatasetV2):
|
|||
return self._make_one_shot_iterator()
|
||||
|
||||
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
|
||||
|
||||
if context.executing_eagerly():
|
||||
with ops.device(self._variant_tensor.device):
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
return iterator_ops.OwnedIterator(self)
|
||||
|
||||
_ensure_same_dataset_graph(self)
|
||||
# Now that we create datasets at python object creation time, the capture
|
||||
|
@ -2206,13 +2201,12 @@ class DatasetV1(DatasetV2):
|
|||
else:
|
||||
six.reraise(ValueError, err)
|
||||
|
||||
with ops.device(self._variant_tensor.device):
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(
|
||||
gen_dataset_ops.one_shot_iterator(
|
||||
dataset_factory=_make_dataset, **self._flat_structure), None,
|
||||
get_legacy_output_types(self), get_legacy_output_shapes(self),
|
||||
get_legacy_output_classes(self))
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(
|
||||
gen_dataset_ops.one_shot_iterator(
|
||||
dataset_factory=_make_dataset, **self._flat_structure), None,
|
||||
get_legacy_output_types(self), get_legacy_output_shapes(self),
|
||||
get_legacy_output_classes(self))
|
||||
|
||||
@deprecation.deprecated(
|
||||
None, "This is a deprecated API that should only be used in TF 1 graph "
|
||||
|
@ -2272,20 +2266,16 @@ class DatasetV1(DatasetV2):
|
|||
dataset = self._apply_options()
|
||||
if shared_name is None:
|
||||
shared_name = ""
|
||||
|
||||
with ops.device(self._variant_tensor.device):
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
|
||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||
container="", shared_name=shared_name, **self._flat_structure)
|
||||
with ops.colocate_with(iterator_resource):
|
||||
initializer = gen_dataset_ops.make_iterator(
|
||||
dataset._variant_tensor, # pylint: disable=protected-access
|
||||
iterator_resource)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(iterator_resource, initializer,
|
||||
get_legacy_output_types(dataset),
|
||||
get_legacy_output_shapes(dataset),
|
||||
get_legacy_output_classes(dataset))
|
||||
# pylint: disable=protected-access
|
||||
return iterator_ops.Iterator(
|
||||
iterator_resource, initializer, get_legacy_output_types(dataset),
|
||||
get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
|
||||
|
||||
@property
|
||||
@deprecation.deprecated(
|
||||
|
@ -4339,14 +4329,11 @@ class PrefetchDataset(UnaryUnchangedStructureDataset):
|
|||
buffer_size = -1 # This is the sentinel for auto-tuning.
|
||||
self._buffer_size = ops.convert_to_tensor(
|
||||
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
||||
|
||||
with ops.device(input_dataset._variant_tensor.device):
|
||||
variant_tensor = gen_dataset_ops.prefetch_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
slack_period=slack_period,
|
||||
**self._flat_structure)
|
||||
|
||||
variant_tensor = gen_dataset_ops.prefetch_dataset(
|
||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||
buffer_size=self._buffer_size,
|
||||
slack_period=slack_period,
|
||||
**self._flat_structure)
|
||||
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
|
||||
|
||||
|
||||
|
|
|
@ -368,8 +368,7 @@ class Iterator(trackable.Trackable):
|
|||
raise TypeError("Expected output shapes compatible with %r but got "
|
||||
"dataset with output shapes %r." %
|
||||
(self.output_shapes, dataset_output_shapes))
|
||||
|
||||
with ops.device(dataset._variant_tensor.device):
|
||||
with ops.colocate_with(self._iterator_resource):
|
||||
return gen_dataset_ops.make_iterator(
|
||||
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
|
||||
|
||||
|
@ -421,14 +420,13 @@ class Iterator(trackable.Trackable):
|
|||
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
|
||||
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
|
||||
|
||||
with ops.device(self._iterator_resource.device):
|
||||
# pylint: disable=protected-access
|
||||
flat_ret = gen_dataset_ops.iterator_get_next(
|
||||
self._iterator_resource,
|
||||
output_types=self._flat_tensor_types,
|
||||
output_shapes=self._flat_tensor_shapes,
|
||||
name=name)
|
||||
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||
# pylint: disable=protected-access
|
||||
flat_ret = gen_dataset_ops.iterator_get_next(
|
||||
self._iterator_resource,
|
||||
output_types=self._flat_tensor_types,
|
||||
output_shapes=self._flat_tensor_shapes,
|
||||
name=name)
|
||||
return structure.from_tensor_list(self._element_spec, flat_ret)
|
||||
|
||||
def string_handle(self, name=None):
|
||||
"""Returns a string-valued `tf.Tensor` that represents this iterator.
|
||||
|
|
Loading…
Reference in New Issue