[tf.data] Add experimental transformation for applying map()
on GPU devices.
PiperOrigin-RevId: 221265601
This commit is contained in:
parent
a466bbdb04
commit
a326fdb402
@ -0,0 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "ExperimentalMapDataset"
|
||||
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||
visibility: HIDDEN
|
||||
}
|
@ -46,7 +46,11 @@ namespace tensorflow {
|
||||
|
||||
// A few string constant used throughout this module.
|
||||
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||
static constexpr const char* const kDeviceArgOp =
|
||||
FunctionLibraryDefinition::kDeviceArgOp;
|
||||
static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
||||
static constexpr const char* const kDeviceRetOp =
|
||||
FunctionLibraryDefinition::kDeviceRetOp;
|
||||
static constexpr const char* const kGradientOp =
|
||||
FunctionLibraryDefinition::kGradientOp;
|
||||
static constexpr const char* const kNodeLabel = "Func";
|
||||
@ -1633,9 +1637,9 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
|
||||
this->ret_nodes.resize(ret_types.size());
|
||||
for (Node* n : this->graph->op_nodes()) {
|
||||
gtl::InlinedVector<Node*, 4>* node_vec;
|
||||
if (n->type_string() == kRetOp) {
|
||||
if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) {
|
||||
node_vec = &this->ret_nodes;
|
||||
} else if (n->type_string() == kArgOp) {
|
||||
} else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) {
|
||||
node_vec = &this->arg_nodes;
|
||||
} else {
|
||||
continue;
|
||||
|
@ -149,8 +149,8 @@ class FunctionInstantiationHelper {
|
||||
}
|
||||
|
||||
// Builds index for nodes that can be used as node's input arguments.
|
||||
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
|
||||
AttrSlice attr_values) {
|
||||
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
|
||||
bool ints_on_device) {
|
||||
bool is_type_list;
|
||||
DataTypeVector dtypes;
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -169,7 +169,11 @@ class FunctionInstantiationHelper {
|
||||
strings::StrAppend(&name, "_", i);
|
||||
}
|
||||
NodeDef* gnode = AddNode(name);
|
||||
gnode->set_op(FunctionLibraryDefinition::kArgOp);
|
||||
if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
|
||||
gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
|
||||
} else {
|
||||
gnode->set_op(FunctionLibraryDefinition::kArgOp);
|
||||
}
|
||||
AddAttr("T", dtypes[i], gnode);
|
||||
AddAttr("index", arg_index, gnode);
|
||||
result_.arg_types.push_back(dtypes[i]);
|
||||
@ -564,9 +568,11 @@ string Print(gtl::ArraySlice<const NodeDef*> nodes) {
|
||||
std::vector<const NodeDef*> ret;
|
||||
std::vector<const NodeDef*> body;
|
||||
for (const NodeDef* n : nodes) {
|
||||
if (n->op() == FunctionLibraryDefinition::kArgOp) {
|
||||
if (n->op() == FunctionLibraryDefinition::kArgOp ||
|
||||
n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
|
||||
arg.push_back(n);
|
||||
} else if (n->op() == FunctionLibraryDefinition::kRetOp) {
|
||||
} else if (n->op() == FunctionLibraryDefinition::kRetOp ||
|
||||
n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
|
||||
ret.push_back(n);
|
||||
} else {
|
||||
body.push_back(n);
|
||||
@ -638,10 +644,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
const OpDef& sig = fdef.signature();
|
||||
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
|
||||
|
||||
bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 &&
|
||||
fdef.attr().at("experimental_ints_on_device").b();
|
||||
|
||||
FunctionInstantiationHelper helper(get_function, result);
|
||||
Status s;
|
||||
for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
|
||||
s = helper.BuildInputArgIndex(arg_def, attr_values);
|
||||
s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
|
||||
if (!s.ok()) {
|
||||
errors::AppendToMessage(&s, "In ", Print(arg_def));
|
||||
return s;
|
||||
@ -693,9 +702,6 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
||||
}
|
||||
}
|
||||
|
||||
bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 &&
|
||||
fdef.attr().at("experimental_ints_on_device").b();
|
||||
|
||||
// Emits nodes for the function's return values.
|
||||
int ret_index = 0;
|
||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
||||
|
@ -379,6 +379,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
// Ops created for function arguments bear the name given by `kArgOp`; those
|
||||
// created for return values bear the name given by `kRetOp`.
|
||||
static constexpr const char* const kArgOp = "_Arg";
|
||||
static constexpr const char* const kDeviceArgOp = "_DeviceArg";
|
||||
static constexpr const char* const kRetOp = "_Retval";
|
||||
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
|
||||
|
||||
|
@ -244,6 +244,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("input_dataset")
|
||||
.HostMemory("handle"),
|
||||
MapDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
|
@ -69,6 +69,7 @@ void RetvalOp::Compute(OpKernelContext* ctx) {
|
||||
}
|
||||
|
||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
|
||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
|
||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
|
||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
|
||||
|
||||
@ -105,6 +106,8 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
ArgOp);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
|
||||
#undef REGISTER
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name(kArgOp)
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||
static const char* const kDeviceArgOp = FunctionLibraryDefinition::kDeviceArgOp;
|
||||
static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
||||
static const char* const kDeviceRetOp = FunctionLibraryDefinition::kDeviceRetOp;
|
||||
|
||||
|
@ -75,6 +75,17 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalMapDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("other_arguments: Targuments")
|
||||
.Output("handle: variant")
|
||||
.Attr("f: func")
|
||||
.Attr("Targuments: list(type) >= 0")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("use_inter_op_parallelism: bool = true")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("ExperimentalNonSerializableDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Output("handle: variant")
|
||||
|
@ -35,6 +35,22 @@ output: The argument.
|
||||
index: This argument is the index-th argument of the function.
|
||||
)doc");
|
||||
|
||||
REGISTER_SYSTEM_OP("_DeviceArg")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("index: int >= 0")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* context) {
|
||||
context->set_output(0, context->UnknownShape());
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
A graph node which represents an argument to a function.
|
||||
|
||||
output: The argument.
|
||||
index: This argument is the index-th argument of the function.
|
||||
)doc");
|
||||
|
||||
REGISTER_SYSTEM_OP("_Retval")
|
||||
.Input("input: T")
|
||||
.Attr("T: type")
|
||||
|
@ -38,6 +38,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/compat:compat",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:iterator_ops",
|
||||
|
@ -28,7 +28,9 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat as util_compat
|
||||
|
||||
|
||||
class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
@ -294,6 +296,42 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testCopyToDeviceGpuWithMap(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
||||
def generator():
|
||||
for i in range(10):
|
||||
yield i, float(i), str(i)
|
||||
|
||||
host_dataset = dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=(dtypes.int32, dtypes.float32, dtypes.string))
|
||||
device_dataset = host_dataset.apply(
|
||||
prefetching_ops.copy_to_device("/gpu:0"))
|
||||
|
||||
def gpu_map_func(x, y, z):
|
||||
return math_ops.square(x), math_ops.square(y), z
|
||||
|
||||
device_dataset = device_dataset.apply(
|
||||
prefetching_ops.map_on_gpu(gpu_map_func))
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_autotune = False
|
||||
device_dataset = device_dataset.with_options(options)
|
||||
|
||||
with ops.device("/gpu:0"):
|
||||
iterator = device_dataset.make_initializable_iterator()
|
||||
next_element = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
sess.run(iterator.initializer)
|
||||
for i in range(10):
|
||||
x, y, z = sess.run(next_element)
|
||||
self.assertEqual(i**2, x)
|
||||
self.assertEqual(float(i**2), y)
|
||||
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(next_element)
|
||||
|
||||
def testCopyToDeviceGpuInt32(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPU available")
|
||||
|
@ -540,3 +540,71 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._input_dataset.output_classes
|
||||
|
||||
|
||||
class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
||||
"""A `Dataset` that maps a function over elements in its using a GPU."""
|
||||
|
||||
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
|
||||
"""See `Dataset.map()` for details."""
|
||||
super(_MapOnGpuDataset, self).__init__(input_dataset)
|
||||
self._input_dataset = input_dataset
|
||||
self._use_inter_op_parallelism = use_inter_op_parallelism
|
||||
|
||||
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
||||
map_func,
|
||||
self._transformation_name(),
|
||||
dataset=input_dataset,
|
||||
defun_kwargs={"experimental_ints_on_device": True})
|
||||
self._output_classes = wrapped_func.output_classes
|
||||
self._output_shapes = wrapped_func.output_shapes
|
||||
self._output_types = wrapped_func.output_types
|
||||
self._map_func = wrapped_func.function
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||
return ged_ops.experimental_map_dataset(
|
||||
input_t,
|
||||
self._map_func.captured_inputs,
|
||||
f=self._map_func,
|
||||
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||
**dataset_ops.flat_structure(self))
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
def _transformation_name(self):
|
||||
return "map_on_gpu()"
|
||||
|
||||
|
||||
def map_on_gpu(map_func):
|
||||
"""Maps `map_func` across the elements of this dataset.
|
||||
|
||||
NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
|
||||
`map_func` on GPU. It must be used after applying the
|
||||
`tf.data.experimental.copy_to_device` transformation with a GPU device
|
||||
argument.
|
||||
|
||||
Args:
|
||||
map_func: A function mapping a nested structure of tensors (having shapes
|
||||
and types defined by `self.output_shapes` and `self.output_types`) to
|
||||
another nested structure of tensors.
|
||||
|
||||
Returns:
|
||||
A `Dataset` transformation function, which can be passed to
|
||||
`tf.data.Dataset.apply`.
|
||||
"""
|
||||
|
||||
def _apply_fn(dataset):
|
||||
return _MapOnGpuDataset(dataset, map_func)
|
||||
|
||||
return _apply_fn
|
||||
|
@ -61,6 +61,7 @@ class Dataset(object):
|
||||
collection of elements (nested structures of tensors) and a "logical
|
||||
plan" of transformations that act on those elements.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@ -201,6 +202,7 @@ class Dataset(object):
|
||||
# a 0-argument function.
|
||||
@function.Defun(capture_by_value=True)
|
||||
def _make_dataset():
|
||||
"""Factory function for a dataset."""
|
||||
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
||||
# enclosing graph, so if a graph-level seed is present we set the local
|
||||
# graph seed based on a combination of the graph- and op-level seeds.
|
||||
@ -1777,7 +1779,8 @@ class StructuredFunctionWrapper(object):
|
||||
input_shapes=None,
|
||||
input_types=None,
|
||||
add_to_graph=True,
|
||||
experimental_nested_dataset_support=False):
|
||||
experimental_nested_dataset_support=False,
|
||||
defun_kwargs=None):
|
||||
"""Creates a new `StructuredFunctionWrapper` for the given function.
|
||||
|
||||
Args:
|
||||
@ -1798,6 +1801,9 @@ class StructuredFunctionWrapper(object):
|
||||
default graph.
|
||||
experimental_nested_dataset_support: (Optional.) If `True`, the function
|
||||
will support `tf.data.Dataset` objects as arguments and return values.
|
||||
defun_kwargs: (Optional.) A dictionary mapping string argument names to
|
||||
values. If supplied, will be passed to `function.Defun()` as keyword
|
||||
arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid combination of `dataset`, `input_classes`,
|
||||
@ -1832,7 +1838,11 @@ class StructuredFunctionWrapper(object):
|
||||
# TODO(b/110122868): Enable this support for all `tf.data` functions.
|
||||
self._nested_dataset_support = experimental_nested_dataset_support
|
||||
|
||||
@function.Defun(*self._defun_args(), func_name=self._func_name)
|
||||
if defun_kwargs is None:
|
||||
defun_kwargs = {}
|
||||
|
||||
@function.Defun(
|
||||
*self._defun_args(), func_name=self._func_name, **defun_kwargs)
|
||||
def tf_data_structured_function_wrapper(*args):
|
||||
"""Wrapper for passing nested structures to and from tf.data functions."""
|
||||
flat_args = []
|
||||
|
Loading…
Reference in New Issue
Block a user