diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMapDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMapDataset.pbtxt new file mode 100644 index 00000000000..e9619edcac1 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMapDataset.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "ExperimentalMapDataset" + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 6775695fa2d..7eb622dc117 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -46,7 +46,11 @@ namespace tensorflow { // A few string constant used throughout this module. static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static constexpr const char* const kDeviceArgOp = + FunctionLibraryDefinition::kDeviceArgOp; static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; +static constexpr const char* const kDeviceRetOp = + FunctionLibraryDefinition::kDeviceRetOp; static constexpr const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp; static constexpr const char* const kNodeLabel = "Func"; @@ -1633,9 +1637,9 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t, this->ret_nodes.resize(ret_types.size()); for (Node* n : this->graph->op_nodes()) { gtl::InlinedVector* node_vec; - if (n->type_string() == kRetOp) { + if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) { node_vec = &this->ret_nodes; - } else if (n->type_string() == kArgOp) { + } else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) { node_vec = &this->arg_nodes; } else { continue; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index abd0930ca9e..9d0933e680d 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -149,8 +149,8 @@ class FunctionInstantiationHelper { } // Builds index for nodes that can be used as node's input arguments. - Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, - AttrSlice attr_values) { + Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values, + bool ints_on_device) { bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR( @@ -169,7 +169,11 @@ class FunctionInstantiationHelper { strings::StrAppend(&name, "_", i); } NodeDef* gnode = AddNode(name); - gnode->set_op(FunctionLibraryDefinition::kArgOp); + if (ints_on_device && dtypes[i] == DataType::DT_INT32) { + gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp); + } else { + gnode->set_op(FunctionLibraryDefinition::kArgOp); + } AddAttr("T", dtypes[i], gnode); AddAttr("index", arg_index, gnode); result_.arg_types.push_back(dtypes[i]); @@ -564,9 +568,11 @@ string Print(gtl::ArraySlice nodes) { std::vector ret; std::vector body; for (const NodeDef* n : nodes) { - if (n->op() == FunctionLibraryDefinition::kArgOp) { + if (n->op() == FunctionLibraryDefinition::kArgOp || + n->op() == FunctionLibraryDefinition::kDeviceArgOp) { arg.push_back(n); - } else if (n->op() == FunctionLibraryDefinition::kRetOp) { + } else if (n->op() == FunctionLibraryDefinition::kRetOp || + n->op() == FunctionLibraryDefinition::kDeviceRetOp) { ret.push_back(n); } else { body.push_back(n); @@ -638,10 +644,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, const OpDef& sig = fdef.signature(); TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); + bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 && + fdef.attr().at("experimental_ints_on_device").b(); + FunctionInstantiationHelper helper(get_function, result); Status s; for (const OpDef::ArgDef& arg_def : sig.input_arg()) { - s = helper.BuildInputArgIndex(arg_def, attr_values); + s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device); if (!s.ok()) { errors::AppendToMessage(&s, "In ", Print(arg_def)); return s; @@ -693,9 +702,6 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, } } - bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 && - fdef.attr().at("experimental_ints_on_device").b(); - // Emits nodes for the function's return values. int ret_index = 0; for (const OpDef::ArgDef& ret_def : sig.output_arg()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 40ace6ef815..4cc1b858e3a 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -379,6 +379,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; + static constexpr const char* const kDeviceArgOp = "_DeviceArg"; static constexpr const char* const kRetOp = "_Retval"; static constexpr const char* const kDeviceRetOp = "_DeviceRetval"; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index d64114e70e5..ab20b832986 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -244,6 +244,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel { }; REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset") + .Device(DEVICE_GPU) + .HostMemory("input_dataset") + .HostMemory("handle"), + MapDatasetOp); } // namespace } // namespace data diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index d5c09712b25..cca3cfbd7c0 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -69,6 +69,7 @@ void RetvalOp::Compute(OpKernelContext* ctx) { } REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp); +REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp); REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp); REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp); @@ -105,6 +106,8 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp) .HostMemory("output") .TypeConstraint("T"), ArgOp); +REGISTER_KERNEL_BUILDER( + Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint("T"), ArgOp); #undef REGISTER REGISTER_KERNEL_BUILDER(Name(kArgOp) diff --git a/tensorflow/core/kernels/function_ops.h b/tensorflow/core/kernels/function_ops.h index 0f51eca1638..9ddd4956039 100644 --- a/tensorflow/core/kernels/function_ops.h +++ b/tensorflow/core/kernels/function_ops.h @@ -22,6 +22,7 @@ limitations under the License. namespace tensorflow { static const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static const char* const kDeviceArgOp = FunctionLibraryDefinition::kDeviceArgOp; static const char* const kRetOp = FunctionLibraryDefinition::kRetOp; static const char* const kDeviceRetOp = FunctionLibraryDefinition::kDeviceRetOp; diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 088d1865ddf..9733cf27768 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -75,6 +75,17 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ExperimentalMapDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("ExperimentalNonSerializableDataset") .Input("input_dataset: variant") .Output("handle: variant") diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc index 6edd86b3ad0..8e86dd9f780 100644 --- a/tensorflow/core/ops/function_ops.cc +++ b/tensorflow/core/ops/function_ops.cc @@ -35,6 +35,22 @@ output: The argument. index: This argument is the index-th argument of the function. )doc"); +REGISTER_SYSTEM_OP("_DeviceArg") + .Output("output: T") + .Attr("T: type") + .Attr("index: int >= 0") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext* context) { + context->set_output(0, context->UnknownShape()); + return Status::OK(); + }) + .Doc(R"doc( +A graph node which represents an argument to a function. + +output: The argument. +index: This argument is the index-th argument of the function. +)doc"); + REGISTER_SYSTEM_OP("_Retval") .Input("input: T") .Attr("T: type") diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 0141ac730fa..c9b11a2c381 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -38,6 +38,7 @@ cuda_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python/compat:compat", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", diff --git a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py index adfacf1c9f8..cea8bd6f0b7 100644 --- a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py @@ -28,7 +28,9 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test +from tensorflow.python.util import compat as util_compat class CopyToDeviceTest(test_base.DatasetTestBase): @@ -294,6 +296,42 @@ class CopyToDeviceTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testCopyToDeviceGpuWithMap(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + def generator(): + for i in range(10): + yield i, float(i), str(i) + + host_dataset = dataset_ops.Dataset.from_generator( + generator, output_types=(dtypes.int32, dtypes.float32, dtypes.string)) + device_dataset = host_dataset.apply( + prefetching_ops.copy_to_device("/gpu:0")) + + def gpu_map_func(x, y, z): + return math_ops.square(x), math_ops.square(y), z + + device_dataset = device_dataset.apply( + prefetching_ops.map_on_gpu(gpu_map_func)) + options = dataset_ops.Options() + options.experimental_autotune = False + device_dataset = device_dataset.with_options(options) + + with ops.device("/gpu:0"): + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + x, y, z = sess.run(next_element) + self.assertEqual(i**2, x) + self.assertEqual(float(i**2), y) + self.assertEqual(util_compat.as_bytes(str(i)), z) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testCopyToDeviceGpuInt32(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py index a55b8bfb769..d34f9f25bda 100644 --- a/tensorflow/python/data/experimental/ops/prefetching_ops.py +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -540,3 +540,71 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset): @property def output_classes(self): return self._input_dataset.output_classes + + +class _MapOnGpuDataset(dataset_ops.UnaryDataset): + """A `Dataset` that maps a function over elements in its using a GPU.""" + + def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): + """See `Dataset.map()` for details.""" + super(_MapOnGpuDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._use_inter_op_parallelism = use_inter_op_parallelism + + wrapped_func = dataset_ops.StructuredFunctionWrapper( + map_func, + self._transformation_name(), + dataset=input_dataset, + defun_kwargs={"experimental_ints_on_device": True}) + self._output_classes = wrapped_func.output_classes + self._output_shapes = wrapped_func.output_shapes + self._output_types = wrapped_func.output_types + self._map_func = wrapped_func.function + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return ged_ops.experimental_map_dataset( + input_t, + self._map_func.captured_inputs, + f=self._map_func, + use_inter_op_parallelism=self._use_inter_op_parallelism, + **dataset_ops.flat_structure(self)) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _transformation_name(self): + return "map_on_gpu()" + + +def map_on_gpu(map_func): + """Maps `map_func` across the elements of this dataset. + + NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs + `map_func` on GPU. It must be used after applying the + `tf.data.experimental.copy_to_device` transformation with a GPU device + argument. + + Args: + map_func: A function mapping a nested structure of tensors (having shapes + and types defined by `self.output_shapes` and `self.output_types`) to + another nested structure of tensors. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _MapOnGpuDataset(dataset, map_func) + + return _apply_fn diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 59389a24f73..3836a68e7d4 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -61,6 +61,7 @@ class Dataset(object): collection of elements (nested structures of tensors) and a "logical plan" of transformations that act on those elements. """ + def __init__(self): pass @@ -201,6 +202,7 @@ class Dataset(object): # a 0-argument function. @function.Defun(capture_by_value=True) def _make_dataset(): + """Factory function for a dataset.""" # NOTE(mrry): `Defun` does not capture the graph-level seed from the # enclosing graph, so if a graph-level seed is present we set the local # graph seed based on a combination of the graph- and op-level seeds. @@ -1777,7 +1779,8 @@ class StructuredFunctionWrapper(object): input_shapes=None, input_types=None, add_to_graph=True, - experimental_nested_dataset_support=False): + experimental_nested_dataset_support=False, + defun_kwargs=None): """Creates a new `StructuredFunctionWrapper` for the given function. Args: @@ -1798,6 +1801,9 @@ class StructuredFunctionWrapper(object): default graph. experimental_nested_dataset_support: (Optional.) If `True`, the function will support `tf.data.Dataset` objects as arguments and return values. + defun_kwargs: (Optional.) A dictionary mapping string argument names to + values. If supplied, will be passed to `function.Defun()` as keyword + arguments. Raises: ValueError: If an invalid combination of `dataset`, `input_classes`, @@ -1832,7 +1838,11 @@ class StructuredFunctionWrapper(object): # TODO(b/110122868): Enable this support for all `tf.data` functions. self._nested_dataset_support = experimental_nested_dataset_support - @function.Defun(*self._defun_args(), func_name=self._func_name) + if defun_kwargs is None: + defun_kwargs = {} + + @function.Defun( + *self._defun_args(), func_name=self._func_name, **defun_kwargs) def tf_data_structured_function_wrapper(*args): """Wrapper for passing nested structures to and from tf.data functions.""" flat_args = []