[tf.data] Add tf.contrib.data.prefetch_to_device(), which supports prefetching to GPU memory.

PiperOrigin-RevId: 190158272
This commit is contained in:
Derek Murray 2018-03-22 18:20:09 -07:00 committed by TensorFlower Gardener
parent e07e70a414
commit dbea93d7f1
8 changed files with 214 additions and 4 deletions

View File

@ -21,6 +21,7 @@ py_library(
deps = [
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/contrib/data/python/ops:prefetching_ops",
"//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",

View File

@ -36,6 +36,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
@@scan
@ -67,6 +68,7 @@ from tensorflow.contrib.data.python.ops.grouping import group_by_window
from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
from tensorflow.contrib.data.python.ops.readers import SqlDataset

View File

@ -406,4 +406,25 @@ REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
FunctionBufferingResourceGetNextOp);
#endif // TENSORFLOW_USE_SYCL
class IteratorGetDeviceOp : public OpKernel {
public:
using OpKernel::OpKernel;
void Compute(OpKernelContext* ctx) override {
// NOTE(mrry): We do not currently Validate that the handle
// corresponds to a real IteratorResource, because that symbol is
// not exposed from the framework library.
Tensor* device_name_t;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, TensorShape({}), &device_name_t));
// NOTE(mrry): Since the operation's input is a resource, we must be
// colocated with it, and so we can simply return the current device's
// name without looking at the input.
device_name_t->scalar<string>()() = ctx->device()->name();
}
};
REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
IteratorGetDeviceOp);
} // namespace tensorflow

View File

@ -37,6 +37,14 @@ REGISTER_OP("UniqueDataset")
Creates a dataset that contains the unique elements of `input_dataset`.
)doc");
REGISTER_OP("IteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Returns the name of the device on which `resource` has been placed.
)doc");
REGISTER_OP("FunctionBufferingResource")
.Input("string_arg: string")
.Input("target_device: string")

View File

@ -479,10 +479,6 @@ py_test(
size = "small",
srcs = ["prefetching_ops_test.py"],
srcs_version = "PY2AND3",
tags = [
"manual",
"no_oss", # b/68785503
],
deps = [
"//tensorflow/contrib/data/python/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",

View File

@ -26,6 +26,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -107,6 +108,53 @@ class StagingAreaOpsTest(test.TestCase):
self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0",
"/job:localhost/replica:0/task:0/gpu:0")
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
# NOTE(mrry): This device block creates the "host" dataset and iterator on
# /cpu:0, and ensures that the prefetching is across devices. In typical use
# this would not be necessary, because the GPU device would not support any
# of the dataset-related ops.
with ops.device("/cpu:0"):
iterator = device_dataset.make_one_shot_iterator()
self.assertEqual(host_dataset.output_types, device_dataset.output_types)
self.assertEqual(host_dataset.output_types, iterator.output_types)
self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
self.assertEqual(host_dataset.output_classes, iterator.output_classes)
next_element = iterator.get_next()
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
with self.test_session(config=worker_config) as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testPrefetchToDeviceGpu(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/gpu:0"))
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with self.test_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
if __name__ == "__main__":
test.main()

View File

@ -173,6 +173,10 @@ py_library(
srcs = ["prefetching_ops.py"],
deps = [
":contrib_op_loader",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
],
)

View File

@ -17,8 +17,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
# TODO(rohanj): Add a python class that constructs resource in the __init__
@ -51,3 +60,124 @@ def function_buffering_resource_get_next(function_buffer_resource,
function_buffer_resource=function_buffer_resource,
output_types=output_types,
name=name)
# pylint: disable=protected-access
class _PrefetchToDeviceIterator(object):
"""A replacement for @{tf.data.Iterator} that prefetches to another device."""
def __init__(self, input_dataset, device, buffer_size):
self._input_dataset = input_dataset
self._get_next_call_count = 0
input_iterator = input_dataset.make_one_shot_iterator()
input_iterator_handle = input_iterator.string_handle()
@function.Defun(dtypes.string)
def _prefetch_fn(handle):
remote_iterator = iterator_ops.Iterator.from_string_handle(
handle, input_iterator.output_types, input_iterator.output_shapes,
input_iterator.output_classes)
return remote_iterator.get_next()
with ops.device(device):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
target_device=gen_dataset_ops.iterator_get_device(
input_iterator._iterator_resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
thread_pool_size=0)
def get_next(self, name=None):
"""See @{tf.data.Iterator.get_next}."""
self._get_next_call_count += 1
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
self._buffering_resource,
output_types=nest.flatten(sparse.as_dense_types(
self.output_types, self.output_classes)), name=name)
ret = sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self.output_types, flat_ret),
self.output_types, self.output_shapes, self.output_classes)
for tensor, shape in zip(
nest.flatten(ret), nest.flatten(self.output_shapes)):
if isinstance(tensor, ops.Tensor):
tensor.set_shape(shape)
return ret
@property
def output_classes(self):
return self._input_dataset.output_classes
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_types(self):
return self._input_dataset.output_types
# pylint: enable=protected-access
class _PrefetchToDeviceDataset(dataset_ops.Dataset):
"""A `Dataset` whose iterator prefetches elements to another device."""
def __init__(self, input_dataset, device, buffer_size):
self._input_dataset = input_dataset
self._device = device
self._buffer_size = buffer_size if buffer_size is not None else 1
def make_one_shot_iterator(self):
return _PrefetchToDeviceIterator(self._input_dataset, self._device,
self._buffer_size)
def make_initializable_iterator(self, shared_name=None):
raise NotImplementedError("`prefetch_to_device()` is not currently "
"compatible with initializable iterators. Use "
"`make_one_shot_iterator()` instead.")
def _as_variant_tensor(self):
# TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
# transformation methods is called.
# TODO(mrry): Investigate support for chaining further transformations after
# the prefetch, including GPU support.
raise NotImplementedError("`prefetch_to_device()` must be the last "
"transformation in a dataset pipeline.")
@property
def output_types(self):
return self._input_dataset.output_types
@property
def output_shapes(self):
return self._input_dataset.output_shapes
@property
def output_classes(self):
return self._input_dataset.output_classes
def prefetch_to_device(device, buffer_size=None):
"""A transformation that prefetches dataset values to the given `device`.
NOTE: Although the transformation creates a @{tf.data.Dataset}, the
transformation must be the final `Dataset` in the input pipeline.
Args:
device: A string. The name of a device to which elements will be prefetched.
buffer_size: (Optional.) The number of elements to buffer on `device`.
Defaults to an automatically chosen value.
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
return _PrefetchToDeviceDataset(dataset, device, buffer_size)
return _apply_fn