Refactor MapDefunOp
This commit is contained in:
parent
e313f0b369
commit
283e298510
@ -1264,6 +1264,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "map_defun_op",
|
||||
srcs = ["map_defun_op.cc"],
|
||||
hdrs = ["map_defun_op.h"],
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -1273,3 +1274,24 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "map_defun_op_test",
|
||||
size = "small",
|
||||
srcs = ["map_defun_op_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":map_defun_op",
|
||||
":stats_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
],
|
||||
)
|
||||
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/map_defun_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -27,299 +28,283 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
// This op runs a given defun on slices of the input arguments. The function
|
||||
// given by "f" is assumed to be stateless, and is executed concurrently
|
||||
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
|
||||
// functions will be scheduled at once.
|
||||
//
|
||||
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
|
||||
// limit the intra op parallelism. To limit inter-op parallelism, a user
|
||||
// can set a private threadpool on the dataset using `tf.data.Options`'s
|
||||
// `ThreadingOptions`.
|
||||
//
|
||||
// Note that this op is not exposed to users directly, but is invoked in
|
||||
// tf.data rewrites.
|
||||
class MapDefunOp : public AsyncOpKernel {
|
||||
/* static */ constexpr const char* const MapDefunOp::kArguments;
|
||||
/* static */ constexpr const char* const MapDefunOp::kCapturedInputs;
|
||||
/* static */ constexpr const char* const MapDefunOp::kTarguments;
|
||||
/* static */ constexpr const char* const MapDefunOp::kTcaptured;
|
||||
/* static */ constexpr const char* const MapDefunOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const MapDefunOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const MapDefunOp::kFunc;
|
||||
/* static */ constexpr const char* const MapDefunOp::kMaxIntraOpParallelism;
|
||||
|
||||
constexpr char kOutput[] = "output";
|
||||
|
||||
struct MapDefunOp::ComputeOptions {
|
||||
// These vary per MapDefunOp::ComputeAsync call, but must persist until
|
||||
// all calls to the function are complete. This struct also encapsulates
|
||||
// all the components that need to be passed to each MapFunctionCallFrame.
|
||||
OpInputList args;
|
||||
const std::vector<TensorShape> arg_shapes;
|
||||
OpInputList captured_inputs;
|
||||
const int64 batch_size;
|
||||
std::function<void(std::function<void()>)> runner;
|
||||
|
||||
// Output of a compute call
|
||||
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
|
||||
OpOutputList output GUARDED_BY(mu);
|
||||
mutex mu;
|
||||
|
||||
// Create a copy of output_shapes because every `Compute` may expect a
|
||||
// different output shape.
|
||||
ComputeOptions(OpKernelContext* ctx, OpInputList args,
|
||||
OpInputList captured_inputs,
|
||||
std::vector<TensorShape> arg_shapes, int64 batch_size,
|
||||
const std::vector<PartialTensorShape>& output_shapes_attr,
|
||||
int max_parallelism)
|
||||
: args(args),
|
||||
arg_shapes(std::move(arg_shapes)),
|
||||
captured_inputs(captured_inputs),
|
||||
batch_size(batch_size),
|
||||
output_shapes(output_shapes_attr) {
|
||||
if (max_parallelism >= 1) {
|
||||
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
|
||||
public:
|
||||
explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
auto func_lib = ctx->function_library();
|
||||
OP_REQUIRES(ctx, func_lib != nullptr,
|
||||
errors::Internal("No function library."));
|
||||
const NameAttrList* func;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
|
||||
&func_handle_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
|
||||
&max_intra_op_parallelism_));
|
||||
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
|
||||
size_t iter)
|
||||
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
|
||||
|
||||
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one input."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one output."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"Length of output_shapes and output_types must match."));
|
||||
~MapFunctionCallFrame() override = default;
|
||||
|
||||
size_t num_args() const override { return compute_opts_->args.size(); }
|
||||
|
||||
size_t num_retvals() const override {
|
||||
return static_cast<size_t>(kernel_->num_outputs());
|
||||
}
|
||||
|
||||
~MapDefunOp() override {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
ComputeOptions* compute_opts = nullptr;
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
|
||||
|
||||
Status s = SetupOutputs(ctx, compute_opts);
|
||||
if (!s.ok()) delete compute_opts;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
|
||||
|
||||
// Run loop
|
||||
StatusCallback callback = std::bind(
|
||||
[](OpKernelContext* ctx, ComputeOptions* compute_opts,
|
||||
DoneCallback& done, const Status& status) {
|
||||
delete compute_opts;
|
||||
ctx->SetStatus(status);
|
||||
done();
|
||||
},
|
||||
ctx, compute_opts, std::move(done), std::placeholders::_1);
|
||||
|
||||
auto* refcounted = new ReffedStatusCallback(std::move(callback));
|
||||
|
||||
CancellationManager* parent_mgr = ctx->cancellation_manager();
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
|
||||
// We use a different cancellation manager each time the function is run
|
||||
// to avoid the race condition between a function run error and other
|
||||
// functions being cancelled as a result.
|
||||
CancellationManager* c_mgr = new CancellationManager();
|
||||
CancellationToken token = parent_mgr->get_cancellation_token();
|
||||
const bool success = parent_mgr->RegisterCallback(
|
||||
token, [c_mgr]() { c_mgr->StartCancel(); });
|
||||
|
||||
opts.cancellation_manager = c_mgr;
|
||||
if (!success) {
|
||||
delete c_mgr;
|
||||
refcounted->UpdateStatus(errors::Cancelled(
|
||||
"MapDefunOp functions cancelled because parent graph cancelled"));
|
||||
break;
|
||||
}
|
||||
|
||||
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
|
||||
|
||||
refcounted->Ref();
|
||||
ctx->function_library()->Run(opts, func_handle_, call_frame,
|
||||
[call_frame, refcounted, c_mgr, parent_mgr,
|
||||
token](const Status& func_status) {
|
||||
parent_mgr->DeregisterCallback(token);
|
||||
delete c_mgr;
|
||||
delete call_frame;
|
||||
refcounted->UpdateStatus(func_status);
|
||||
refcounted->Unref();
|
||||
});
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index < 0 || index >= compute_opts_->args.size() +
|
||||
compute_opts_->captured_inputs.size()) {
|
||||
return errors::InvalidArgument("Mismatch in number of function inputs.");
|
||||
}
|
||||
|
||||
// Unref 1 because refcounted is initialized with refcount = 1
|
||||
refcounted->Unref();
|
||||
}
|
||||
|
||||
private:
|
||||
struct ComputeOptions {
|
||||
// These vary per MapDefunOp::ComputeAsync call, but must persist until
|
||||
// all calls to the function are complete. This struct also encapsulates
|
||||
// all the components that need to be passed to each MapFunctionCallFrame.
|
||||
|
||||
OpInputList args;
|
||||
const std::vector<TensorShape> arg_shapes;
|
||||
OpInputList captured_inputs;
|
||||
const int64 batch_size;
|
||||
std::function<void(std::function<void()>)> runner;
|
||||
|
||||
// Output of a compute call
|
||||
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
|
||||
OpOutputList output GUARDED_BY(mu);
|
||||
mutex mu;
|
||||
|
||||
// Create a copy of output_shapes because every `Compute` may expect a
|
||||
// different output shape.
|
||||
ComputeOptions(OpKernelContext* ctx, OpInputList args,
|
||||
OpInputList captured_inputs,
|
||||
std::vector<TensorShape> arg_shapes, int64 batch_size,
|
||||
const std::vector<PartialTensorShape>& output_shapes_attr,
|
||||
int max_parallelism)
|
||||
: args(args),
|
||||
arg_shapes(std::move(arg_shapes)),
|
||||
captured_inputs(captured_inputs),
|
||||
batch_size(batch_size),
|
||||
output_shapes(output_shapes_attr) {
|
||||
if (max_parallelism >= 1) {
|
||||
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MapFunctionCallFrame : public CallFrameInterface {
|
||||
public:
|
||||
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
|
||||
size_t iter)
|
||||
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
|
||||
|
||||
~MapFunctionCallFrame() override {}
|
||||
|
||||
size_t num_args() const override { return compute_opts_->args.size(); }
|
||||
|
||||
size_t num_retvals() const override {
|
||||
return static_cast<size_t>(kernel_->num_outputs());
|
||||
}
|
||||
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index < 0 || index >= compute_opts_->args.size() +
|
||||
compute_opts_->captured_inputs.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in number of function inputs.");
|
||||
}
|
||||
|
||||
if (index >= compute_opts_->args.size()) {
|
||||
// The function is calling for a captured input
|
||||
*val =
|
||||
compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool result =
|
||||
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
if (!result) {
|
||||
return errors::Internal("GetArg failed.");
|
||||
} else if (!val->IsAligned()) {
|
||||
// Ensure alignment
|
||||
*val = tensor::DeepCopy(*val);
|
||||
}
|
||||
if (index >= compute_opts_->args.size()) {
|
||||
// The function is calling for a captured input
|
||||
*val = compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetRetval(int index, const Tensor& val) override {
|
||||
if (index < 0 || index >= kernel_->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in number of function outputs.");
|
||||
}
|
||||
|
||||
if (val.dtype() != kernel_->output_type(index)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function return type and expected output type for "
|
||||
"output: ",
|
||||
index);
|
||||
}
|
||||
Tensor* out;
|
||||
{ // Locking scope
|
||||
mutex_lock l(compute_opts_->mu);
|
||||
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
|
||||
val.shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function retval shape, ", val.shape(),
|
||||
", and expected output shape, ",
|
||||
compute_opts_->output_shapes.at(index).DebugString(), ".");
|
||||
}
|
||||
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
|
||||
// Given val, we have new information about the output shape at
|
||||
// this index. Store the shape and allocate the output accordingly.
|
||||
compute_opts_->output_shapes.at(index) = val.shape();
|
||||
|
||||
TensorShape actual_shape = val.shape();
|
||||
actual_shape.InsertDim(0, compute_opts_->batch_size);
|
||||
TF_RETURN_IF_ERROR(
|
||||
compute_opts_->output.allocate(index, actual_shape, &out));
|
||||
} else {
|
||||
out = (compute_opts_->output)[index];
|
||||
}
|
||||
}
|
||||
return batch_util::CopyElementToSlice(val, out, iter_);
|
||||
}
|
||||
|
||||
private:
|
||||
ComputeOptions* const compute_opts_; // Not owned
|
||||
const OpKernel* kernel_;
|
||||
const size_t iter_;
|
||||
}; // MapFunctionCallFrame
|
||||
|
||||
void SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts, bool always_collect_stats) {
|
||||
opts->rendezvous = ctx->rendezvous();
|
||||
if (always_collect_stats) {
|
||||
opts->stats_collector = ctx->stats_collector();
|
||||
}
|
||||
if (max_intra_op_parallelism_ >= 1) {
|
||||
opts->runner = &compute_opts->runner;
|
||||
} else {
|
||||
opts->runner = ctx->runner();
|
||||
}
|
||||
}
|
||||
|
||||
// Get inputs to Compute and check that they are valid.
|
||||
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
|
||||
OpInputList arguments;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
|
||||
OpInputList captured_inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
|
||||
|
||||
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
if (arguments[i].dims() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have rank at least 1. Input ", i,
|
||||
" has a rank of 0.");
|
||||
} else if (arguments[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
", while all previous inputs have leading dimension ", batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
arg_shapes.reserve(arguments.size());
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
arg_shapes.push_back(arguments[i].shape());
|
||||
arg_shapes.at(i).RemoveDim(0);
|
||||
}
|
||||
|
||||
*compute_opts = new ComputeOptions(
|
||||
ctx, arguments, captured_inputs, std::move(arg_shapes), batch_size,
|
||||
output_shapes_, max_intra_op_parallelism_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
|
||||
mutex_lock l(opts->mu);
|
||||
TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
|
||||
|
||||
for (size_t i = 0; i < output_types().size(); ++i) {
|
||||
if (output_shapes_.at(i).IsFullyDefined()) {
|
||||
Tensor* out = nullptr;
|
||||
TensorShape output_shape;
|
||||
output_shapes_.at(i).AsTensorShape(&output_shape);
|
||||
output_shape.InsertDim(0, opts->batch_size);
|
||||
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
|
||||
}
|
||||
bool result =
|
||||
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
if (!result) {
|
||||
return errors::Internal("GetArg failed.");
|
||||
} else if (!val->IsAligned()) {
|
||||
// Ensure alignment
|
||||
*val = tensor::DeepCopy(*val);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle func_handle_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
// If this value is positive, limit the max intra op parallelism when the
|
||||
// function is run on slices of the input.
|
||||
int max_intra_op_parallelism_;
|
||||
}; // MapDefunOp
|
||||
Status SetRetval(int index, const Tensor& val) override {
|
||||
if (index < 0 || index >= kernel_->num_outputs()) {
|
||||
return errors::InvalidArgument("Mismatch in number of function outputs.");
|
||||
}
|
||||
|
||||
if (val.dtype() != kernel_->output_type(index)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function return type and expected output type for "
|
||||
"output: ",
|
||||
index);
|
||||
}
|
||||
Tensor* out;
|
||||
{ // Locking scope
|
||||
mutex_lock l(compute_opts_->mu);
|
||||
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
|
||||
val.shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function retval shape, ", val.shape(),
|
||||
", and expected output shape, ",
|
||||
compute_opts_->output_shapes.at(index).DebugString(), ".");
|
||||
}
|
||||
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
|
||||
// Given val, we have new information about the output shape at
|
||||
// this index. Store the shape and allocate the output accordingly.
|
||||
compute_opts_->output_shapes.at(index) = val.shape();
|
||||
|
||||
TensorShape actual_shape = val.shape();
|
||||
actual_shape.InsertDim(0, compute_opts_->batch_size);
|
||||
TF_RETURN_IF_ERROR(
|
||||
compute_opts_->output.allocate(index, actual_shape, &out));
|
||||
} else {
|
||||
out = (compute_opts_->output)[index];
|
||||
}
|
||||
}
|
||||
return batch_util::CopyElementToSlice(val, out, iter_);
|
||||
}
|
||||
|
||||
private:
|
||||
ComputeOptions* const compute_opts_; // Not owned
|
||||
const OpKernel* kernel_;
|
||||
const size_t iter_;
|
||||
};
|
||||
|
||||
MapDefunOp::MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
auto func_lib = ctx->function_library();
|
||||
OP_REQUIRES(ctx, func_lib != nullptr,
|
||||
errors::Internal("No function library."));
|
||||
const NameAttrList* func;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kFunc, &func));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
|
||||
&func_handle_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kMaxIntraOpParallelism, &max_intra_op_parallelism_));
|
||||
|
||||
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one input."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one output."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"Length of output_shapes and output_types must match."));
|
||||
}
|
||||
|
||||
void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
ComputeOptions* compute_opts = nullptr;
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
|
||||
|
||||
Status s = SetupOutputs(ctx, compute_opts);
|
||||
if (!s.ok()) delete compute_opts;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
|
||||
|
||||
// Run loop
|
||||
StatusCallback callback = std::bind(
|
||||
[](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done,
|
||||
const Status& status) {
|
||||
delete compute_opts;
|
||||
ctx->SetStatus(status);
|
||||
done();
|
||||
},
|
||||
ctx, compute_opts, std::move(done), std::placeholders::_1);
|
||||
|
||||
auto* refcounted = new ReffedStatusCallback(std::move(callback));
|
||||
|
||||
CancellationManager* parent_mgr = ctx->cancellation_manager();
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
|
||||
// We use a different cancellation manager each time the function is run
|
||||
// to avoid the race condition between a function run error and other
|
||||
// functions being cancelled as a result.
|
||||
CancellationManager* c_mgr = new CancellationManager();
|
||||
CancellationToken token = parent_mgr->get_cancellation_token();
|
||||
const bool success = parent_mgr->RegisterCallback(
|
||||
token, [c_mgr]() { c_mgr->StartCancel(); });
|
||||
|
||||
opts.cancellation_manager = c_mgr;
|
||||
if (!success) {
|
||||
delete c_mgr;
|
||||
refcounted->UpdateStatus(errors::Cancelled(
|
||||
"MapDefunOp functions cancelled because parent graph cancelled"));
|
||||
break;
|
||||
}
|
||||
|
||||
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
|
||||
|
||||
refcounted->Ref();
|
||||
ctx->function_library()->Run(opts, func_handle_, call_frame,
|
||||
[call_frame, refcounted, c_mgr, parent_mgr,
|
||||
token](const Status& func_status) {
|
||||
parent_mgr->DeregisterCallback(token);
|
||||
delete c_mgr;
|
||||
delete call_frame;
|
||||
refcounted->UpdateStatus(func_status);
|
||||
refcounted->Unref();
|
||||
});
|
||||
}
|
||||
|
||||
// Unref 1 because refcounted is initialized with refcount = 1
|
||||
refcounted->Unref();
|
||||
}
|
||||
|
||||
void MapDefunOp::SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts,
|
||||
bool always_collect_stats) {
|
||||
opts->rendezvous = ctx->rendezvous();
|
||||
if (always_collect_stats) {
|
||||
opts->stats_collector = ctx->stats_collector();
|
||||
}
|
||||
if (max_intra_op_parallelism_ >= 1) {
|
||||
opts->runner = &compute_opts->runner;
|
||||
} else {
|
||||
opts->runner = ctx->runner();
|
||||
}
|
||||
}
|
||||
|
||||
Status MapDefunOp::SetupArgs(OpKernelContext* ctx,
|
||||
ComputeOptions** compute_opts) {
|
||||
OpInputList arguments;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list(kArguments, &arguments));
|
||||
OpInputList captured_inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list(kCapturedInputs, &captured_inputs));
|
||||
|
||||
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
if (arguments[i].dims() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have rank at least 1. Input ", i,
|
||||
" has a rank of 0.");
|
||||
} else if (arguments[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
", while all previous inputs have leading dimension ", batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
arg_shapes.reserve(arguments.size());
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
arg_shapes.push_back(arguments[i].shape());
|
||||
arg_shapes.at(i).RemoveDim(0);
|
||||
}
|
||||
|
||||
*compute_opts =
|
||||
new ComputeOptions(ctx, arguments, captured_inputs, std::move(arg_shapes),
|
||||
batch_size, output_shapes_, max_intra_op_parallelism_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
|
||||
mutex_lock l(opts->mu);
|
||||
TF_RETURN_IF_ERROR(ctx->output_list(kOutput, &opts->output));
|
||||
|
||||
for (size_t i = 0; i < output_types().size(); ++i) {
|
||||
if (output_shapes_.at(i).IsFullyDefined()) {
|
||||
Tensor* out = nullptr;
|
||||
TensorShape output_shape;
|
||||
output_shapes_.at(i).AsTensorShape(&output_shape);
|
||||
output_shape.InsertDim(0, opts->batch_size);
|
||||
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
|
76
tensorflow/core/kernels/data/map_defun_op.h
Normal file
76
tensorflow/core/kernels/data/map_defun_op.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// This op runs a given defun on slices of the input arguments. The function
|
||||
// given by "f" is assumed to be stateless, and is executed concurrently
|
||||
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
|
||||
// functions will be scheduled at once.
|
||||
//
|
||||
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
|
||||
// limit the intra op parallelism. To limit inter-op parallelism, a user
|
||||
// can set a private threadpool on the dataset using `tf.data.Options`'s
|
||||
// `ThreadingOptions`.
|
||||
//
|
||||
// Note that this op is not exposed to users directly, but is invoked in
|
||||
// tf.data rewrites.
|
||||
class MapDefunOp : public AsyncOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kArguments = "arguments";
|
||||
static constexpr const char* const kCapturedInputs = "captured_inputs";
|
||||
static constexpr const char* const kTarguments = "Targuments";
|
||||
static constexpr const char* const kTcaptured = "Tcaptured";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kFunc = "f";
|
||||
static constexpr const char* const kMaxIntraOpParallelism =
|
||||
"max_intra_op_parallelism";
|
||||
|
||||
explicit MapDefunOp(OpKernelConstruction* ctx);
|
||||
|
||||
~MapDefunOp() override = default;
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
private:
|
||||
struct ComputeOptions;
|
||||
class MapFunctionCallFrame;
|
||||
|
||||
void SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts, bool always_collect_stats);
|
||||
|
||||
// Get inputs to Compute and check that they are valid.
|
||||
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts);
|
||||
|
||||
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts);
|
||||
|
||||
FunctionLibraryRuntime::Handle func_handle_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
// If this value is positive, limit the max intra op parallelism when the
|
||||
// function is run on slices of the input.
|
||||
int max_intra_op_parallelism_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
|
Loading…
Reference in New Issue
Block a user