Merge pull request #30739 from feihugis:Refactor_DatasetOps_9

PiperOrigin-RevId: 258562390
This commit is contained in:
TensorFlower Gardener 2019-07-17 07:20:17 -07:00
commit 4fd6623585
7 changed files with 752 additions and 350 deletions

View File

@ -1181,6 +1181,7 @@ tf_cc_test(
tf_kernel_library(
name = "optimize_dataset_op",
srcs = ["optimize_dataset_op.cc"],
hdrs = ["optimize_dataset_op.h"],
deps = [
":dataset_utils",
"//tensorflow/core:core_cpu_internal",
@ -1284,6 +1285,7 @@ tf_kernel_library(
tf_kernel_library(
name = "map_defun_op",
srcs = ["map_defun_op.cc"],
hdrs = ["map_defun_op.h"],
deps = [
":dataset_utils",
"//tensorflow/core:core_cpu_internal",
@ -1293,3 +1295,24 @@ tf_kernel_library(
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "map_defun_op_test",
size = "small",
srcs = ["map_defun_op_test.cc"],
deps = [
":dataset_test_base",
":dataset_utils",
":map_defun_op",
":stats_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
],
)

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/map_defun_op.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -27,299 +28,283 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
// This op runs a given defun on slices of the input arguments. The function
// given by "f" is assumed to be stateless, and is executed concurrently
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
// functions will be scheduled at once.
//
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
// limit the intra op parallelism. To limit inter-op parallelism, a user
// can set a private threadpool on the dataset using `tf.data.Options`'s
// `ThreadingOptions`.
//
// Note that this op is not exposed to users directly, but is invoked in
// tf.data rewrites.
class MapDefunOp : public AsyncOpKernel {
/* static */ constexpr const char* const MapDefunOp::kArguments;
/* static */ constexpr const char* const MapDefunOp::kCapturedInputs;
/* static */ constexpr const char* const MapDefunOp::kTarguments;
/* static */ constexpr const char* const MapDefunOp::kTcaptured;
/* static */ constexpr const char* const MapDefunOp::kOutputTypes;
/* static */ constexpr const char* const MapDefunOp::kOutputShapes;
/* static */ constexpr const char* const MapDefunOp::kFunc;
/* static */ constexpr const char* const MapDefunOp::kMaxIntraOpParallelism;
constexpr char kOutput[] = "output";
struct MapDefunOp::ComputeOptions {
// These vary per MapDefunOp::ComputeAsync call, but must persist until
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
OpInputList args;
const std::vector<TensorShape> arg_shapes;
OpInputList captured_inputs;
const int64 batch_size;
std::function<void(std::function<void()>)> runner;
// Output of a compute call
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
OpOutputList output GUARDED_BY(mu);
mutex mu;
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
ComputeOptions(OpKernelContext* ctx, OpInputList args,
OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr,
int max_parallelism)
: args(args),
arg_shapes(std::move(arg_shapes)),
captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {
if (max_parallelism >= 1) {
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
}
}
};
class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
public:
explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto func_lib = ctx->function_library();
OP_REQUIRES(ctx, func_lib != nullptr,
errors::Internal("No function library."));
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func));
OP_REQUIRES_OK(ctx,
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
&func_handle_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
&max_intra_op_parallelism_));
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
size_t iter)
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
errors::InvalidArgument("Must have at least one input."));
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
errors::InvalidArgument("Must have at least one output."));
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
errors::InvalidArgument(
"Length of output_shapes and output_types must match."));
~MapFunctionCallFrame() override = default;
size_t num_args() const override { return compute_opts_->args.size(); }
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
~MapDefunOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
Status s = SetupOutputs(ctx, compute_opts);
if (!s.ok()) delete compute_opts;
OP_REQUIRES_OK_ASYNC(ctx, s, done);
FunctionLibraryRuntime::Options opts;
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, ComputeOptions* compute_opts,
DoneCallback& done, const Status& status) {
delete compute_opts;
ctx->SetStatus(status);
done();
},
ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
CancellationManager* parent_mgr = ctx->cancellation_manager();
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
// We use a different cancellation manager each time the function is run
// to avoid the race condition between a function run error and other
// functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager();
CancellationToken token = parent_mgr->get_cancellation_token();
const bool success = parent_mgr->RegisterCallback(
token, [c_mgr]() { c_mgr->StartCancel(); });
opts.cancellation_manager = c_mgr;
if (!success) {
delete c_mgr;
refcounted->UpdateStatus(errors::Cancelled(
"MapDefunOp functions cancelled because parent graph cancelled"));
break;
}
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
refcounted->Ref();
ctx->function_library()->Run(opts, func_handle_, call_frame,
[call_frame, refcounted, c_mgr, parent_mgr,
token](const Status& func_status) {
parent_mgr->DeregisterCallback(token);
delete c_mgr;
delete call_frame;
refcounted->UpdateStatus(func_status);
refcounted->Unref();
});
Status GetArg(int index, Tensor* val) const override {
if (index < 0 || index >= compute_opts_->args.size() +
compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument("Mismatch in number of function inputs.");
}
// Unref 1 because refcounted is initialized with refcount = 1
refcounted->Unref();
}
private:
struct ComputeOptions {
// These vary per MapDefunOp::ComputeAsync call, but must persist until
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
OpInputList args;
const std::vector<TensorShape> arg_shapes;
OpInputList captured_inputs;
const int64 batch_size;
std::function<void(std::function<void()>)> runner;
// Output of a compute call
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
OpOutputList output GUARDED_BY(mu);
mutex mu;
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
ComputeOptions(OpKernelContext* ctx, OpInputList args,
OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr,
int max_parallelism)
: args(args),
arg_shapes(std::move(arg_shapes)),
captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {
if (max_parallelism >= 1) {
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
}
}
};
class MapFunctionCallFrame : public CallFrameInterface {
public:
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
size_t iter)
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
size_t num_args() const override { return compute_opts_->args.size(); }
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
if (index < 0 || index >= compute_opts_->args.size() +
compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
if (index >= compute_opts_->args.size()) {
// The function is calling for a captured input
*val =
compute_opts_->captured_inputs[index - compute_opts_->args.size()];
return Status::OK();
}
bool result =
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
if (index >= compute_opts_->args.size()) {
// The function is calling for a captured input
*val = compute_opts_->captured_inputs[index - compute_opts_->args.size()];
return Status::OK();
}
Status SetRetval(int index, const Tensor& val) override {
if (index < 0 || index >= kernel_->num_outputs()) {
return errors::InvalidArgument(
"Mismatch in number of function outputs.");
}
if (val.dtype() != kernel_->output_type(index)) {
return errors::InvalidArgument(
"Mismatch in function return type and expected output type for "
"output: ",
index);
}
Tensor* out;
{ // Locking scope
mutex_lock l(compute_opts_->mu);
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
", and expected output shape, ",
compute_opts_->output_shapes.at(index).DebugString(), ".");
}
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
compute_opts_->output_shapes.at(index) = val.shape();
TensorShape actual_shape = val.shape();
actual_shape.InsertDim(0, compute_opts_->batch_size);
TF_RETURN_IF_ERROR(
compute_opts_->output.allocate(index, actual_shape, &out));
} else {
out = (compute_opts_->output)[index];
}
}
return batch_util::CopyElementToSlice(val, out, iter_);
}
private:
ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
}; // MapFunctionCallFrame
void SetRunOptions(OpKernelContext* ctx,
FunctionLibraryRuntime::Options* opts,
ComputeOptions* compute_opts, bool always_collect_stats) {
opts->rendezvous = ctx->rendezvous();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
if (max_intra_op_parallelism_ >= 1) {
opts->runner = &compute_opts->runner;
} else {
opts->runner = ctx->runner();
}
}
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
OpInputList arguments;
TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
OpInputList captured_inputs;
TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
for (size_t i = 0; i < arguments.size(); ++i) {
if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
} else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
", while all previous inputs have leading dimension ", batch_size);
}
}
std::vector<TensorShape> arg_shapes;
arg_shapes.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
*compute_opts = new ComputeOptions(
ctx, arguments, captured_inputs, std::move(arg_shapes), batch_size,
output_shapes_, max_intra_op_parallelism_);
return Status::OK();
}
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
mutex_lock l(opts->mu);
TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
for (size_t i = 0; i < output_types().size(); ++i) {
if (output_shapes_.at(i).IsFullyDefined()) {
Tensor* out = nullptr;
TensorShape output_shape;
output_shapes_.at(i).AsTensorShape(&output_shape);
output_shape.InsertDim(0, opts->batch_size);
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
}
bool result =
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
return Status::OK();
}
FunctionLibraryRuntime::Handle func_handle_;
std::vector<PartialTensorShape> output_shapes_;
// If this value is positive, limit the max intra op parallelism when the
// function is run on slices of the input.
int max_intra_op_parallelism_;
}; // MapDefunOp
Status SetRetval(int index, const Tensor& val) override {
if (index < 0 || index >= kernel_->num_outputs()) {
return errors::InvalidArgument("Mismatch in number of function outputs.");
}
if (val.dtype() != kernel_->output_type(index)) {
return errors::InvalidArgument(
"Mismatch in function return type and expected output type for "
"output: ",
index);
}
Tensor* out;
{ // Locking scope
mutex_lock l(compute_opts_->mu);
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
", and expected output shape, ",
compute_opts_->output_shapes.at(index).DebugString(), ".");
}
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
compute_opts_->output_shapes.at(index) = val.shape();
TensorShape actual_shape = val.shape();
actual_shape.InsertDim(0, compute_opts_->batch_size);
TF_RETURN_IF_ERROR(
compute_opts_->output.allocate(index, actual_shape, &out));
} else {
out = (compute_opts_->output)[index];
}
}
return batch_util::CopyElementToSlice(val, out, iter_);
}
private:
ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
};
MapDefunOp::MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto func_lib = ctx->function_library();
OP_REQUIRES(ctx, func_lib != nullptr,
errors::Internal("No function library."));
const NameAttrList* func;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kFunc, &func));
OP_REQUIRES_OK(ctx,
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
&func_handle_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kMaxIntraOpParallelism, &max_intra_op_parallelism_));
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
errors::InvalidArgument("Must have at least one input."));
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
errors::InvalidArgument("Must have at least one output."));
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
errors::InvalidArgument(
"Length of output_shapes and output_types must match."));
}
void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
ComputeOptions* compute_opts = nullptr;
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
Status s = SetupOutputs(ctx, compute_opts);
if (!s.ok()) delete compute_opts;
OP_REQUIRES_OK_ASYNC(ctx, s, done);
FunctionLibraryRuntime::Options opts;
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
// Run loop
StatusCallback callback = std::bind(
[](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done,
const Status& status) {
delete compute_opts;
ctx->SetStatus(status);
done();
},
ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
CancellationManager* parent_mgr = ctx->cancellation_manager();
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
// We use a different cancellation manager each time the function is run
// to avoid the race condition between a function run error and other
// functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager();
CancellationToken token = parent_mgr->get_cancellation_token();
const bool success = parent_mgr->RegisterCallback(
token, [c_mgr]() { c_mgr->StartCancel(); });
opts.cancellation_manager = c_mgr;
if (!success) {
delete c_mgr;
refcounted->UpdateStatus(errors::Cancelled(
"MapDefunOp functions cancelled because parent graph cancelled"));
break;
}
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
refcounted->Ref();
ctx->function_library()->Run(opts, func_handle_, call_frame,
[call_frame, refcounted, c_mgr, parent_mgr,
token](const Status& func_status) {
parent_mgr->DeregisterCallback(token);
delete c_mgr;
delete call_frame;
refcounted->UpdateStatus(func_status);
refcounted->Unref();
});
}
// Unref 1 because refcounted is initialized with refcount = 1
refcounted->Unref();
}
void MapDefunOp::SetRunOptions(OpKernelContext* ctx,
FunctionLibraryRuntime::Options* opts,
ComputeOptions* compute_opts,
bool always_collect_stats) {
opts->rendezvous = ctx->rendezvous();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
if (max_intra_op_parallelism_ >= 1) {
opts->runner = &compute_opts->runner;
} else {
opts->runner = ctx->runner();
}
}
Status MapDefunOp::SetupArgs(OpKernelContext* ctx,
ComputeOptions** compute_opts) {
OpInputList arguments;
TF_RETURN_IF_ERROR(ctx->input_list(kArguments, &arguments));
OpInputList captured_inputs;
TF_RETURN_IF_ERROR(ctx->input_list(kCapturedInputs, &captured_inputs));
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
for (size_t i = 0; i < arguments.size(); ++i) {
if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
} else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
", while all previous inputs have leading dimension ", batch_size);
}
}
std::vector<TensorShape> arg_shapes;
arg_shapes.reserve(arguments.size());
for (size_t i = 0; i < arguments.size(); ++i) {
arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
*compute_opts =
new ComputeOptions(ctx, arguments, captured_inputs, std::move(arg_shapes),
batch_size, output_shapes_, max_intra_op_parallelism_);
return Status::OK();
}
Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
mutex_lock l(opts->mu);
TF_RETURN_IF_ERROR(ctx->output_list(kOutput, &opts->output));
for (size_t i = 0; i < output_types().size(); ++i) {
if (output_shapes_.at(i).IsFullyDefined()) {
Tensor* out = nullptr;
TensorShape output_shape;
output_shapes_.at(i).AsTensorShape(&output_shape);
output_shape.InsertDim(0, opts->batch_size);
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
}
}
return Status::OK();
}
namespace {
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
} // namespace
} // namespace data

