[tf.data] Add experimental transformation for applying map() on GPU devices.

PiperOrigin-RevId: 221265601
This commit is contained in:
Derek Murray 2018-11-13 07:47:27 -08:00 committed by TensorFlower Gardener
parent a466bbdb04
commit a326fdb402
13 changed files with 182 additions and 13 deletions

View File

@ -0,0 +1,5 @@
op {
graph_op_name: "ExperimentalMapDataset"
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
visibility: HIDDEN
}

View File

@ -46,7 +46,11 @@ namespace tensorflow {
// A few string constant used throughout this module.
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
static constexpr const char* const kDeviceArgOp =
FunctionLibraryDefinition::kDeviceArgOp;
static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
static constexpr const char* const kDeviceRetOp =
FunctionLibraryDefinition::kDeviceRetOp;
static constexpr const char* const kGradientOp =
FunctionLibraryDefinition::kGradientOp;
static constexpr const char* const kNodeLabel = "Func";
@ -1633,9 +1637,9 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
this->ret_nodes.resize(ret_types.size());
for (Node* n : this->graph->op_nodes()) {
gtl::InlinedVector<Node*, 4>* node_vec;
if (n->type_string() == kRetOp) {
if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) {
node_vec = &this->ret_nodes;
} else if (n->type_string() == kArgOp) {
} else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) {
node_vec = &this->arg_nodes;
} else {
continue;

View File

@ -149,8 +149,8 @@ class FunctionInstantiationHelper {
}
// Builds index for nodes that can be used as node's input arguments.
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
AttrSlice attr_values) {
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
bool ints_on_device) {
bool is_type_list;
DataTypeVector dtypes;
TF_RETURN_IF_ERROR(
@ -169,7 +169,11 @@ class FunctionInstantiationHelper {
strings::StrAppend(&name, "_", i);
}
NodeDef* gnode = AddNode(name);
if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
} else {
gnode->set_op(FunctionLibraryDefinition::kArgOp);
}
AddAttr("T", dtypes[i], gnode);
AddAttr("index", arg_index, gnode);
result_.arg_types.push_back(dtypes[i]);
@ -564,9 +568,11 @@ string Print(gtl::ArraySlice<const NodeDef*> nodes) {
std::vector<const NodeDef*> ret;
std::vector<const NodeDef*> body;
for (const NodeDef* n : nodes) {
if (n->op() == FunctionLibraryDefinition::kArgOp) {
if (n->op() == FunctionLibraryDefinition::kArgOp ||
n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
arg.push_back(n);
} else if (n->op() == FunctionLibraryDefinition::kRetOp) {
} else if (n->op() == FunctionLibraryDefinition::kRetOp ||
n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
ret.push_back(n);
} else {
body.push_back(n);
@ -638,10 +644,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
const OpDef& sig = fdef.signature();
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 &&
fdef.attr().at("experimental_ints_on_device").b();
FunctionInstantiationHelper helper(get_function, result);
Status s;
for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
s = helper.BuildInputArgIndex(arg_def, attr_values);
s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
if (!s.ok()) {
errors::AppendToMessage(&s, "In ", Print(arg_def));
return s;
@ -693,9 +702,6 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
}
}
bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 &&
fdef.attr().at("experimental_ints_on_device").b();
// Emits nodes for the function's return values.
int ret_index = 0;
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {

View File

@ -379,6 +379,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";
static constexpr const char* const kDeviceArgOp = "_DeviceArg";
static constexpr const char* const kRetOp = "_Retval";
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";

View File

@ -244,6 +244,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
.Device(DEVICE_GPU)
.HostMemory("input_dataset")
.HostMemory("handle"),
MapDatasetOp);
} // namespace
} // namespace data

View File

@ -69,6 +69,7 @@ void RetvalOp::Compute(OpKernelContext* ctx) {
}
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
@ -105,6 +106,8 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
.HostMemory("output")
.TypeConstraint<int32>("T"),
ArgOp);
REGISTER_KERNEL_BUILDER(
Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
#undef REGISTER
REGISTER_KERNEL_BUILDER(Name(kArgOp)

View File

@ -22,6 +22,7 @@ limitations under the License.
namespace tensorflow {
static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
static const char* const kDeviceArgOp = FunctionLibraryDefinition::kDeviceArgOp;
static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
static const char* const kDeviceRetOp = FunctionLibraryDefinition::kDeviceRetOp;

View File

@ -75,6 +75,17 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
.Output("handle: variant")
.Attr("f: func")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalNonSerializableDataset")
.Input("input_dataset: variant")
.Output("handle: variant")

View File

@ -35,6 +35,22 @@ output: The argument.
index: This argument is the index-th argument of the function.
)doc");
REGISTER_SYSTEM_OP("_DeviceArg")
.Output("output: T")
.Attr("T: type")
.Attr("index: int >= 0")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* context) {
context->set_output(0, context->UnknownShape());
return Status::OK();
})
.Doc(R"doc(
A graph node which represents an argument to a function.
output: The argument.
index: This argument is the index-th argument of the function.
)doc");
REGISTER_SYSTEM_OP("_Retval")
.Input("input: T")
.Attr("T: type")

View File

@ -38,6 +38,7 @@ cuda_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python/compat:compat",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",

View File

@ -28,7 +28,9 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat as util_compat
class CopyToDeviceTest(test_base.DatasetTestBase):
@ -294,6 +296,42 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testCopyToDeviceGpuWithMap(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
def generator():
for i in range(10):
yield i, float(i), str(i)
host_dataset = dataset_ops.Dataset.from_generator(
generator, output_types=(dtypes.int32, dtypes.float32, dtypes.string))
device_dataset = host_dataset.apply(
prefetching_ops.copy_to_device("/gpu:0"))
def gpu_map_func(x, y, z):
return math_ops.square(x), math_ops.square(y), z
device_dataset = device_dataset.apply(
prefetching_ops.map_on_gpu(gpu_map_func))
options = dataset_ops.Options()
options.experimental_autotune = False
device_dataset = device_dataset.with_options(options)
with ops.device("/gpu:0"):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
x, y, z = sess.run(next_element)
self.assertEqual(i**2, x)
self.assertEqual(float(i**2), y)
self.assertEqual(util_compat.as_bytes(str(i)), z)
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
def testCopyToDeviceGpuInt32(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")

View File

@ -540,3 +540,71 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
@property
def output_classes(self):
return self._input_dataset.output_classes
class _MapOnGpuDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that maps a function over elements in its using a GPU."""
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
"""See `Dataset.map()` for details."""
super(_MapOnGpuDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._use_inter_op_parallelism = use_inter_op_parallelism
wrapped_func = dataset_ops.StructuredFunctionWrapper(
map_func,
self._transformation_name(),
dataset=input_dataset,
defun_kwargs={"experimental_ints_on_device": True})
self._output_classes = wrapped_func.output_classes
self._output_shapes = wrapped_func.output_shapes
self._output_types = wrapped_func.output_types
self._map_func = wrapped_func.function
def _as_variant_tensor(self):
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
return ged_ops.experimental_map_dataset(
input_t,
self._map_func.captured_inputs,
f=self._map_func,
use_inter_op_parallelism=self._use_inter_op_parallelism,
**dataset_ops.flat_structure(self))
@property
def output_classes(self):
return self._output_classes
@property
def output_shapes(self):
return self._output_shapes
@property
def output_types(self):
return self._output_types
def _transformation_name(self):
return "map_on_gpu()"
def map_on_gpu(map_func):
"""Maps `map_func` across the elements of this dataset.
NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
`map_func` on GPU. It must be used after applying the
`tf.data.experimental.copy_to_device` transformation with a GPU device
argument.
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to
another nested structure of tensors.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _MapOnGpuDataset(dataset, map_func)
return _apply_fn

View File

@ -61,6 +61,7 @@ class Dataset(object):
collection of elements (nested structures of tensors) and a "logical
plan" of transformations that act on those elements.
"""
def __init__(self):
pass
@ -201,6 +202,7 @@ class Dataset(object):
# a 0-argument function.
@function.Defun(capture_by_value=True)
def _make_dataset():
"""Factory function for a dataset."""
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
# enclosing graph, so if a graph-level seed is present we set the local
# graph seed based on a combination of the graph- and op-level seeds.
@ -1777,7 +1779,8 @@ class StructuredFunctionWrapper(object):
input_shapes=None,
input_types=None,
add_to_graph=True,
experimental_nested_dataset_support=False):
experimental_nested_dataset_support=False,
defun_kwargs=None):
"""Creates a new `StructuredFunctionWrapper` for the given function.
Args:
@ -1798,6 +1801,9 @@ class StructuredFunctionWrapper(object):
default graph.
experimental_nested_dataset_support: (Optional.) If `True`, the function
will support `tf.data.Dataset` objects as arguments and return values.
defun_kwargs: (Optional.) A dictionary mapping string argument names to
values. If supplied, will be passed to `function.Defun()` as keyword
arguments.
Raises:
ValueError: If an invalid combination of `dataset`, `input_classes`,
@ -1832,7 +1838,11 @@ class StructuredFunctionWrapper(object):
# TODO(b/110122868): Enable this support for all `tf.data` functions.
self._nested_dataset_support = experimental_nested_dataset_support
@function.Defun(*self._defun_args(), func_name=self._func_name)
if defun_kwargs is None:
defun_kwargs = {}
@function.Defun(
*self._defun_args(), func_name=self._func_name, **defun_kwargs)
def tf_data_structured_function_wrapper(*args):
"""Wrapper for passing nested structures to and from tf.data functions."""
flat_args = []