[tf.data] Add tf.contrib.data.prefetch_to_device(), which supports prefetching to GPU memory.
PiperOrigin-RevId: 190158272
This commit is contained in:
parent
e07e70a414
commit
dbea93d7f1
@ -21,6 +21,7 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:dataset_ops",
|
||||
"//tensorflow/contrib/data/python/ops:iterator_ops",
|
||||
"//tensorflow/contrib/data/python/ops:prefetching_ops",
|
||||
"//tensorflow/contrib/data/python/ops:readers",
|
||||
"//tensorflow/contrib/data/python/ops:shuffle_ops",
|
||||
"//tensorflow/contrib/data/python/ops:transformation_ops",
|
||||
|
||||
@ -36,6 +36,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
|
||||
@@map_and_batch
|
||||
@@padded_batch_and_drop_remainder
|
||||
@@parallel_interleave
|
||||
@@prefetch_to_device
|
||||
@@read_batch_features
|
||||
@@rejection_resample
|
||||
@@scan
|
||||
@ -67,6 +68,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window
|
||||
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
|
||||
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
|
||||
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
|
||||
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
|
||||
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
|
||||
from tensorflow.contrib.data.python.ops.readers import read_batch_features
|
||||
from tensorflow.contrib.data.python.ops.readers import SqlDataset
|
||||
|
||||
@ -406,4 +406,25 @@ REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
|
||||
FunctionBufferingResourceGetNextOp);
|
||||
#endif // TENSORFLOW_USE_SYCL
|
||||
|
||||
class IteratorGetDeviceOp : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
// NOTE(mrry): We do not currently Validate that the handle
|
||||
// corresponds to a real IteratorResource, because that symbol is
|
||||
// not exposed from the framework library.
|
||||
Tensor* device_name_t;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->allocate_output(0, TensorShape({}), &device_name_t));
|
||||
// NOTE(mrry): Since the operation's input is a resource, we must be
|
||||
// colocated with it, and so we can simply return the current device's
|
||||
// name without looking at the input.
|
||||
device_name_t->scalar<string>()() = ctx->device()->name();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
|
||||
IteratorGetDeviceOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -37,6 +37,14 @@ REGISTER_OP("UniqueDataset")
|
||||
Creates a dataset that contains the unique elements of `input_dataset`.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("IteratorGetDevice")
|
||||
.Input("resource: resource")
|
||||
.Output("device: string")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Returns the name of the device on which `resource` has been placed.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FunctionBufferingResource")
|
||||
.Input("string_arg: string")
|
||||
.Input("target_device: string")
|
||||
|
||||
@ -479,10 +479,6 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["prefetching_ops_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss", # b/68785503
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/contrib/data/python/ops:prefetching_ops",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
||||
@ -26,6 +26,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -107,6 +108,53 @@ class StagingAreaOpsTest(test.TestCase):
|
||||
self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0",
|
||||
"/job:localhost/replica:0/task:0/gpu:0")
|
||||
|
||||
def testPrefetchToDevice(self):
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/cpu:1"))
|
||||
|
||||
# NOTE(mrry): This device block creates the "host" dataset and iterator on
|
||||
# /cpu:0, and ensures that the prefetching is across devices. In typical use
|
||||
# this would not be necessary, because the GPU device would not support any
|
||||
# of the dataset-related ops.
|
||||
with ops.device("/cpu:0"):
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
|
||||
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
|
||||
self.assertEqual(host_dataset.output_types, iterator.output_types)
|
||||
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
|
||||
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
|
||||
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
|
||||
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
|
||||
|
||||
next_element = iterator.get_next()
|
||||
self.assertEqual(dtypes.int64, next_element.dtype)
|
||||
self.assertEqual([], next_element.shape)
|
||||
|
||||
worker_config = config_pb2.ConfigProto()
|
||||
worker_config.device_count["CPU"] = 2
|
||||
with self.test_session(config=worker_config) as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testPrefetchToDeviceGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
host_dataset = dataset_ops.Dataset.range(10)
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.prefetch_to_device("/gpu:0"))
|
||||
|
||||
iterator = device_dataset.make_one_shot_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
for i in range(10):
|
||||
self.assertEqual(i, sess.run(next_element))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
@ -173,6 +173,10 @@ py_library(
|
||||
srcs = ["prefetching_ops.py"],
|
||||
deps = [
|
||||
":contrib_op_loader",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -17,8 +17,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import warnings
|
||||
|
||||
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
|
||||
from tensorflow.contrib.data.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
|
||||
# TODO(rohanj): Add a python class that constructs resource in the __init__
|
||||
@ -51,3 +60,124 @@ def function_buffering_resource_get_next(function_buffer_resource,
|
||||
function_buffer_resource=function_buffer_resource,
|
||||
output_types=output_types,
|
||||
name=name)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class _PrefetchToDeviceIterator(object):
|
||||
"""A replacement for @{tf.data.Iterator} that prefetches to another device."""
|
||||
|
||||
def __init__(self, input_dataset, device, buffer_size):
|
||||
self._input_dataset = input_dataset
|
||||
self._get_next_call_count = 0
|
||||
input_iterator = input_dataset.make_one_shot_iterator()
|
||||
input_iterator_handle = input_iterator.string_handle()
|
||||
|
||||
@function.Defun(dtypes.string)
|
||||
def _prefetch_fn(handle):
|
||||
remote_iterator = iterator_ops.Iterator.from_string_handle(
|
||||
handle, input_iterator.output_types, input_iterator.output_shapes,
|
||||
input_iterator.output_classes)
|
||||
return remote_iterator.get_next()
|
||||
|
||||
with ops.device(device):
|
||||
self._buffering_resource = function_buffering_resource(
|
||||
f=_prefetch_fn,
|
||||
target_device=gen_dataset_ops.iterator_get_device(
|
||||
input_iterator._iterator_resource),
|
||||
string_arg=input_iterator_handle,
|
||||
buffer_size=buffer_size,
|
||||
thread_pool_size=0)
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""See @{tf.data.Iterator.get_next}."""
|
||||
self._get_next_call_count += 1
|
||||
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
|
||||
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
|
||||
|
||||
flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
|
||||
self._buffering_resource,
|
||||
output_types=nest.flatten(sparse.as_dense_types(
|
||||
self.output_types, self.output_classes)), name=name)
|
||||
|
||||
ret = sparse.deserialize_sparse_tensors(
|
||||
nest.pack_sequence_as(self.output_types, flat_ret),
|
||||
self.output_types, self.output_shapes, self.output_classes)
|
||||
|
||||
for tensor, shape in zip(
|
||||
nest.flatten(ret), nest.flatten(self.output_shapes)):
|
||||
if isinstance(tensor, ops.Tensor):
|
||||
tensor.set_shape(shape)
|
||||
|
||||
return ret
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._input_dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class _PrefetchToDeviceDataset(dataset_ops.Dataset):
|
||||
"""A `Dataset` whose iterator prefetches elements to another device."""
|
||||
|
||||
def __init__(self, input_dataset, device, buffer_size):
|
||||
self._input_dataset = input_dataset
|
||||
self._device = device
|
||||
self._buffer_size = buffer_size if buffer_size is not None else 1
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
return _PrefetchToDeviceIterator(self._input_dataset, self._device,
|
||||
self._buffer_size)
|
||||
|
||||
def make_initializable_iterator(self, shared_name=None):
|
||||
raise NotImplementedError("`prefetch_to_device()` is not currently "
|
||||
"compatible with initializable iterators. Use "
|
||||
"`make_one_shot_iterator()` instead.")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
# TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
|
||||
# transformation methods is called.
|
||||
# TODO(mrry): Investigate support for chaining further transformations after
|
||||
# the prefetch, including GPU support.
|
||||
raise NotImplementedError("`prefetch_to_device()` must be the last "
|
||||
"transformation in a dataset pipeline.")
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._input_dataset.output_types
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._input_dataset.output_shapes
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
|
||||
def prefetch_to_device(device, buffer_size=None):
|
||||
"""A transformation that prefetches dataset values to the given `device`.
|
||||
|
||||
NOTE: Although the transformation creates a @{tf.data.Dataset}, the
|
||||
transformation must be the final `Dataset` in the input pipeline.
|
||||
|
||||
Args:
|
||||
device: A string. The name of a device to which elements will be prefetched.
|
||||
buffer_size: (Optional.) The number of elements to buffer on `device`.
|
||||
Defaults to an automatically chosen value.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
@{tf.data.Dataset.apply}.
|
||||
"""
|
||||
def _apply_fn(dataset):
|
||||
return _PrefetchToDeviceDataset(dataset, device, buffer_size)
|
||||
|
||||
return _apply_fn
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user