[tf.data] Add experimental transformation for applying map()
on GPU devices.
PiperOrigin-RevId: 221265601
This commit is contained in:
parent
a466bbdb04
commit
a326fdb402
@ -0,0 +1,5 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "ExperimentalMapDataset"
|
||||||
|
summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -46,7 +46,11 @@ namespace tensorflow {
|
|||||||
|
|
||||||
// A few string constant used throughout this module.
|
// A few string constant used throughout this module.
|
||||||
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||||
|
static constexpr const char* const kDeviceArgOp =
|
||||||
|
FunctionLibraryDefinition::kDeviceArgOp;
|
||||||
static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
||||||
|
static constexpr const char* const kDeviceRetOp =
|
||||||
|
FunctionLibraryDefinition::kDeviceRetOp;
|
||||||
static constexpr const char* const kGradientOp =
|
static constexpr const char* const kGradientOp =
|
||||||
FunctionLibraryDefinition::kGradientOp;
|
FunctionLibraryDefinition::kGradientOp;
|
||||||
static constexpr const char* const kNodeLabel = "Func";
|
static constexpr const char* const kNodeLabel = "Func";
|
||||||
@ -1633,9 +1637,9 @@ FunctionBody::FunctionBody(const FunctionDef& f, DataTypeSlice arg_t,
|
|||||||
this->ret_nodes.resize(ret_types.size());
|
this->ret_nodes.resize(ret_types.size());
|
||||||
for (Node* n : this->graph->op_nodes()) {
|
for (Node* n : this->graph->op_nodes()) {
|
||||||
gtl::InlinedVector<Node*, 4>* node_vec;
|
gtl::InlinedVector<Node*, 4>* node_vec;
|
||||||
if (n->type_string() == kRetOp) {
|
if (n->type_string() == kRetOp || n->type_string() == kDeviceRetOp) {
|
||||||
node_vec = &this->ret_nodes;
|
node_vec = &this->ret_nodes;
|
||||||
} else if (n->type_string() == kArgOp) {
|
} else if (n->type_string() == kArgOp || n->type_string() == kDeviceArgOp) {
|
||||||
node_vec = &this->arg_nodes;
|
node_vec = &this->arg_nodes;
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
|
@ -149,8 +149,8 @@ class FunctionInstantiationHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Builds index for nodes that can be used as node's input arguments.
|
// Builds index for nodes that can be used as node's input arguments.
|
||||||
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def,
|
Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values,
|
||||||
AttrSlice attr_values) {
|
bool ints_on_device) {
|
||||||
bool is_type_list;
|
bool is_type_list;
|
||||||
DataTypeVector dtypes;
|
DataTypeVector dtypes;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -169,7 +169,11 @@ class FunctionInstantiationHelper {
|
|||||||
strings::StrAppend(&name, "_", i);
|
strings::StrAppend(&name, "_", i);
|
||||||
}
|
}
|
||||||
NodeDef* gnode = AddNode(name);
|
NodeDef* gnode = AddNode(name);
|
||||||
gnode->set_op(FunctionLibraryDefinition::kArgOp);
|
if (ints_on_device && dtypes[i] == DataType::DT_INT32) {
|
||||||
|
gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp);
|
||||||
|
} else {
|
||||||
|
gnode->set_op(FunctionLibraryDefinition::kArgOp);
|
||||||
|
}
|
||||||
AddAttr("T", dtypes[i], gnode);
|
AddAttr("T", dtypes[i], gnode);
|
||||||
AddAttr("index", arg_index, gnode);
|
AddAttr("index", arg_index, gnode);
|
||||||
result_.arg_types.push_back(dtypes[i]);
|
result_.arg_types.push_back(dtypes[i]);
|
||||||
@ -564,9 +568,11 @@ string Print(gtl::ArraySlice<const NodeDef*> nodes) {
|
|||||||
std::vector<const NodeDef*> ret;
|
std::vector<const NodeDef*> ret;
|
||||||
std::vector<const NodeDef*> body;
|
std::vector<const NodeDef*> body;
|
||||||
for (const NodeDef* n : nodes) {
|
for (const NodeDef* n : nodes) {
|
||||||
if (n->op() == FunctionLibraryDefinition::kArgOp) {
|
if (n->op() == FunctionLibraryDefinition::kArgOp ||
|
||||||
|
n->op() == FunctionLibraryDefinition::kDeviceArgOp) {
|
||||||
arg.push_back(n);
|
arg.push_back(n);
|
||||||
} else if (n->op() == FunctionLibraryDefinition::kRetOp) {
|
} else if (n->op() == FunctionLibraryDefinition::kRetOp ||
|
||||||
|
n->op() == FunctionLibraryDefinition::kDeviceRetOp) {
|
||||||
ret.push_back(n);
|
ret.push_back(n);
|
||||||
} else {
|
} else {
|
||||||
body.push_back(n);
|
body.push_back(n);
|
||||||
@ -638,10 +644,13 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
|||||||
const OpDef& sig = fdef.signature();
|
const OpDef& sig = fdef.signature();
|
||||||
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
|
TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values));
|
||||||
|
|
||||||
|
bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 &&
|
||||||
|
fdef.attr().at("experimental_ints_on_device").b();
|
||||||
|
|
||||||
FunctionInstantiationHelper helper(get_function, result);
|
FunctionInstantiationHelper helper(get_function, result);
|
||||||
Status s;
|
Status s;
|
||||||
for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
|
for (const OpDef::ArgDef& arg_def : sig.input_arg()) {
|
||||||
s = helper.BuildInputArgIndex(arg_def, attr_values);
|
s = helper.BuildInputArgIndex(arg_def, attr_values, ints_on_device);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
errors::AppendToMessage(&s, "In ", Print(arg_def));
|
errors::AppendToMessage(&s, "In ", Print(arg_def));
|
||||||
return s;
|
return s;
|
||||||
@ -693,9 +702,6 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ints_on_device = fdef.attr().count("experimental_ints_on_device") != 0 &&
|
|
||||||
fdef.attr().at("experimental_ints_on_device").b();
|
|
||||||
|
|
||||||
// Emits nodes for the function's return values.
|
// Emits nodes for the function's return values.
|
||||||
int ret_index = 0;
|
int ret_index = 0;
|
||||||
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
for (const OpDef::ArgDef& ret_def : sig.output_arg()) {
|
||||||
|
@ -379,6 +379,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
|||||||
// Ops created for function arguments bear the name given by `kArgOp`; those
|
// Ops created for function arguments bear the name given by `kArgOp`; those
|
||||||
// created for return values bear the name given by `kRetOp`.
|
// created for return values bear the name given by `kRetOp`.
|
||||||
static constexpr const char* const kArgOp = "_Arg";
|
static constexpr const char* const kArgOp = "_Arg";
|
||||||
|
static constexpr const char* const kDeviceArgOp = "_DeviceArg";
|
||||||
static constexpr const char* const kRetOp = "_Retval";
|
static constexpr const char* const kRetOp = "_Retval";
|
||||||
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
|
static constexpr const char* const kDeviceRetOp = "_DeviceRetval";
|
||||||
|
|
||||||
|
@ -244,6 +244,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
|
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("ExperimentalMapDataset")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("input_dataset")
|
||||||
|
.HostMemory("handle"),
|
||||||
|
MapDatasetOp);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -69,6 +69,7 @@ void RetvalOp::Compute(OpKernelContext* ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
|
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
|
||||||
|
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
|
||||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
|
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
|
||||||
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
|
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
|
||||||
|
|
||||||
@ -105,6 +106,8 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
|
|||||||
.HostMemory("output")
|
.HostMemory("output")
|
||||||
.TypeConstraint<int32>("T"),
|
.TypeConstraint<int32>("T"),
|
||||||
ArgOp);
|
ArgOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
|
||||||
#undef REGISTER
|
#undef REGISTER
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name(kArgOp)
|
REGISTER_KERNEL_BUILDER(Name(kArgOp)
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
|
||||||
|
static const char* const kDeviceArgOp = FunctionLibraryDefinition::kDeviceArgOp;
|
||||||
static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
|
||||||
static const char* const kDeviceRetOp = FunctionLibraryDefinition::kDeviceRetOp;
|
static const char* const kDeviceRetOp = FunctionLibraryDefinition::kDeviceRetOp;
|
||||||
|
|
||||||
|
@ -75,6 +75,17 @@ REGISTER_OP("ExperimentalIgnoreErrorsDataset")
|
|||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("ExperimentalMapDataset")
|
||||||
|
.Input("input_dataset: variant")
|
||||||
|
.Input("other_arguments: Targuments")
|
||||||
|
.Output("handle: variant")
|
||||||
|
.Attr("f: func")
|
||||||
|
.Attr("Targuments: list(type) >= 0")
|
||||||
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
|
.Attr("use_inter_op_parallelism: bool = true")
|
||||||
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("ExperimentalNonSerializableDataset")
|
REGISTER_OP("ExperimentalNonSerializableDataset")
|
||||||
.Input("input_dataset: variant")
|
.Input("input_dataset: variant")
|
||||||
.Output("handle: variant")
|
.Output("handle: variant")
|
||||||
|
@ -35,6 +35,22 @@ output: The argument.
|
|||||||
index: This argument is the index-th argument of the function.
|
index: This argument is the index-th argument of the function.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_SYSTEM_OP("_DeviceArg")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: type")
|
||||||
|
.Attr("index: int >= 0")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* context) {
|
||||||
|
context->set_output(0, context->UnknownShape());
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"doc(
|
||||||
|
A graph node which represents an argument to a function.
|
||||||
|
|
||||||
|
output: The argument.
|
||||||
|
index: This argument is the index-th argument of the function.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_SYSTEM_OP("_Retval")
|
REGISTER_SYSTEM_OP("_Retval")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
|
@ -38,6 +38,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:dtypes",
|
"//tensorflow/python:dtypes",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:framework_test_lib",
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python/compat:compat",
|
"//tensorflow/python/compat:compat",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/data/ops:iterator_ops",
|
"//tensorflow/python/data/ops:iterator_ops",
|
||||||
|
@ -28,7 +28,9 @@ from tensorflow.python.framework import errors
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.util import compat as util_compat
|
||||||
|
|
||||||
|
|
||||||
class CopyToDeviceTest(test_base.DatasetTestBase):
|
class CopyToDeviceTest(test_base.DatasetTestBase):
|
||||||
@ -294,6 +296,42 @@ class CopyToDeviceTest(test_base.DatasetTestBase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(next_element)
|
sess.run(next_element)
|
||||||
|
|
||||||
|
def testCopyToDeviceGpuWithMap(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
def generator():
|
||||||
|
for i in range(10):
|
||||||
|
yield i, float(i), str(i)
|
||||||
|
|
||||||
|
host_dataset = dataset_ops.Dataset.from_generator(
|
||||||
|
generator, output_types=(dtypes.int32, dtypes.float32, dtypes.string))
|
||||||
|
device_dataset = host_dataset.apply(
|
||||||
|
prefetching_ops.copy_to_device("/gpu:0"))
|
||||||
|
|
||||||
|
def gpu_map_func(x, y, z):
|
||||||
|
return math_ops.square(x), math_ops.square(y), z
|
||||||
|
|
||||||
|
device_dataset = device_dataset.apply(
|
||||||
|
prefetching_ops.map_on_gpu(gpu_map_func))
|
||||||
|
options = dataset_ops.Options()
|
||||||
|
options.experimental_autotune = False
|
||||||
|
device_dataset = device_dataset.with_options(options)
|
||||||
|
|
||||||
|
with ops.device("/gpu:0"):
|
||||||
|
iterator = device_dataset.make_initializable_iterator()
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
sess.run(iterator.initializer)
|
||||||
|
for i in range(10):
|
||||||
|
x, y, z = sess.run(next_element)
|
||||||
|
self.assertEqual(i**2, x)
|
||||||
|
self.assertEqual(float(i**2), y)
|
||||||
|
self.assertEqual(util_compat.as_bytes(str(i)), z)
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(next_element)
|
||||||
|
|
||||||
def testCopyToDeviceGpuInt32(self):
|
def testCopyToDeviceGpuInt32(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPU available")
|
self.skipTest("No GPU available")
|
||||||
|
@ -540,3 +540,71 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
|
|||||||
@property
|
@property
|
||||||
def output_classes(self):
|
def output_classes(self):
|
||||||
return self._input_dataset.output_classes
|
return self._input_dataset.output_classes
|
||||||
|
|
||||||
|
|
||||||
|
class _MapOnGpuDataset(dataset_ops.UnaryDataset):
|
||||||
|
"""A `Dataset` that maps a function over elements in its using a GPU."""
|
||||||
|
|
||||||
|
def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True):
|
||||||
|
"""See `Dataset.map()` for details."""
|
||||||
|
super(_MapOnGpuDataset, self).__init__(input_dataset)
|
||||||
|
self._input_dataset = input_dataset
|
||||||
|
self._use_inter_op_parallelism = use_inter_op_parallelism
|
||||||
|
|
||||||
|
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
||||||
|
map_func,
|
||||||
|
self._transformation_name(),
|
||||||
|
dataset=input_dataset,
|
||||||
|
defun_kwargs={"experimental_ints_on_device": True})
|
||||||
|
self._output_classes = wrapped_func.output_classes
|
||||||
|
self._output_shapes = wrapped_func.output_shapes
|
||||||
|
self._output_types = wrapped_func.output_types
|
||||||
|
self._map_func = wrapped_func.function
|
||||||
|
|
||||||
|
def _as_variant_tensor(self):
|
||||||
|
input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
|
||||||
|
return ged_ops.experimental_map_dataset(
|
||||||
|
input_t,
|
||||||
|
self._map_func.captured_inputs,
|
||||||
|
f=self._map_func,
|
||||||
|
use_inter_op_parallelism=self._use_inter_op_parallelism,
|
||||||
|
**dataset_ops.flat_structure(self))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_classes(self):
|
||||||
|
return self._output_classes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shapes(self):
|
||||||
|
return self._output_shapes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_types(self):
|
||||||
|
return self._output_types
|
||||||
|
|
||||||
|
def _transformation_name(self):
|
||||||
|
return "map_on_gpu()"
|
||||||
|
|
||||||
|
|
||||||
|
def map_on_gpu(map_func):
|
||||||
|
"""Maps `map_func` across the elements of this dataset.
|
||||||
|
|
||||||
|
NOTE: This is a highly experimental version of `tf.data.Dataset.map` that runs
|
||||||
|
`map_func` on GPU. It must be used after applying the
|
||||||
|
`tf.data.experimental.copy_to_device` transformation with a GPU device
|
||||||
|
argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
map_func: A function mapping a nested structure of tensors (having shapes
|
||||||
|
and types defined by `self.output_shapes` and `self.output_types`) to
|
||||||
|
another nested structure of tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Dataset` transformation function, which can be passed to
|
||||||
|
`tf.data.Dataset.apply`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _apply_fn(dataset):
|
||||||
|
return _MapOnGpuDataset(dataset, map_func)
|
||||||
|
|
||||||
|
return _apply_fn
|
||||||
|
@ -61,6 +61,7 @@ class Dataset(object):
|
|||||||
collection of elements (nested structures of tensors) and a "logical
|
collection of elements (nested structures of tensors) and a "logical
|
||||||
plan" of transformations that act on those elements.
|
plan" of transformations that act on those elements.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -201,6 +202,7 @@ class Dataset(object):
|
|||||||
# a 0-argument function.
|
# a 0-argument function.
|
||||||
@function.Defun(capture_by_value=True)
|
@function.Defun(capture_by_value=True)
|
||||||
def _make_dataset():
|
def _make_dataset():
|
||||||
|
"""Factory function for a dataset."""
|
||||||
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
# NOTE(mrry): `Defun` does not capture the graph-level seed from the
|
||||||
# enclosing graph, so if a graph-level seed is present we set the local
|
# enclosing graph, so if a graph-level seed is present we set the local
|
||||||
# graph seed based on a combination of the graph- and op-level seeds.
|
# graph seed based on a combination of the graph- and op-level seeds.
|
||||||
@ -1777,7 +1779,8 @@ class StructuredFunctionWrapper(object):
|
|||||||
input_shapes=None,
|
input_shapes=None,
|
||||||
input_types=None,
|
input_types=None,
|
||||||
add_to_graph=True,
|
add_to_graph=True,
|
||||||
experimental_nested_dataset_support=False):
|
experimental_nested_dataset_support=False,
|
||||||
|
defun_kwargs=None):
|
||||||
"""Creates a new `StructuredFunctionWrapper` for the given function.
|
"""Creates a new `StructuredFunctionWrapper` for the given function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1798,6 +1801,9 @@ class StructuredFunctionWrapper(object):
|
|||||||
default graph.
|
default graph.
|
||||||
experimental_nested_dataset_support: (Optional.) If `True`, the function
|
experimental_nested_dataset_support: (Optional.) If `True`, the function
|
||||||
will support `tf.data.Dataset` objects as arguments and return values.
|
will support `tf.data.Dataset` objects as arguments and return values.
|
||||||
|
defun_kwargs: (Optional.) A dictionary mapping string argument names to
|
||||||
|
values. If supplied, will be passed to `function.Defun()` as keyword
|
||||||
|
arguments.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If an invalid combination of `dataset`, `input_classes`,
|
ValueError: If an invalid combination of `dataset`, `input_classes`,
|
||||||
@ -1832,7 +1838,11 @@ class StructuredFunctionWrapper(object):
|
|||||||
# TODO(b/110122868): Enable this support for all `tf.data` functions.
|
# TODO(b/110122868): Enable this support for all `tf.data` functions.
|
||||||
self._nested_dataset_support = experimental_nested_dataset_support
|
self._nested_dataset_support = experimental_nested_dataset_support
|
||||||
|
|
||||||
@function.Defun(*self._defun_args(), func_name=self._func_name)
|
if defun_kwargs is None:
|
||||||
|
defun_kwargs = {}
|
||||||
|
|
||||||
|
@function.Defun(
|
||||||
|
*self._defun_args(), func_name=self._func_name, **defun_kwargs)
|
||||||
def tf_data_structured_function_wrapper(*args):
|
def tf_data_structured_function_wrapper(*args):
|
||||||
"""Wrapper for passing nested structures to and from tf.data functions."""
|
"""Wrapper for passing nested structures to and from tf.data functions."""
|
||||||
flat_args = []
|
flat_args = []
|
||||||
|
Loading…
Reference in New Issue
Block a user