View File

@ -0,0 +1,76 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
// This op runs a given defun on slices of the input arguments. The function
// given by "f" is assumed to be stateless, and is executed concurrently
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
// functions will be scheduled at once.
//
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
// limit the intra op parallelism. To limit inter-op parallelism, a user
// can set a private threadpool on the dataset using `tf.data.Options`'s
// `ThreadingOptions`.
//
// Note that this op is not exposed to users directly, but is invoked in
// tf.data rewrites.
class MapDefunOp : public AsyncOpKernel {
public:
static constexpr const char* const kArguments = "arguments";
static constexpr const char* const kCapturedInputs = "captured_inputs";
static constexpr const char* const kTarguments = "Targuments";
static constexpr const char* const kTcaptured = "Tcaptured";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kFunc = "f";
static constexpr const char* const kMaxIntraOpParallelism =
"max_intra_op_parallelism";
explicit MapDefunOp(OpKernelConstruction* ctx);
~MapDefunOp() override = default;
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
private:
struct ComputeOptions;
class MapFunctionCallFrame;
void SetRunOptions(OpKernelContext* ctx,
FunctionLibraryRuntime::Options* opts,
ComputeOptions* compute_opts, bool always_collect_stats);
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts);
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts);
FunctionLibraryRuntime::Handle func_handle_;
std::vector<PartialTensorShape> output_shapes_;
// If this value is positive, limit the max intra op parallelism when the
// function is run on slices of the input.
int max_intra_op_parallelism_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_

