Refactor MapDefunOp

This commit is contained in:
Fei Hu 2019-07-15 23:16:25 -07:00
parent e313f0b369
commit 283e298510
3 changed files with 361 additions and 278 deletions

View File

@ -1264,6 +1264,7 @@ tf_kernel_library(
tf_kernel_library(
name = "map_defun_op",
srcs = ["map_defun_op.cc"],
hdrs = ["map_defun_op.h"],
deps = [
":dataset_utils",
"//tensorflow/core:core_cpu_internal",
@ -1273,3 +1274,24 @@ tf_kernel_library(
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "map_defun_op_test",
size = "small",
srcs = ["map_defun_op_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":map_defun_op",
":stats_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
],
)

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/map_defun_op.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -27,299 +28,283 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
// This op runs a given defun on slices of the input arguments. The function
// given by "f" is assumed to be stateless, and is executed concurrently
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
// functions will be scheduled at once.
//
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
// limit the intra op parallelism. To limit inter-op parallelism, a user
// can set a private threadpool on the dataset using `tf.data.Options`'s
// `ThreadingOptions`.
//
// Note that this op is not exposed to users directly, but is invoked in
// tf.data rewrites.
class MapDefunOp : public AsyncOpKernel {
/* static */ constexpr const char* const MapDefunOp::kArguments;
/* static */ constexpr const char* const MapDefunOp::kCapturedInputs;
/* static */ constexpr const char* const MapDefunOp::kTarguments;
/* static */ constexpr const char* const MapDefunOp::kTcaptured;
/* static */ constexpr const char* const MapDefunOp::kOutputTypes;
/* static */ constexpr const char* const MapDefunOp::kOutputShapes;
/* static */ constexpr const char* const MapDefunOp::kFunc;
/* static */ constexpr const char* const MapDefunOp::kMaxIntraOpParallelism;
constexpr char kOutput[] = "output";
struct MapDefunOp::ComputeOptions {
// These vary per MapDefunOp::ComputeAsync call, but must persist until
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
OpInputList args;
const std::vector<TensorShape> arg_shapes;
OpInputList captured_inputs;
const int64 batch_size;
std::function<void(std::function<void()>)> runner;
// Output of a compute call
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
OpOutputList output GUARDED_BY(mu);
mutex mu;
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
ComputeOptions(OpKernelContext* ctx, OpInputList args,
OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr,
int max_parallelism)
: args(args),
arg_shapes(std::move(arg_shapes)),
captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {
if (max_parallelism >= 1) {
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
}
}
};
class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
public:
explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto func_lib = ctx->function_library();
OP_REQUIRES(ctx, func_lib != nullptr,
errors::Internal("No function library."));
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func));
OP_REQUIRES_OK(ctx,
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
&func_handle_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
&max_intra_op_parallelism_));
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
size_t iter)
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
errors::InvalidArgument("Must have at least one input."));
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
errors::InvalidArgument("Must have at least one output."));
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
errors::InvalidArgument(
"Length of output_shapes and output_types must match."));
~MapFunctionCallFrame() override = default;
size_t num_args() const override { return compute_opts_->args.size(); }
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
~MapDefunOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
Status s = SetupOutputs(ctx, compute_opts);
if (!s.ok()) delete compute_opts;
OP_REQUIRES_OK_ASYNC(ctx, s, done);
FunctionLibraryRuntime::Options opts;
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, ComputeOptions* compute_opts,
DoneCallback& done, const Status& status) {
delete compute_opts;
ctx->SetStatus(status);
done();
},
ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
CancellationManager* parent_mgr = ctx->cancellation_manager();
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
// We use a different cancellation manager each time the function is run
// to avoid the race condition between a function run error and other
// functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager();
CancellationToken token = parent_mgr->get_cancellation_token();
const bool success = parent_mgr->RegisterCallback(
token, [c_mgr]() { c_mgr->StartCancel(); });
opts.cancellation_manager = c_mgr;
if (!success) {
delete c_mgr;
refcounted->UpdateStatus(errors::Cancelled(
"MapDefunOp functions cancelled because parent graph cancelled"));
break;
}
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
refcounted->Ref();
ctx->function_library()->Run(opts, func_handle_, call_frame,
[call_frame, refcounted, c_mgr, parent_mgr,
token](const Status& func_status) {
parent_mgr->DeregisterCallback(token);
delete c_mgr;
delete call_frame;
refcounted->UpdateStatus(func_status);
refcounted->Unref();
});
Status GetArg(int index, Tensor* val) const override {
if (index < 0 || index >= compute_opts_->args.size() +
compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument("Mismatch in number of function inputs.");
}
// Unref 1 because refcounted is initialized with refcount = 1
refcounted->Unref();
}
private:
struct ComputeOptions {
// These vary per MapDefunOp::ComputeAsync call, but must persist until
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
OpInputList args;
const std::vector<TensorShape> arg_shapes;
OpInputList captured_inputs;
const int64 batch_size;
std::function<void(std::function<void()>)> runner;
// Output of a compute call
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
OpOutputList output GUARDED_BY(mu);
mutex mu;
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
ComputeOptions(OpKernelContext* ctx, OpInputList args,
OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr,
int max_parallelism)
: args(args),
arg_shapes(std::move(arg_shapes)),
captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {
if (max_parallelism >= 1) {
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
}
}
};
class MapFunctionCallFrame : public CallFrameInterface {
public:
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
size_t iter)
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
size_t num_args() const override { return compute_opts_->args.size(); }
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
if (index < 0 || index >= compute_opts_->args.size() +
compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
if (index >= compute_opts_->args.size()) {
// The function is calling for a captured input
*val =
compute_opts_->captured_inputs[index - compute_opts_->args.size()];
return Status::OK();
}
bool result =
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
if (index >= compute_opts_->args.size()) {
// The function is calling for a captured input
*val = compute_opts_->captured_inputs[index - compute_opts_->args.size()];
return Status::OK();
}
Status SetRetval(int index, const Tensor& val) override {
if (index < 0 || index >= kernel_->num_outputs()) {
return errors::InvalidArgument(
"Mismatch in number of function outputs.");
}
if (val.dtype() != kernel_->output_type(index)) {
return errors::InvalidArgument(
"Mismatch in function return type and expected output type for "
"output: ",
index);
}
Tensor* out;
{ // Locking scope
mutex_lock l(compute_opts_->mu);
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
", and expected output shape, ",
compute_opts_->output_shapes.at(index).DebugString(), ".");
}
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
compute_opts_->output_shapes.at(index) = val.shape();
TensorShape actual_shape = val.shape();
actual_shape.InsertDim(0, compute_opts_->batch_size);
TF_RETURN_IF_ERROR(
compute_opts_->output.allocate(index, actual_shape, &out));
} else {
out = (compute_opts_->output)[index];
}
}
return batch_util::CopyElementToSlice(val, out, iter_);
}
private:
ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
}; // MapFunctionCallFrame
void SetRunOptions(OpKernelContext* ctx,
FunctionLibraryRuntime::Options* opts,
ComputeOptions* compute_opts, bool always_collect_stats) {
opts->rendezvous = ctx->rendezvous();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
if (max_intra_op_parallelism_ >= 1) {
opts->runner = &compute_opts->runner;
} else {
opts->runner = ctx->runner();
}
}
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
OpInputList arguments;
TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
OpInputList captured_inputs;
TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
for (size_t i = 0; i < arguments.size(); ++i) {
if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
} else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
", while all previous inputs have leading dimension ", batch_size);
}
}
std::vector<TensorShape> arg_shapes;
arg_shapes.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
*compute_opts = new ComputeOptions(
ctx, arguments, captured_inputs, std::move(arg_shapes), batch_size,
output_shapes_, max_intra_op_parallelism_);
return Status::OK();
}
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
mutex_lock l(opts->mu);
TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
for (size_t i = 0; i < output_types().size(); ++i) {
if (output_shapes_.at(i).IsFullyDefined()) {
Tensor* out = nullptr;
TensorShape output_shape;
output_shapes_.at(i).AsTensorShape(&output_shape);
output_shape.InsertDim(0, opts->batch_size);
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
}
bool result =
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
return Status::OK();
}
FunctionLibraryRuntime::Handle func_handle_;
std::vector<PartialTensorShape> output_shapes_;
// If this value is positive, limit the max intra op parallelism when the
// function is run on slices of the input.
int max_intra_op_parallelism_;
}; // MapDefunOp
Status SetRetval(int index, const Tensor& val) override {
if (index < 0 || index >= kernel_->num_outputs()) {
return errors::InvalidArgument("Mismatch in number of function outputs.");
}
if (val.dtype() != kernel_->output_type(index)) {
return errors::InvalidArgument(
"Mismatch in function return type and expected output type for "
"output: ",
index);
}
Tensor* out;
{ // Locking scope
mutex_lock l(compute_opts_->mu);
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
", and expected output shape, ",
compute_opts_->output_shapes.at(index).DebugString(), ".");
}
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
compute_opts_->output_shapes.at(index) = val.shape();
TensorShape actual_shape = val.shape();
actual_shape.InsertDim(0, compute_opts_->batch_size);
TF_RETURN_IF_ERROR(
compute_opts_->output.allocate(index, actual_shape, &out));
} else {
out = (compute_opts_->output)[index];
}
}
return batch_util::CopyElementToSlice(val, out, iter_);
}
private:
ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
};
MapDefunOp::MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto func_lib = ctx->function_library();
OP_REQUIRES(ctx, func_lib != nullptr,
errors::Internal("No function library."));
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kFunc, &func));
OP_REQUIRES_OK(ctx,
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
&func_handle_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kMaxIntraOpParallelism, &max_intra_op_parallelism_));
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
errors::InvalidArgument("Must have at least one input."));
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
errors::InvalidArgument("Must have at least one output."));
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
errors::InvalidArgument(
"Length of output_shapes and output_types must match."));
}
void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
ComputeOptions* compute_opts = nullptr;
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
Status s = SetupOutputs(ctx, compute_opts);
if (!s.ok()) delete compute_opts;
OP_REQUIRES_OK_ASYNC(ctx, s, done);
FunctionLibraryRuntime::Options opts;
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done,
const Status& status) {
delete compute_opts;
ctx->SetStatus(status);
done();
},
ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
CancellationManager* parent_mgr = ctx->cancellation_manager();
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
// We use a different cancellation manager each time the function is run
// to avoid the race condition between a function run error and other
// functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager();
CancellationToken token = parent_mgr->get_cancellation_token();
const bool success = parent_mgr->RegisterCallback(
token, [c_mgr]() { c_mgr->StartCancel(); });
opts.cancellation_manager = c_mgr;
if (!success) {
delete c_mgr;
refcounted->UpdateStatus(errors::Cancelled(
"MapDefunOp functions cancelled because parent graph cancelled"));
break;
}
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
refcounted->Ref();
ctx->function_library()->Run(opts, func_handle_, call_frame,
[call_frame, refcounted, c_mgr, parent_mgr,
token](const Status& func_status) {
parent_mgr->DeregisterCallback(token);
delete c_mgr;
delete call_frame;
refcounted->UpdateStatus(func_status);
refcounted->Unref();
});
}
// Unref 1 because refcounted is initialized with refcount = 1
refcounted->Unref();
}
void MapDefunOp::SetRunOptions(OpKernelContext* ctx,
FunctionLibraryRuntime::Options* opts,
ComputeOptions* compute_opts,
bool always_collect_stats) {
opts->rendezvous = ctx->rendezvous();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
if (max_intra_op_parallelism_ >= 1) {
opts->runner = &compute_opts->runner;
} else {
opts->runner = ctx->runner();
}
}
Status MapDefunOp::SetupArgs(OpKernelContext* ctx,
ComputeOptions** compute_opts) {
OpInputList arguments;
TF_RETURN_IF_ERROR(ctx->input_list(kArguments, &arguments));
OpInputList captured_inputs;
TF_RETURN_IF_ERROR(ctx->input_list(kCapturedInputs, &captured_inputs));
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
for (size_t i = 0; i < arguments.size(); ++i) {
if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
} else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
", while all previous inputs have leading dimension ", batch_size);
}
}
std::vector<TensorShape> arg_shapes;
arg_shapes.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
*compute_opts =
new ComputeOptions(ctx, arguments, captured_inputs, std::move(arg_shapes),
batch_size, output_shapes_, max_intra_op_parallelism_);
return Status::OK();
}
Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
mutex_lock l(opts->mu);
TF_RETURN_IF_ERROR(ctx->output_list(kOutput, &opts->output));
for (size_t i = 0; i < output_types().size(); ++i) {
if (output_shapes_.at(i).IsFullyDefined()) {
Tensor* out = nullptr;
TensorShape output_shape;
output_shapes_.at(i).AsTensorShape(&output_shape);
output_shape.InsertDim(0, opts->batch_size);
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
}
}
return Status::OK();
}
namespace {
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
} // namespace
} // namespace data

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
// This op runs a given defun on slices of the input arguments. The function
// given by "f" is assumed to be stateless, and is executed concurrently
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
// functions will be scheduled at once.
//
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
// limit the intra op parallelism. To limit inter-op parallelism, a user
// can set a private threadpool on the dataset using `tf.data.Options`'s
// `ThreadingOptions`.
//
// Note that this op is not exposed to users directly, but is invoked in
// tf.data rewrites.
class MapDefunOp : public AsyncOpKernel {
public:
static constexpr const char* const kArguments = "arguments";
static constexpr const char* const kCapturedInputs = "captured_inputs";
static constexpr const char* const kTarguments = "Targuments";
static constexpr const char* const kTcaptured = "Tcaptured";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kFunc = "f";
static constexpr const char* const kMaxIntraOpParallelism =
"max_intra_op_parallelism";
explicit MapDefunOp(OpKernelConstruction* ctx);
~MapDefunOp() override = default;
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
struct ComputeOptions;
class MapFunctionCallFrame;
void SetRunOptions(OpKernelContext* ctx,
FunctionLibraryRuntime::Options* opts,
ComputeOptions* compute_opts, bool always_collect_stats);
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts);
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts);
FunctionLibraryRuntime::Handle func_handle_;
std::vector<PartialTensorShape> output_shapes_;
// If this value is positive, limit the max intra op parallelism when the
// function is run on slices of the input.
int max_intra_op_parallelism_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_