Internal change
PiperOrigin-RevId: 308645783 Change-Id: Iaa11d70b2cac141b4eb067465e78ef4c45384171
This commit is contained in:
parent
c1ceeb28c3
commit
5ab3af7a7b
|
@ -150,21 +150,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
self.assertDatasetProduces(device_dataset, list(range(10)))
|
self.assertDatasetProduces(device_dataset, list(range(10)))
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
|
||||||
def testPrefetchToDeviceCorrectPlacement(self):
|
|
||||||
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
|
||||||
|
|
||||||
self.assertTrue(("" == host_dataset._variant_tensor.device or
|
|
||||||
"cpu:0" in host_dataset._variant_tensor.device.lower()))
|
|
||||||
|
|
||||||
self.assertTrue("gpu:0" in device_dataset._variant_tensor.device.lower())
|
|
||||||
|
|
||||||
@combinations.generate(test_base.graph_only_combinations())
|
@combinations.generate(test_base.graph_only_combinations())
|
||||||
def testPrefetchToDeviceWithReInit(self):
|
def testPrefetchToDeviceWithReInit(self):
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
host_dataset = dataset_ops.Dataset.range(10)
|
||||||
|
|
|
@ -374,7 +374,6 @@ cuda_py_test(
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/compat",
|
"//tensorflow/python/compat",
|
||||||
"//tensorflow/python/data/experimental/ops:prefetching_ops",
|
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/data/ops:iterator_ops",
|
"//tensorflow/python/data/ops:iterator_ops",
|
||||||
"//tensorflow/python/data/ops:readers",
|
"//tensorflow/python/data/ops:readers",
|
||||||
|
|
|
@ -25,7 +25,6 @@ import numpy as np
|
||||||
from tensorflow.core.protobuf import cluster_pb2
|
from tensorflow.core.protobuf import cluster_pb2
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.data.experimental.ops import prefetching_ops
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
|
@ -1017,78 +1016,6 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
self.evaluate(counter_var.initializer)
|
self.evaluate(counter_var.initializer)
|
||||||
self.assertEqual(self.evaluate(fn()), 10)
|
self.assertEqual(self.evaluate(fn()), 10)
|
||||||
|
|
||||||
def assert_dataset_placement(self, host_dataset, host_iterator, host_tensor,
|
|
||||||
device_dataset, device_iterator, device_tensor):
|
|
||||||
|
|
||||||
self.assertTrue("cpu:0" in host_dataset._variant_tensor.device.lower() or
|
|
||||||
host_dataset._variant_tensor.device == "")
|
|
||||||
self.assertTrue(
|
|
||||||
"cpu:0" in host_iterator._iterator_resource.device.lower() or
|
|
||||||
host_iterator._iterator_resource.device == "")
|
|
||||||
self.assertTrue("cpu:0" in host_tensor.device.lower() or
|
|
||||||
host_tensor.device == "")
|
|
||||||
|
|
||||||
self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower())
|
|
||||||
self.assertIn("gpu:0", device_iterator._iterator_resource.device.lower())
|
|
||||||
self.assertIn("gpu:0", device_tensor.device.lower())
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testIteratorOnDeviceEagerMode(self):
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
|
||||||
|
|
||||||
host_iterator = iter(host_dataset)
|
|
||||||
device_iterator = iter(device_dataset)
|
|
||||||
|
|
||||||
host_tensor = next(host_iterator)
|
|
||||||
device_tensor = next(device_iterator)
|
|
||||||
|
|
||||||
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
|
|
||||||
device_dataset, device_iterator,
|
|
||||||
device_tensor)
|
|
||||||
|
|
||||||
@combinations.generate(test_base.graph_only_combinations())
|
|
||||||
def testIteratorOnDeviceGraphModeOneShotIterator(self):
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
|
||||||
|
|
||||||
host_iterator = dataset_ops.make_one_shot_iterator(host_dataset)
|
|
||||||
device_iterator = dataset_ops.make_one_shot_iterator(device_dataset)
|
|
||||||
|
|
||||||
host_tensor = host_iterator.get_next()
|
|
||||||
device_tensor = device_iterator.get_next()
|
|
||||||
|
|
||||||
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
|
|
||||||
device_dataset, device_iterator,
|
|
||||||
device_tensor)
|
|
||||||
|
|
||||||
@combinations.generate(test_base.graph_only_combinations())
|
|
||||||
def testIteratorOnDeviceGraphModeInitializableIterator(self):
|
|
||||||
if not test_util.is_gpu_available():
|
|
||||||
self.skipTest("No GPU available")
|
|
||||||
|
|
||||||
host_dataset = dataset_ops.Dataset.range(10)
|
|
||||||
device_dataset = host_dataset.apply(
|
|
||||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
|
||||||
|
|
||||||
host_iterator = dataset_ops.make_initializable_iterator(host_dataset)
|
|
||||||
device_iterator = dataset_ops.make_initializable_iterator(device_dataset)
|
|
||||||
|
|
||||||
host_tensor = host_iterator.get_next()
|
|
||||||
device_tensor = device_iterator.get_next()
|
|
||||||
|
|
||||||
self.assert_dataset_placement(host_dataset, host_iterator, host_tensor,
|
|
||||||
device_dataset, device_iterator,
|
|
||||||
device_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
|
@ -406,7 +406,6 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||||
RuntimeError: If not inside of tf.function and not executing eagerly.
|
RuntimeError: If not inside of tf.function and not executing eagerly.
|
||||||
"""
|
"""
|
||||||
if context.executing_eagerly() or ops.inside_function():
|
if context.executing_eagerly() or ops.inside_function():
|
||||||
with ops.device(self._variant_tensor.device):
|
|
||||||
return iterator_ops.OwnedIterator(self)
|
return iterator_ops.OwnedIterator(self)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
raise RuntimeError("__iter__() is only supported inside of tf.function "
|
||||||
|
@ -482,14 +481,12 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
|
raise RuntimeError("as_numpy_iterator() is not supported while tracing "
|
||||||
"functions")
|
"functions")
|
||||||
|
|
||||||
for component_spec in nest.flatten(self.element_spec):
|
for component_spec in nest.flatten(self.element_spec):
|
||||||
if not isinstance(component_spec, tensor_spec.TensorSpec):
|
if not isinstance(component_spec, tensor_spec.TensorSpec):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Dataset.as_numpy_iterator() does not support datasets containing "
|
"Dataset.as_numpy_iterator() does not support datasets containing "
|
||||||
+ str(component_spec.value_type))
|
+ str(component_spec.value_type))
|
||||||
|
|
||||||
with ops.device(self._variant_tensor.device):
|
|
||||||
return _NumpyIterator(self)
|
return _NumpyIterator(self)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -2161,9 +2158,7 @@ class DatasetV1(DatasetV2):
|
||||||
return self._make_one_shot_iterator()
|
return self._make_one_shot_iterator()
|
||||||
|
|
||||||
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
|
def _make_one_shot_iterator(self): # pylint: disable=missing-docstring
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
with ops.device(self._variant_tensor.device):
|
|
||||||
return iterator_ops.OwnedIterator(self)
|
return iterator_ops.OwnedIterator(self)
|
||||||
|
|
||||||
_ensure_same_dataset_graph(self)
|
_ensure_same_dataset_graph(self)
|
||||||
|
@ -2206,7 +2201,6 @@ class DatasetV1(DatasetV2):
|
||||||
else:
|
else:
|
||||||
six.reraise(ValueError, err)
|
six.reraise(ValueError, err)
|
||||||
|
|
||||||
with ops.device(self._variant_tensor.device):
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return iterator_ops.Iterator(
|
return iterator_ops.Iterator(
|
||||||
gen_dataset_ops.one_shot_iterator(
|
gen_dataset_ops.one_shot_iterator(
|
||||||
|
@ -2272,20 +2266,16 @@ class DatasetV1(DatasetV2):
|
||||||
dataset = self._apply_options()
|
dataset = self._apply_options()
|
||||||
if shared_name is None:
|
if shared_name is None:
|
||||||
shared_name = ""
|
shared_name = ""
|
||||||
|
|
||||||
with ops.device(self._variant_tensor.device):
|
|
||||||
iterator_resource = gen_dataset_ops.iterator_v2(
|
iterator_resource = gen_dataset_ops.iterator_v2(
|
||||||
container="", shared_name=shared_name, **self._flat_structure)
|
container="", shared_name=shared_name, **self._flat_structure)
|
||||||
|
with ops.colocate_with(iterator_resource):
|
||||||
initializer = gen_dataset_ops.make_iterator(
|
initializer = gen_dataset_ops.make_iterator(
|
||||||
dataset._variant_tensor, # pylint: disable=protected-access
|
dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
iterator_resource)
|
iterator_resource)
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return iterator_ops.Iterator(iterator_resource, initializer,
|
return iterator_ops.Iterator(
|
||||||
get_legacy_output_types(dataset),
|
iterator_resource, initializer, get_legacy_output_types(dataset),
|
||||||
get_legacy_output_shapes(dataset),
|
get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset))
|
||||||
get_legacy_output_classes(dataset))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
|
@ -4339,14 +4329,11 @@ class PrefetchDataset(UnaryUnchangedStructureDataset):
|
||||||
buffer_size = -1 # This is the sentinel for auto-tuning.
|
buffer_size = -1 # This is the sentinel for auto-tuning.
|
||||||
self._buffer_size = ops.convert_to_tensor(
|
self._buffer_size = ops.convert_to_tensor(
|
||||||
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
buffer_size, dtype=dtypes.int64, name="buffer_size")
|
||||||
|
|
||||||
with ops.device(input_dataset._variant_tensor.device):
|
|
||||||
variant_tensor = gen_dataset_ops.prefetch_dataset(
|
variant_tensor = gen_dataset_ops.prefetch_dataset(
|
||||||
input_dataset._variant_tensor, # pylint: disable=protected-access
|
input_dataset._variant_tensor, # pylint: disable=protected-access
|
||||||
buffer_size=self._buffer_size,
|
buffer_size=self._buffer_size,
|
||||||
slack_period=slack_period,
|
slack_period=slack_period,
|
||||||
**self._flat_structure)
|
**self._flat_structure)
|
||||||
|
|
||||||
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
|
super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -368,8 +368,7 @@ class Iterator(trackable.Trackable):
|
||||||
raise TypeError("Expected output shapes compatible with %r but got "
|
raise TypeError("Expected output shapes compatible with %r but got "
|
||||||
"dataset with output shapes %r." %
|
"dataset with output shapes %r." %
|
||||||
(self.output_shapes, dataset_output_shapes))
|
(self.output_shapes, dataset_output_shapes))
|
||||||
|
with ops.colocate_with(self._iterator_resource):
|
||||||
with ops.device(dataset._variant_tensor.device):
|
|
||||||
return gen_dataset_ops.make_iterator(
|
return gen_dataset_ops.make_iterator(
|
||||||
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
|
dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@ -421,7 +420,6 @@ class Iterator(trackable.Trackable):
|
||||||
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
|
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
|
||||||
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
|
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
|
||||||
|
|
||||||
with ops.device(self._iterator_resource.device):
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
flat_ret = gen_dataset_ops.iterator_get_next(
|
flat_ret = gen_dataset_ops.iterator_get_next(
|
||||||
self._iterator_resource,
|
self._iterator_resource,
|
||||||
|
|
Loading…
Reference in New Issue