View File

@ -0,0 +1,259 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/map_defun_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "map_defun";
constexpr char kOpName[] = "MapDefun";
class MapDefunOpTest : public DatasetOpsTestBase {
protected:
// Creates a new `MapDefun` op kernel
Status CreateMapDefunOpKernel(
const DataTypeVector& t_arguments, const DataTypeVector& t_captured,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
const FunctionDefHelper::AttrValueWrapper& func,
int max_intra_op_parallelism,
std::unique_ptr<OpKernel>* map_defun_kernel) {
std::vector<string> input_placeholders;
input_placeholders.reserve(t_arguments.size() + t_captured.size());
for (int i = 0; i < t_arguments.size(); ++i) {
input_placeholders.emplace_back(
strings::StrCat(MapDefunOp::kArguments, "_", i));
}
for (int i = 0; i < t_captured.size(); ++i) {
input_placeholders.emplace_back(
strings::StrCat(MapDefunOp::kCapturedInputs, "_", i));
}
NodeDef node_def = test::function::NDef(
kNodeName, kOpName, input_placeholders,
{{MapDefunOp::kTarguments, t_arguments},
{MapDefunOp::kTcaptured, t_captured},
{MapDefunOp::kOutputTypes, output_types},
{MapDefunOp::kOutputShapes, output_shapes},
{MapDefunOp::kFunc, func},
{MapDefunOp::kMaxIntraOpParallelism, max_intra_op_parallelism}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, map_defun_kernel));
return Status::OK();
}
// Creates a new `MapDefun` op kernel context.
Status CreateMapDefunContext(OpKernel* const op_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
};
struct TestCase {
std::vector<Tensor> arguments;
std::vector<Tensor> captured_inputs;
DataTypeVector t_arguments;
DataTypeVector t_captured;
FunctionDefHelper::AttrValueWrapper func;
std::vector<FunctionDef> func_lib;
int max_intra_op_parallelism;
DataTypeVector output_dtypes;
std::vector<PartialTensorShape> output_shapes;
std::vector<Tensor> expected_outputs;
};
// Test case 1: one input for the map function with no captured inputs.
TestCase TestCase1() {
return {
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
/*captured_inputs*/ {},
/*t_arguments*/ {DT_INT64},
/*t_captured*/ {},
/*func*/ {FunctionDefHelper::FunctionRef("XTimesTwo", {{"T", DT_INT64}})},
/*func_lib*/ {test::function::XTimesTwo()},
/*max_intra_op_parallelism*/ 2,
/*output_dtypes*/ {DT_INT64},
/*output_shapes*/ {PartialTensorShape({2})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
{0, 2, 4, 6, 8, 10})}};
}
// Test case 2: two inputs for the map function with no captured inputs.
TestCase TestCase2() {
return {/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 10, 20, 30, 40, 50})},
/*captured_inputs*/ {},
/*t_arguments*/ {DT_INT64, DT_INT64},
/*t_captured*/ {},
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
/*func_lib*/ {test::function::XAddY()},
/*max_intra_op_parallelism*/ 2,
/*output_dtypes*/ {DT_INT64},
/*output_shapes*/ {PartialTensorShape({2})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
{0, 11, 22, 33, 44, 55})}};
}
// Test case 3: two inputs for the map function with one captured input.
TestCase TestCase3() {
return {
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
/*captured_inputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
/*t_arguments*/ {DT_INT64},
/*t_captured*/ {DT_INT64},
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
/*func_lib*/ {test::function::XAddY()},
/*max_intra_op_parallelism*/ 2,
/*output_dtypes*/ {DT_INT64},
/*output_shapes*/ {PartialTensorShape({2})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
{10, 101, 12, 103, 14, 105})}};
}
TestCase InvalidOutputTypes() {
return {
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
/*captured_inputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
/*t_arguments*/ {DT_INT64},
/*t_captured*/ {DT_INT64},
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
/*func_lib*/ {test::function::XAddY()},
/*max_intra_op_parallelism*/ 2,
/*output_dtypes*/ {DT_FLOAT},
/*output_shapes*/ {PartialTensorShape({2})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
{10, 101, 12, 103, 14, 105})}};
}
TestCase InvalidOutputShapes() {
return {
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
/*captured_inputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
/*t_arguments*/ {DT_INT64},
/*t_captured*/ {DT_INT64},
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
/*func_lib*/ {test::function::XAddY()},
/*max_intra_op_parallelism*/ 2,
/*output_dtypes*/ {DT_INT64},
/*output_shapes*/ {PartialTensorShape({2, 2})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
{10, 101, 12, 103, 14, 105})}};
}
TestCase InvalidInputs() {
return {
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
DatasetOpsTestBase::CreateTensor<int64>(
TensorShape({2, 2}), {0, 1, 2, 3})},
/*captured_inputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
/*t_arguments*/ {DT_INT64, DT_INT64},
/*t_captured*/ {DT_INT64},
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
/*func_lib*/ {test::function::XAddY()},
/*max_intra_op_parallelism*/ 2,
/*output_dtypes*/ {DT_INT64},
/*output_shapes*/ {PartialTensorShape({2})},
/*expected_outputs*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
{10, 101, 12, 103, 14, 105})}};
}
class ParameterizedMapDefunOpTest
: public MapDefunOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedMapDefunOpTest, NormalTests) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> map_defun_kernel;
TF_ASSERT_OK(CreateMapDefunOpKernel(
test_case.t_arguments, test_case.t_captured, test_case.output_dtypes,
test_case.output_shapes, test_case.func,
test_case.max_intra_op_parallelism, &map_defun_kernel));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto& arg : test_case.arguments) {
inputs.emplace_back(&arg);
}
for (auto& captured_input : test_case.captured_inputs) {
inputs.emplace_back(&captured_input);
}
std::unique_ptr<OpKernelContext> context;
TF_ASSERT_OK(
CreateMapDefunContext(map_defun_kernel.get(), &inputs, &context));
TF_ASSERT_OK(RunOpKernel(map_defun_kernel.get(), context.get()));
EXPECT_EQ(context->num_outputs(), test_case.expected_outputs.size());
for (int i = 0; i < context->num_outputs(); ++i) {
TF_EXPECT_OK(ExpectEqual(*context->mutable_output(i),
test_case.expected_outputs[i]));
}
}
INSTANTIATE_TEST_SUITE_P(MapDefunOpTest, ParameterizedMapDefunOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TestCase1(), TestCase2(), TestCase3()})));
TEST_F(MapDefunOpTest, InvalidArguments) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
std::vector<TestCase> test_cases = {InvalidOutputTypes(),
InvalidOutputShapes(), InvalidInputs()};
for (auto& test_case : test_cases) {
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
std::unique_ptr<OpKernel> map_defun_kernel;
TF_ASSERT_OK(CreateMapDefunOpKernel(
test_case.t_arguments, test_case.t_captured, test_case.output_dtypes,
test_case.output_shapes, test_case.func,
test_case.max_intra_op_parallelism, &map_defun_kernel));
gtl::InlinedVector<TensorValue, 4> inputs;
for (auto& arg : test_case.arguments) {
inputs.emplace_back(&arg);
}
for (auto& captured_input : test_case.captured_inputs) {
inputs.emplace_back(&captured_input);
}
std::unique_ptr<OpKernelContext> context;
TF_ASSERT_OK(
CreateMapDefunContext(map_defun_kernel.get(), &inputs, &context));
EXPECT_EQ(RunOpKernel(map_defun_kernel.get(), context.get()).code(),
tensorflow::error::INVALID_ARGUMENT);
}
}
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/optimize_dataset_op.h"
#include <map>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
@ -23,66 +24,68 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
public:
explicit OptimizeDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(
ctx, ctx->GetAttr("optimization_configs", &optimization_configs_));
/* static */ constexpr const char* const OptimizeDatasetOp::kDatasetType;
/* static */ constexpr const char* const OptimizeDatasetOp::kInputDataset;
/* static */ constexpr const char* const OptimizeDatasetOp::kOptimizations;
/* static */ constexpr const char* const OptimizeDatasetOp::kOutputTypes;
/* static */ constexpr const char* const OptimizeDatasetOp::kOutputShapes;
/* static */ constexpr const char* const
OptimizeDatasetOp::kOptimizationConfigs;
constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
constexpr char kOptimizers[] = "optimizers";
constexpr char kOptimizerConfigs[] = "optimizer_configs";
OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx,
ctx->GetAttr(kOptimizationConfigs, &optimization_configs_));
}
void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
std::vector<string> optimizations;
OP_REQUIRES_OK(
ctx, ParseVectorArgument<string>(ctx, kOptimizations, &optimizations));
auto config_factory = [this, &optimizations]() {
return CreateConfig(optimizations, optimization_configs_);
};
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/true, output));
}
RewriterConfig OptimizeDatasetOp::CreateConfig(
std::vector<string> optimizations,
std::vector<string> optimizations_configs) {
RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
rewriter_config.set_fail_on_optimizer_errors(true);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
auto* custom_optimizations_list =
(*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list();
for (const auto& opt : optimizations) {
custom_optimizations_list->add_s(opt);
}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
std::vector<string> optimizations;
OP_REQUIRES_OK(
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
auto config_factory = [this, &optimizations]() {
return CreateConfig(optimizations, optimization_configs_);
};
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/true, output));
auto* config_list =
(*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs]
.mutable_list();
for (const auto& config : optimizations_configs) {
config_list->add_s(config);
}
return rewriter_config;
}
private:
static RewriterConfig CreateConfig(
std::vector<string> optimizations,
std::vector<string> optimizations_configs) {
RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
rewriter_config.set_fail_on_optimizer_errors(true);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
auto* custom_optimizations_list =
(*custom_optimizer->mutable_parameter_map())["optimizers"]
.mutable_list();
for (const auto& opt : optimizations) {
custom_optimizations_list->add_s(opt);
}
auto* config_list =
(*custom_optimizer->mutable_parameter_map())["optimizer_configs"]
.mutable_list();
for (const auto& config : optimizations_configs) {
config_list->add_s(config);
}
return rewriter_config;
}
std::vector<string> optimization_configs_;
};
namespace {
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
OptimizeDatasetOp);
} // namespace
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
namespace data {
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "Optimize";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kOptimizations = "optimizations";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
static constexpr const char* const kOptimizationConfigs =
"optimization_configs";
explicit OptimizeDatasetOp(OpKernelConstruction* ctx);
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override;
private:
static RewriterConfig CreateConfig(std::vector<string> optimizations,
std::vector<string> optimizations_configs);
std::vector<string> optimization_configs_;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_

View File

@ -9,15 +9,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/optimize_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/data/take_dataset_op.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "optimize_dataset";
constexpr char kOpName[] = "OptimizeDataset";
constexpr char kNoopElimination[] = "noop_elimination";
constexpr char kIteratorPrefix[] = "Iterator";
class OptimizeDatasetOpTest : public DatasetOpsTestBase {
protected:
@ -28,10 +32,11 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
const std::vector<string>& optimization_configs,
std::unique_ptr<OpKernel>* optimize_dataset_op_kernel) {
NodeDef node_def = test::function::NDef(
kNodeName, kOpName, {"input_dataset", "optimizations"},
{{"output_types", output_types},
{"output_shapes", output_shapes},
{"optimization_configs", optimization_configs}});
kNodeName, name_utils::OpName(OptimizeDatasetOp::kDatasetType),
{OptimizeDatasetOp::kInputDataset, OptimizeDatasetOp::kOptimizations},
{{OptimizeDatasetOp::kOutputTypes, output_types},
{OptimizeDatasetOp::kOutputShapes, output_shapes},
{OptimizeDatasetOp::kOptimizationConfigs, optimization_configs}});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, optimize_dataset_op_kernel));
return Status::OK();
}
@ -55,12 +60,13 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
GraphConstructorOptions graph_opts;
graph_opts.allow_internal_ops = true;
graph_opts.expect_device_spec = false;
TF_RETURN_IF_ERROR(RunFunction(
test::function::MakeRangeDataset(),
/*attrs*/
{{"output_types", output_types}, {"output_shapes", output_shapes}},
/*inputs*/ {start, stop, step}, graph_opts,
/*rets*/ {range_dataset}));
TF_RETURN_IF_ERROR(
RunFunction(test::function::MakeRangeDataset(),
/*attrs*/
{{RangeDatasetOp::kOutputTypes, output_types},
{RangeDatasetOp::kOutputShapes, output_shapes}},
/*inputs*/ {start, stop, step}, graph_opts,
/*rets*/ {range_dataset}));
return Status::OK();
}
@ -74,12 +80,13 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
graph_opts.expect_device_spec = false;
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
TF_RETURN_IF_ERROR(RunFunction(
test::function::MakeTakeDataset(),
/*attrs*/
{{"output_types", output_types}, {"output_shapes", output_shapes}},
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
/*rets*/ {take_dataset}));
TF_RETURN_IF_ERROR(
RunFunction(test::function::MakeTakeDataset(),
/*attrs*/
{{TakeDatasetOp::kOutputTypes, output_types},
{TakeDatasetOp::kOutputShapes, output_shapes}},
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
/*rets*/ {take_dataset}));
return Status::OK();
}
};
@ -109,7 +116,7 @@ TEST_F(OptimizeDatasetOpTest, NoopElimination) {
/*optimization_configs*/ {},
&optimize_dataset_kernel));
Tensor optimizations =
CreateTensor<string>(TensorShape({1}), {"noop_elimination"});
CreateTensor<string>(TensorShape({1}), {kNoopElimination});
gtl::InlinedVector<TensorValue, 4> inputs(
{TensorValue(&take_dataset_tensor), TensorValue(&optimizations)});
std::unique_ptr<OpKernelContext> optimize_dataset_context;
@ -127,7 +134,7 @@ TEST_F(OptimizeDatasetOpTest, NoopElimination) {
CreateIteratorContext(optimize_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(optimize_dataset->MakeIterator(iterator_context.get(),
"Iterator", &iterator));
kIteratorPrefix, &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;