Merge pull request #30739 from feihugis:Refactor_DatasetOps_9
PiperOrigin-RevId: 258562390
This commit is contained in:
commit
4fd6623585
@ -1181,6 +1181,7 @@ tf_cc_test(
|
||||
tf_kernel_library(
|
||||
name = "optimize_dataset_op",
|
||||
srcs = ["optimize_dataset_op.cc"],
|
||||
hdrs = ["optimize_dataset_op.h"],
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -1284,6 +1285,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "map_defun_op",
|
||||
srcs = ["map_defun_op.cc"],
|
||||
hdrs = ["map_defun_op.h"],
|
||||
deps = [
|
||||
":dataset_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -1293,3 +1295,24 @@ tf_kernel_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "map_defun_op_test",
|
||||
size = "small",
|
||||
srcs = ["map_defun_op_test.cc"],
|
||||
deps = [
|
||||
":dataset_test_base",
|
||||
":dataset_utils",
|
||||
":map_defun_op",
|
||||
":stats_utils",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
],
|
||||
)
|
||||
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/map_defun_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -27,299 +28,283 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
// This op runs a given defun on slices of the input arguments. The function
|
||||
// given by "f" is assumed to be stateless, and is executed concurrently
|
||||
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
|
||||
// functions will be scheduled at once.
|
||||
//
|
||||
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
|
||||
// limit the intra op parallelism. To limit inter-op parallelism, a user
|
||||
// can set a private threadpool on the dataset using `tf.data.Options`'s
|
||||
// `ThreadingOptions`.
|
||||
//
|
||||
// Note that this op is not exposed to users directly, but is invoked in
|
||||
// tf.data rewrites.
|
||||
class MapDefunOp : public AsyncOpKernel {
|
||||
/* static */ constexpr const char* const MapDefunOp::kArguments;
|
||||
/* static */ constexpr const char* const MapDefunOp::kCapturedInputs;
|
||||
/* static */ constexpr const char* const MapDefunOp::kTarguments;
|
||||
/* static */ constexpr const char* const MapDefunOp::kTcaptured;
|
||||
/* static */ constexpr const char* const MapDefunOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const MapDefunOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const MapDefunOp::kFunc;
|
||||
/* static */ constexpr const char* const MapDefunOp::kMaxIntraOpParallelism;
|
||||
|
||||
constexpr char kOutput[] = "output";
|
||||
|
||||
struct MapDefunOp::ComputeOptions {
|
||||
// These vary per MapDefunOp::ComputeAsync call, but must persist until
|
||||
// all calls to the function are complete. This struct also encapsulates
|
||||
// all the components that need to be passed to each MapFunctionCallFrame.
|
||||
OpInputList args;
|
||||
const std::vector<TensorShape> arg_shapes;
|
||||
OpInputList captured_inputs;
|
||||
const int64 batch_size;
|
||||
std::function<void(std::function<void()>)> runner;
|
||||
|
||||
// Output of a compute call
|
||||
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
|
||||
OpOutputList output GUARDED_BY(mu);
|
||||
mutex mu;
|
||||
|
||||
// Create a copy of output_shapes because every `Compute` may expect a
|
||||
// different output shape.
|
||||
ComputeOptions(OpKernelContext* ctx, OpInputList args,
|
||||
OpInputList captured_inputs,
|
||||
std::vector<TensorShape> arg_shapes, int64 batch_size,
|
||||
const std::vector<PartialTensorShape>& output_shapes_attr,
|
||||
int max_parallelism)
|
||||
: args(args),
|
||||
arg_shapes(std::move(arg_shapes)),
|
||||
captured_inputs(captured_inputs),
|
||||
batch_size(batch_size),
|
||||
output_shapes(output_shapes_attr) {
|
||||
if (max_parallelism >= 1) {
|
||||
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface {
|
||||
public:
|
||||
explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
auto func_lib = ctx->function_library();
|
||||
OP_REQUIRES(ctx, func_lib != nullptr,
|
||||
errors::Internal("No function library."));
|
||||
const NameAttrList* func;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
|
||||
&func_handle_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
|
||||
&max_intra_op_parallelism_));
|
||||
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
|
||||
size_t iter)
|
||||
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
|
||||
|
||||
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one input."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one output."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"Length of output_shapes and output_types must match."));
|
||||
~MapFunctionCallFrame() override = default;
|
||||
|
||||
size_t num_args() const override { return compute_opts_->args.size(); }
|
||||
|
||||
size_t num_retvals() const override {
|
||||
return static_cast<size_t>(kernel_->num_outputs());
|
||||
}
|
||||
|
||||
~MapDefunOp() override {}
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
ComputeOptions* compute_opts = nullptr;
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
|
||||
|
||||
Status s = SetupOutputs(ctx, compute_opts);
|
||||
if (!s.ok()) delete compute_opts;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
|
||||
|
||||
// Run loop
|
||||
StatusCallback callback = std::bind(
|
||||
[](OpKernelContext* ctx, ComputeOptions* compute_opts,
|
||||
DoneCallback& done, const Status& status) {
|
||||
delete compute_opts;
|
||||
ctx->SetStatus(status);
|
||||
done();
|
||||
},
|
||||
ctx, compute_opts, std::move(done), std::placeholders::_1);
|
||||
|
||||
auto* refcounted = new ReffedStatusCallback(std::move(callback));
|
||||
|
||||
CancellationManager* parent_mgr = ctx->cancellation_manager();
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
|
||||
// We use a different cancellation manager each time the function is run
|
||||
// to avoid the race condition between a function run error and other
|
||||
// functions being cancelled as a result.
|
||||
CancellationManager* c_mgr = new CancellationManager();
|
||||
CancellationToken token = parent_mgr->get_cancellation_token();
|
||||
const bool success = parent_mgr->RegisterCallback(
|
||||
token, [c_mgr]() { c_mgr->StartCancel(); });
|
||||
|
||||
opts.cancellation_manager = c_mgr;
|
||||
if (!success) {
|
||||
delete c_mgr;
|
||||
refcounted->UpdateStatus(errors::Cancelled(
|
||||
"MapDefunOp functions cancelled because parent graph cancelled"));
|
||||
break;
|
||||
}
|
||||
|
||||
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
|
||||
|
||||
refcounted->Ref();
|
||||
ctx->function_library()->Run(opts, func_handle_, call_frame,
|
||||
[call_frame, refcounted, c_mgr, parent_mgr,
|
||||
token](const Status& func_status) {
|
||||
parent_mgr->DeregisterCallback(token);
|
||||
delete c_mgr;
|
||||
delete call_frame;
|
||||
refcounted->UpdateStatus(func_status);
|
||||
refcounted->Unref();
|
||||
});
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index < 0 || index >= compute_opts_->args.size() +
|
||||
compute_opts_->captured_inputs.size()) {
|
||||
return errors::InvalidArgument("Mismatch in number of function inputs.");
|
||||
}
|
||||
|
||||
// Unref 1 because refcounted is initialized with refcount = 1
|
||||
refcounted->Unref();
|
||||
}
|
||||
|
||||
private:
|
||||
struct ComputeOptions {
|
||||
// These vary per MapDefunOp::ComputeAsync call, but must persist until
|
||||
// all calls to the function are complete. This struct also encapsulates
|
||||
// all the components that need to be passed to each MapFunctionCallFrame.
|
||||
|
||||
OpInputList args;
|
||||
const std::vector<TensorShape> arg_shapes;
|
||||
OpInputList captured_inputs;
|
||||
const int64 batch_size;
|
||||
std::function<void(std::function<void()>)> runner;
|
||||
|
||||
// Output of a compute call
|
||||
std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
|
||||
OpOutputList output GUARDED_BY(mu);
|
||||
mutex mu;
|
||||
|
||||
// Create a copy of output_shapes because every `Compute` may expect a
|
||||
// different output shape.
|
||||
ComputeOptions(OpKernelContext* ctx, OpInputList args,
|
||||
OpInputList captured_inputs,
|
||||
std::vector<TensorShape> arg_shapes, int64 batch_size,
|
||||
const std::vector<PartialTensorShape>& output_shapes_attr,
|
||||
int max_parallelism)
|
||||
: args(args),
|
||||
arg_shapes(std::move(arg_shapes)),
|
||||
captured_inputs(captured_inputs),
|
||||
batch_size(batch_size),
|
||||
output_shapes(output_shapes_attr) {
|
||||
if (max_parallelism >= 1) {
|
||||
runner = RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MapFunctionCallFrame : public CallFrameInterface {
|
||||
public:
|
||||
MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
|
||||
size_t iter)
|
||||
: compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
|
||||
|
||||
~MapFunctionCallFrame() override {}
|
||||
|
||||
size_t num_args() const override { return compute_opts_->args.size(); }
|
||||
|
||||
size_t num_retvals() const override {
|
||||
return static_cast<size_t>(kernel_->num_outputs());
|
||||
}
|
||||
|
||||
Status GetArg(int index, Tensor* val) const override {
|
||||
if (index < 0 || index >= compute_opts_->args.size() +
|
||||
compute_opts_->captured_inputs.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in number of function inputs.");
|
||||
}
|
||||
|
||||
if (index >= compute_opts_->args.size()) {
|
||||
// The function is calling for a captured input
|
||||
*val =
|
||||
compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool result =
|
||||
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
if (!result) {
|
||||
return errors::Internal("GetArg failed.");
|
||||
} else if (!val->IsAligned()) {
|
||||
// Ensure alignment
|
||||
*val = tensor::DeepCopy(*val);
|
||||
}
|
||||
if (index >= compute_opts_->args.size()) {
|
||||
// The function is calling for a captured input
|
||||
*val = compute_opts_->captured_inputs[index - compute_opts_->args.size()];
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetRetval(int index, const Tensor& val) override {
|
||||
if (index < 0 || index >= kernel_->num_outputs()) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in number of function outputs.");
|
||||
}
|
||||
|
||||
if (val.dtype() != kernel_->output_type(index)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function return type and expected output type for "
|
||||
"output: ",
|
||||
index);
|
||||
}
|
||||
Tensor* out;
|
||||
{ // Locking scope
|
||||
mutex_lock l(compute_opts_->mu);
|
||||
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
|
||||
val.shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function retval shape, ", val.shape(),
|
||||
", and expected output shape, ",
|
||||
compute_opts_->output_shapes.at(index).DebugString(), ".");
|
||||
}
|
||||
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
|
||||
// Given val, we have new information about the output shape at
|
||||
// this index. Store the shape and allocate the output accordingly.
|
||||
compute_opts_->output_shapes.at(index) = val.shape();
|
||||
|
||||
TensorShape actual_shape = val.shape();
|
||||
actual_shape.InsertDim(0, compute_opts_->batch_size);
|
||||
TF_RETURN_IF_ERROR(
|
||||
compute_opts_->output.allocate(index, actual_shape, &out));
|
||||
} else {
|
||||
out = (compute_opts_->output)[index];
|
||||
}
|
||||
}
|
||||
return batch_util::CopyElementToSlice(val, out, iter_);
|
||||
}
|
||||
|
||||
private:
|
||||
ComputeOptions* const compute_opts_; // Not owned
|
||||
const OpKernel* kernel_;
|
||||
const size_t iter_;
|
||||
}; // MapFunctionCallFrame
|
||||
|
||||
void SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts, bool always_collect_stats) {
|
||||
opts->rendezvous = ctx->rendezvous();
|
||||
if (always_collect_stats) {
|
||||
opts->stats_collector = ctx->stats_collector();
|
||||
}
|
||||
if (max_intra_op_parallelism_ >= 1) {
|
||||
opts->runner = &compute_opts->runner;
|
||||
} else {
|
||||
opts->runner = ctx->runner();
|
||||
}
|
||||
}
|
||||
|
||||
// Get inputs to Compute and check that they are valid.
|
||||
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
|
||||
OpInputList arguments;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
|
||||
OpInputList captured_inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
|
||||
|
||||
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
if (arguments[i].dims() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have rank at least 1. Input ", i,
|
||||
" has a rank of 0.");
|
||||
} else if (arguments[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
", while all previous inputs have leading dimension ", batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
arg_shapes.reserve(arguments.size());
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
arg_shapes.push_back(arguments[i].shape());
|
||||
arg_shapes.at(i).RemoveDim(0);
|
||||
}
|
||||
|
||||
*compute_opts = new ComputeOptions(
|
||||
ctx, arguments, captured_inputs, std::move(arg_shapes), batch_size,
|
||||
output_shapes_, max_intra_op_parallelism_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
|
||||
mutex_lock l(opts->mu);
|
||||
TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
|
||||
|
||||
for (size_t i = 0; i < output_types().size(); ++i) {
|
||||
if (output_shapes_.at(i).IsFullyDefined()) {
|
||||
Tensor* out = nullptr;
|
||||
TensorShape output_shape;
|
||||
output_shapes_.at(i).AsTensorShape(&output_shape);
|
||||
output_shape.InsertDim(0, opts->batch_size);
|
||||
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
|
||||
}
|
||||
bool result =
|
||||
val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
|
||||
compute_opts_->arg_shapes.at(index));
|
||||
if (!result) {
|
||||
return errors::Internal("GetArg failed.");
|
||||
} else if (!val->IsAligned()) {
|
||||
// Ensure alignment
|
||||
*val = tensor::DeepCopy(*val);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime::Handle func_handle_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
// If this value is positive, limit the max intra op parallelism when the
|
||||
// function is run on slices of the input.
|
||||
int max_intra_op_parallelism_;
|
||||
}; // MapDefunOp
|
||||
Status SetRetval(int index, const Tensor& val) override {
|
||||
if (index < 0 || index >= kernel_->num_outputs()) {
|
||||
return errors::InvalidArgument("Mismatch in number of function outputs.");
|
||||
}
|
||||
|
||||
if (val.dtype() != kernel_->output_type(index)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function return type and expected output type for "
|
||||
"output: ",
|
||||
index);
|
||||
}
|
||||
Tensor* out;
|
||||
{ // Locking scope
|
||||
mutex_lock l(compute_opts_->mu);
|
||||
if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
|
||||
val.shape())) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch in function retval shape, ", val.shape(),
|
||||
", and expected output shape, ",
|
||||
compute_opts_->output_shapes.at(index).DebugString(), ".");
|
||||
}
|
||||
if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
|
||||
// Given val, we have new information about the output shape at
|
||||
// this index. Store the shape and allocate the output accordingly.
|
||||
compute_opts_->output_shapes.at(index) = val.shape();
|
||||
|
||||
TensorShape actual_shape = val.shape();
|
||||
actual_shape.InsertDim(0, compute_opts_->batch_size);
|
||||
TF_RETURN_IF_ERROR(
|
||||
compute_opts_->output.allocate(index, actual_shape, &out));
|
||||
} else {
|
||||
out = (compute_opts_->output)[index];
|
||||
}
|
||||
}
|
||||
return batch_util::CopyElementToSlice(val, out, iter_);
|
||||
}
|
||||
|
||||
private:
|
||||
ComputeOptions* const compute_opts_; // Not owned
|
||||
const OpKernel* kernel_;
|
||||
const size_t iter_;
|
||||
};
|
||||
|
||||
MapDefunOp::MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||
auto func_lib = ctx->function_library();
|
||||
OP_REQUIRES(ctx, func_lib != nullptr,
|
||||
errors::Internal("No function library."));
|
||||
const NameAttrList* func;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kFunc, &func));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
|
||||
&func_handle_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kMaxIntraOpParallelism, &max_intra_op_parallelism_));
|
||||
|
||||
OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one input."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
|
||||
errors::InvalidArgument("Must have at least one output."));
|
||||
OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
|
||||
errors::InvalidArgument(
|
||||
"Length of output_shapes and output_types must match."));
|
||||
}
|
||||
|
||||
void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
ComputeOptions* compute_opts = nullptr;
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
|
||||
|
||||
Status s = SetupOutputs(ctx, compute_opts);
|
||||
if (!s.ok()) delete compute_opts;
|
||||
OP_REQUIRES_OK_ASYNC(ctx, s, done);
|
||||
|
||||
FunctionLibraryRuntime::Options opts;
|
||||
SetRunOptions(ctx, &opts, compute_opts, /*always_collect_stats=*/false);
|
||||
|
||||
// Run loop
|
||||
StatusCallback callback = std::bind(
|
||||
[](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done,
|
||||
const Status& status) {
|
||||
delete compute_opts;
|
||||
ctx->SetStatus(status);
|
||||
done();
|
||||
},
|
||||
ctx, compute_opts, std::move(done), std::placeholders::_1);
|
||||
|
||||
auto* refcounted = new ReffedStatusCallback(std::move(callback));
|
||||
|
||||
CancellationManager* parent_mgr = ctx->cancellation_manager();
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
|
||||
// We use a different cancellation manager each time the function is run
|
||||
// to avoid the race condition between a function run error and other
|
||||
// functions being cancelled as a result.
|
||||
CancellationManager* c_mgr = new CancellationManager();
|
||||
CancellationToken token = parent_mgr->get_cancellation_token();
|
||||
const bool success = parent_mgr->RegisterCallback(
|
||||
token, [c_mgr]() { c_mgr->StartCancel(); });
|
||||
|
||||
opts.cancellation_manager = c_mgr;
|
||||
if (!success) {
|
||||
delete c_mgr;
|
||||
refcounted->UpdateStatus(errors::Cancelled(
|
||||
"MapDefunOp functions cancelled because parent graph cancelled"));
|
||||
break;
|
||||
}
|
||||
|
||||
auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
|
||||
|
||||
refcounted->Ref();
|
||||
ctx->function_library()->Run(opts, func_handle_, call_frame,
|
||||
[call_frame, refcounted, c_mgr, parent_mgr,
|
||||
token](const Status& func_status) {
|
||||
parent_mgr->DeregisterCallback(token);
|
||||
delete c_mgr;
|
||||
delete call_frame;
|
||||
refcounted->UpdateStatus(func_status);
|
||||
refcounted->Unref();
|
||||
});
|
||||
}
|
||||
|
||||
// Unref 1 because refcounted is initialized with refcount = 1
|
||||
refcounted->Unref();
|
||||
}
|
||||
|
||||
void MapDefunOp::SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts,
|
||||
bool always_collect_stats) {
|
||||
opts->rendezvous = ctx->rendezvous();
|
||||
if (always_collect_stats) {
|
||||
opts->stats_collector = ctx->stats_collector();
|
||||
}
|
||||
if (max_intra_op_parallelism_ >= 1) {
|
||||
opts->runner = &compute_opts->runner;
|
||||
} else {
|
||||
opts->runner = ctx->runner();
|
||||
}
|
||||
}
|
||||
|
||||
Status MapDefunOp::SetupArgs(OpKernelContext* ctx,
|
||||
ComputeOptions** compute_opts) {
|
||||
OpInputList arguments;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list(kArguments, &arguments));
|
||||
OpInputList captured_inputs;
|
||||
TF_RETURN_IF_ERROR(ctx->input_list(kCapturedInputs, &captured_inputs));
|
||||
|
||||
int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
if (arguments[i].dims() == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have rank at least 1. Input ", i,
|
||||
" has a rank of 0.");
|
||||
} else if (arguments[i].dim_size(0) != batch_size) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs must have the same dimension 0. Input ", i,
|
||||
" has leading dimension ", ctx->input(i).dim_size(0),
|
||||
", while all previous inputs have leading dimension ", batch_size);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
arg_shapes.reserve(arguments.size());
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
arg_shapes.push_back(arguments[i].shape());
|
||||
arg_shapes.at(i).RemoveDim(0);
|
||||
}
|
||||
|
||||
*compute_opts =
|
||||
new ComputeOptions(ctx, arguments, captured_inputs, std::move(arg_shapes),
|
||||
batch_size, output_shapes_, max_intra_op_parallelism_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
|
||||
mutex_lock l(opts->mu);
|
||||
TF_RETURN_IF_ERROR(ctx->output_list(kOutput, &opts->output));
|
||||
|
||||
for (size_t i = 0; i < output_types().size(); ++i) {
|
||||
if (output_shapes_.at(i).IsFullyDefined()) {
|
||||
Tensor* out = nullptr;
|
||||
TensorShape output_shape;
|
||||
output_shapes_.at(i).AsTensorShape(&output_shape);
|
||||
output_shape.InsertDim(0, opts->batch_size);
|
||||
TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
|
||||
} // namespace
|
||||
} // namespace data
|
||||
|
76
tensorflow/core/kernels/data/map_defun_op.h
Normal file
76
tensorflow/core/kernels/data/map_defun_op.h
Normal file
@ -0,0 +1,76 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// This op runs a given defun on slices of the input arguments. The function
|
||||
// given by "f" is assumed to be stateless, and is executed concurrently
|
||||
// on all the slices; up to batch_size (i.e. the 0th dimension of each argument)
|
||||
// functions will be scheduled at once.
|
||||
//
|
||||
// The "max_intra_op_parallelism" attr, which defaults to 1, can be used to
|
||||
// limit the intra op parallelism. To limit inter-op parallelism, a user
|
||||
// can set a private threadpool on the dataset using `tf.data.Options`'s
|
||||
// `ThreadingOptions`.
|
||||
//
|
||||
// Note that this op is not exposed to users directly, but is invoked in
|
||||
// tf.data rewrites.
|
||||
class MapDefunOp : public AsyncOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kArguments = "arguments";
|
||||
static constexpr const char* const kCapturedInputs = "captured_inputs";
|
||||
static constexpr const char* const kTarguments = "Targuments";
|
||||
static constexpr const char* const kTcaptured = "Tcaptured";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kFunc = "f";
|
||||
static constexpr const char* const kMaxIntraOpParallelism =
|
||||
"max_intra_op_parallelism";
|
||||
|
||||
explicit MapDefunOp(OpKernelConstruction* ctx);
|
||||
|
||||
~MapDefunOp() override = default;
|
||||
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
|
||||
|
||||
private:
|
||||
struct ComputeOptions;
|
||||
class MapFunctionCallFrame;
|
||||
|
||||
void SetRunOptions(OpKernelContext* ctx,
|
||||
FunctionLibraryRuntime::Options* opts,
|
||||
ComputeOptions* compute_opts, bool always_collect_stats);
|
||||
|
||||
// Get inputs to Compute and check that they are valid.
|
||||
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts);
|
||||
|
||||
Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts);
|
||||
|
||||
FunctionLibraryRuntime::Handle func_handle_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
// If this value is positive, limit the max intra op parallelism when the
|
||||
// function is run on slices of the input.
|
||||
int max_intra_op_parallelism_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_MAP_DEFUN_DATASET_OP_H_
|
259
tensorflow/core/kernels/data/map_defun_op_test.cc
Normal file
259
tensorflow/core/kernels/data/map_defun_op_test.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/map_defun_op.h"
|
||||
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "map_defun";
|
||||
constexpr char kOpName[] = "MapDefun";
|
||||
|
||||
class MapDefunOpTest : public DatasetOpsTestBase {
|
||||
protected:
|
||||
// Creates a new `MapDefun` op kernel
|
||||
Status CreateMapDefunOpKernel(
|
||||
const DataTypeVector& t_arguments, const DataTypeVector& t_captured,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes,
|
||||
const FunctionDefHelper::AttrValueWrapper& func,
|
||||
int max_intra_op_parallelism,
|
||||
std::unique_ptr<OpKernel>* map_defun_kernel) {
|
||||
std::vector<string> input_placeholders;
|
||||
input_placeholders.reserve(t_arguments.size() + t_captured.size());
|
||||
for (int i = 0; i < t_arguments.size(); ++i) {
|
||||
input_placeholders.emplace_back(
|
||||
strings::StrCat(MapDefunOp::kArguments, "_", i));
|
||||
}
|
||||
for (int i = 0; i < t_captured.size(); ++i) {
|
||||
input_placeholders.emplace_back(
|
||||
strings::StrCat(MapDefunOp::kCapturedInputs, "_", i));
|
||||
}
|
||||
|
||||
NodeDef node_def = test::function::NDef(
|
||||
kNodeName, kOpName, input_placeholders,
|
||||
{{MapDefunOp::kTarguments, t_arguments},
|
||||
{MapDefunOp::kTcaptured, t_captured},
|
||||
{MapDefunOp::kOutputTypes, output_types},
|
||||
{MapDefunOp::kOutputShapes, output_shapes},
|
||||
{MapDefunOp::kFunc, func},
|
||||
{MapDefunOp::kMaxIntraOpParallelism, max_intra_op_parallelism}});
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, map_defun_kernel));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a new `MapDefun` op kernel context.
|
||||
Status CreateMapDefunContext(OpKernel* const op_kernel,
|
||||
gtl::InlinedVector<TensorValue, 4>* const inputs,
|
||||
std::unique_ptr<OpKernelContext>* context) {
|
||||
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
|
||||
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestCase {
|
||||
std::vector<Tensor> arguments;
|
||||
std::vector<Tensor> captured_inputs;
|
||||
DataTypeVector t_arguments;
|
||||
DataTypeVector t_captured;
|
||||
FunctionDefHelper::AttrValueWrapper func;
|
||||
std::vector<FunctionDef> func_lib;
|
||||
int max_intra_op_parallelism;
|
||||
DataTypeVector output_dtypes;
|
||||
std::vector<PartialTensorShape> output_shapes;
|
||||
std::vector<Tensor> expected_outputs;
|
||||
};
|
||||
|
||||
// Test case 1: one input for the map function with no captured inputs.
|
||||
TestCase TestCase1() {
|
||||
return {
|
||||
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
|
||||
/*captured_inputs*/ {},
|
||||
/*t_arguments*/ {DT_INT64},
|
||||
/*t_captured*/ {},
|
||||
/*func*/ {FunctionDefHelper::FunctionRef("XTimesTwo", {{"T", DT_INT64}})},
|
||||
/*func_lib*/ {test::function::XTimesTwo()},
|
||||
/*max_intra_op_parallelism*/ 2,
|
||||
/*output_dtypes*/ {DT_INT64},
|
||||
/*output_shapes*/ {PartialTensorShape({2})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
|
||||
{0, 2, 4, 6, 8, 10})}};
|
||||
}
|
||||
|
||||
// Test case 2: two inputs for the map function with no captured inputs.
|
||||
TestCase TestCase2() {
|
||||
return {/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 10, 20, 30, 40, 50})},
|
||||
/*captured_inputs*/ {},
|
||||
/*t_arguments*/ {DT_INT64, DT_INT64},
|
||||
/*t_captured*/ {},
|
||||
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
|
||||
/*func_lib*/ {test::function::XAddY()},
|
||||
/*max_intra_op_parallelism*/ 2,
|
||||
/*output_dtypes*/ {DT_INT64},
|
||||
/*output_shapes*/ {PartialTensorShape({2})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
|
||||
{0, 11, 22, 33, 44, 55})}};
|
||||
}
|
||||
|
||||
// Test case 3: two inputs for the map function with one captured input.
|
||||
TestCase TestCase3() {
|
||||
return {
|
||||
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
|
||||
/*captured_inputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
|
||||
/*t_arguments*/ {DT_INT64},
|
||||
/*t_captured*/ {DT_INT64},
|
||||
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
|
||||
/*func_lib*/ {test::function::XAddY()},
|
||||
/*max_intra_op_parallelism*/ 2,
|
||||
/*output_dtypes*/ {DT_INT64},
|
||||
/*output_shapes*/ {PartialTensorShape({2})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
|
||||
{10, 101, 12, 103, 14, 105})}};
|
||||
}
|
||||
|
||||
TestCase InvalidOutputTypes() {
|
||||
return {
|
||||
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
|
||||
/*captured_inputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
|
||||
/*t_arguments*/ {DT_INT64},
|
||||
/*t_captured*/ {DT_INT64},
|
||||
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
|
||||
/*func_lib*/ {test::function::XAddY()},
|
||||
/*max_intra_op_parallelism*/ 2,
|
||||
/*output_dtypes*/ {DT_FLOAT},
|
||||
/*output_shapes*/ {PartialTensorShape({2})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
|
||||
{10, 101, 12, 103, 14, 105})}};
|
||||
}
|
||||
|
||||
TestCase InvalidOutputShapes() {
|
||||
return {
|
||||
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5})},
|
||||
/*captured_inputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
|
||||
/*t_arguments*/ {DT_INT64},
|
||||
/*t_captured*/ {DT_INT64},
|
||||
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
|
||||
/*func_lib*/ {test::function::XAddY()},
|
||||
/*max_intra_op_parallelism*/ 2,
|
||||
/*output_dtypes*/ {DT_INT64},
|
||||
/*output_shapes*/ {PartialTensorShape({2, 2})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
|
||||
{10, 101, 12, 103, 14, 105})}};
|
||||
}
|
||||
|
||||
TestCase InvalidInputs() {
|
||||
return {
|
||||
/*arguments*/ {DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({3, 2}), {0, 1, 2, 3, 4, 5}),
|
||||
DatasetOpsTestBase::CreateTensor<int64>(
|
||||
TensorShape({2, 2}), {0, 1, 2, 3})},
|
||||
/*captured_inputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({2}), {10, 100})},
|
||||
/*t_arguments*/ {DT_INT64, DT_INT64},
|
||||
/*t_captured*/ {DT_INT64},
|
||||
/*func*/ {FunctionDefHelper::FunctionRef("XAddY", {{"T", DT_INT64}})},
|
||||
/*func_lib*/ {test::function::XAddY()},
|
||||
/*max_intra_op_parallelism*/ 2,
|
||||
/*output_dtypes*/ {DT_INT64},
|
||||
/*output_shapes*/ {PartialTensorShape({2})},
|
||||
/*expected_outputs*/
|
||||
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape({3, 2}),
|
||||
{10, 101, 12, 103, 14, 105})}};
|
||||
}
|
||||
|
||||
class ParameterizedMapDefunOpTest
|
||||
: public MapDefunOpTest,
|
||||
public ::testing::WithParamInterface<TestCase> {};
|
||||
|
||||
TEST_P(ParameterizedMapDefunOpTest, NormalTests) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TestCase test_case = GetParam();
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> map_defun_kernel;
|
||||
TF_ASSERT_OK(CreateMapDefunOpKernel(
|
||||
test_case.t_arguments, test_case.t_captured, test_case.output_dtypes,
|
||||
test_case.output_shapes, test_case.func,
|
||||
test_case.max_intra_op_parallelism, &map_defun_kernel));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto& arg : test_case.arguments) {
|
||||
inputs.emplace_back(&arg);
|
||||
}
|
||||
for (auto& captured_input : test_case.captured_inputs) {
|
||||
inputs.emplace_back(&captured_input);
|
||||
}
|
||||
std::unique_ptr<OpKernelContext> context;
|
||||
TF_ASSERT_OK(
|
||||
CreateMapDefunContext(map_defun_kernel.get(), &inputs, &context));
|
||||
TF_ASSERT_OK(RunOpKernel(map_defun_kernel.get(), context.get()));
|
||||
|
||||
EXPECT_EQ(context->num_outputs(), test_case.expected_outputs.size());
|
||||
for (int i = 0; i < context->num_outputs(); ++i) {
|
||||
TF_EXPECT_OK(ExpectEqual(*context->mutable_output(i),
|
||||
test_case.expected_outputs[i]));
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MapDefunOpTest, ParameterizedMapDefunOpTest,
|
||||
::testing::ValuesIn(std::vector<TestCase>(
|
||||
{TestCase1(), TestCase2(), TestCase3()})));
|
||||
|
||||
TEST_F(MapDefunOpTest, InvalidArguments) {
|
||||
int thread_num = 2, cpu_num = 2;
|
||||
TF_ASSERT_OK(InitThreadPool(thread_num));
|
||||
std::vector<TestCase> test_cases = {InvalidOutputTypes(),
|
||||
InvalidOutputShapes(), InvalidInputs()};
|
||||
for (auto& test_case : test_cases) {
|
||||
TF_ASSERT_OK(InitFunctionLibraryRuntime(test_case.func_lib, cpu_num));
|
||||
|
||||
std::unique_ptr<OpKernel> map_defun_kernel;
|
||||
TF_ASSERT_OK(CreateMapDefunOpKernel(
|
||||
test_case.t_arguments, test_case.t_captured, test_case.output_dtypes,
|
||||
test_case.output_shapes, test_case.func,
|
||||
test_case.max_intra_op_parallelism, &map_defun_kernel));
|
||||
gtl::InlinedVector<TensorValue, 4> inputs;
|
||||
for (auto& arg : test_case.arguments) {
|
||||
inputs.emplace_back(&arg);
|
||||
}
|
||||
for (auto& captured_input : test_case.captured_inputs) {
|
||||
inputs.emplace_back(&captured_input);
|
||||
}
|
||||
std::unique_ptr<OpKernelContext> context;
|
||||
TF_ASSERT_OK(
|
||||
CreateMapDefunContext(map_defun_kernel.get(), &inputs, &context));
|
||||
EXPECT_EQ(RunOpKernel(map_defun_kernel.get(), context.get()).code(),
|
||||
tensorflow::error::INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/optimize_dataset_op.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
@ -23,66 +24,68 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
|
||||
|
||||
// See documentation in ../../ops/dataset_ops.cc for a high-level
|
||||
// description of the following op.
|
||||
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
explicit OptimizeDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr("optimization_configs", &optimization_configs_));
|
||||
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kDatasetType;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kInputDataset;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOptimizations;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOutputTypes;
|
||||
/* static */ constexpr const char* const OptimizeDatasetOp::kOutputShapes;
|
||||
/* static */ constexpr const char* const
|
||||
OptimizeDatasetOp::kOptimizationConfigs;
|
||||
|
||||
constexpr char kOptimizerName[] = "tf_data_meta_optimizer";
|
||||
constexpr char kOptimizers[] = "optimizers";
|
||||
constexpr char kOptimizerConfigs[] = "optimizer_configs";
|
||||
|
||||
OptimizeDatasetOp::OptimizeDatasetOp(OpKernelConstruction* ctx)
|
||||
: UnaryDatasetOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetAttr(kOptimizationConfigs, &optimization_configs_));
|
||||
}
|
||||
|
||||
void OptimizeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) {
|
||||
std::vector<string> optimizations;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseVectorArgument<string>(ctx, kOptimizations, &optimizations));
|
||||
|
||||
auto config_factory = [this, &optimizations]() {
|
||||
return CreateConfig(optimizations, optimization_configs_);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
RewriteDataset(ctx, input, std::move(config_factory),
|
||||
/*optimize_function_library=*/true, output));
|
||||
}
|
||||
|
||||
RewriterConfig OptimizeDatasetOp::CreateConfig(
|
||||
std::vector<string> optimizations,
|
||||
std::vector<string> optimizations_configs) {
|
||||
RewriterConfig rewriter_config;
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
||||
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
auto* custom_optimizations_list =
|
||||
(*custom_optimizer->mutable_parameter_map())[kOptimizers].mutable_list();
|
||||
for (const auto& opt : optimizations) {
|
||||
custom_optimizations_list->add_s(opt);
|
||||
}
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override {
|
||||
std::vector<string> optimizations;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
|
||||
|
||||
auto config_factory = [this, &optimizations]() {
|
||||
return CreateConfig(optimizations, optimization_configs_);
|
||||
};
|
||||
OP_REQUIRES_OK(ctx,
|
||||
RewriteDataset(ctx, input, std::move(config_factory),
|
||||
/*optimize_function_library=*/true, output));
|
||||
auto* config_list =
|
||||
(*custom_optimizer->mutable_parameter_map())[kOptimizerConfigs]
|
||||
.mutable_list();
|
||||
for (const auto& config : optimizations_configs) {
|
||||
config_list->add_s(config);
|
||||
}
|
||||
return rewriter_config;
|
||||
}
|
||||
|
||||
private:
|
||||
static RewriterConfig CreateConfig(
|
||||
std::vector<string> optimizations,
|
||||
std::vector<string> optimizations_configs) {
|
||||
RewriterConfig rewriter_config;
|
||||
rewriter_config.add_optimizers(kOptimizerName);
|
||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
||||
rewriter_config.set_fail_on_optimizer_errors(true);
|
||||
auto custom_optimizer = rewriter_config.add_custom_optimizers();
|
||||
custom_optimizer->set_name(kOptimizerName);
|
||||
auto* custom_optimizations_list =
|
||||
(*custom_optimizer->mutable_parameter_map())["optimizers"]
|
||||
.mutable_list();
|
||||
for (const auto& opt : optimizations) {
|
||||
custom_optimizations_list->add_s(opt);
|
||||
}
|
||||
auto* config_list =
|
||||
(*custom_optimizer->mutable_parameter_map())["optimizer_configs"]
|
||||
.mutable_list();
|
||||
for (const auto& config : optimizations_configs) {
|
||||
config_list->add_s(config);
|
||||
}
|
||||
return rewriter_config;
|
||||
}
|
||||
|
||||
std::vector<string> optimization_configs_;
|
||||
};
|
||||
|
||||
namespace {
|
||||
REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
|
||||
OptimizeDatasetOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
49
tensorflow/core/kernels/data/optimize_dataset_op.h
Normal file
49
tensorflow/core/kernels/data/optimize_dataset_op.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
class OptimizeDatasetOp : public UnaryDatasetOpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetType = "Optimize";
|
||||
static constexpr const char* const kInputDataset = "input_dataset";
|
||||
static constexpr const char* const kOptimizations = "optimizations";
|
||||
static constexpr const char* const kOutputTypes = "output_types";
|
||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||
static constexpr const char* const kOptimizationConfigs =
|
||||
"optimization_configs";
|
||||
|
||||
explicit OptimizeDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
protected:
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
|
||||
DatasetBase** output) override;
|
||||
|
||||
private:
|
||||
static RewriterConfig CreateConfig(std::vector<string> optimizations,
|
||||
std::vector<string> optimizations_configs);
|
||||
|
||||
std::vector<string> optimization_configs_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIMIZE_DATASET_OP_H_
|
@ -9,15 +9,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/data/optimize_dataset_op.h"
|
||||
|
||||
#include "tensorflow/core/kernels/data/dataset_test_base.h"
|
||||
#include "tensorflow/core/kernels/data/range_dataset_op.h"
|
||||
#include "tensorflow/core/kernels/data/take_dataset_op.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
|
||||
constexpr char kNodeName[] = "optimize_dataset";
|
||||
constexpr char kOpName[] = "OptimizeDataset";
|
||||
constexpr char kNoopElimination[] = "noop_elimination";
|
||||
constexpr char kIteratorPrefix[] = "Iterator";
|
||||
|
||||
class OptimizeDatasetOpTest : public DatasetOpsTestBase {
|
||||
protected:
|
||||
@ -28,10 +32,11 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
|
||||
const std::vector<string>& optimization_configs,
|
||||
std::unique_ptr<OpKernel>* optimize_dataset_op_kernel) {
|
||||
NodeDef node_def = test::function::NDef(
|
||||
kNodeName, kOpName, {"input_dataset", "optimizations"},
|
||||
{{"output_types", output_types},
|
||||
{"output_shapes", output_shapes},
|
||||
{"optimization_configs", optimization_configs}});
|
||||
kNodeName, name_utils::OpName(OptimizeDatasetOp::kDatasetType),
|
||||
{OptimizeDatasetOp::kInputDataset, OptimizeDatasetOp::kOptimizations},
|
||||
{{OptimizeDatasetOp::kOutputTypes, output_types},
|
||||
{OptimizeDatasetOp::kOutputShapes, output_shapes},
|
||||
{OptimizeDatasetOp::kOptimizationConfigs, optimization_configs}});
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, optimize_dataset_op_kernel));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -55,12 +60,13 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
|
||||
GraphConstructorOptions graph_opts;
|
||||
graph_opts.allow_internal_ops = true;
|
||||
graph_opts.expect_device_spec = false;
|
||||
TF_RETURN_IF_ERROR(RunFunction(
|
||||
test::function::MakeRangeDataset(),
|
||||
/*attrs*/
|
||||
{{"output_types", output_types}, {"output_shapes", output_shapes}},
|
||||
/*inputs*/ {start, stop, step}, graph_opts,
|
||||
/*rets*/ {range_dataset}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunFunction(test::function::MakeRangeDataset(),
|
||||
/*attrs*/
|
||||
{{RangeDatasetOp::kOutputTypes, output_types},
|
||||
{RangeDatasetOp::kOutputShapes, output_shapes}},
|
||||
/*inputs*/ {start, stop, step}, graph_opts,
|
||||
/*rets*/ {range_dataset}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -74,12 +80,13 @@ class OptimizeDatasetOpTest : public DatasetOpsTestBase {
|
||||
graph_opts.expect_device_spec = false;
|
||||
|
||||
Tensor count_tensor = CreateTensor<int64>(TensorShape({}), {count});
|
||||
TF_RETURN_IF_ERROR(RunFunction(
|
||||
test::function::MakeTakeDataset(),
|
||||
/*attrs*/
|
||||
{{"output_types", output_types}, {"output_shapes", output_shapes}},
|
||||
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
|
||||
/*rets*/ {take_dataset}));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunFunction(test::function::MakeTakeDataset(),
|
||||
/*attrs*/
|
||||
{{TakeDatasetOp::kOutputTypes, output_types},
|
||||
{TakeDatasetOp::kOutputShapes, output_shapes}},
|
||||
/*inputs*/ {input_dataset, count_tensor}, graph_opts,
|
||||
/*rets*/ {take_dataset}));
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -109,7 +116,7 @@ TEST_F(OptimizeDatasetOpTest, NoopElimination) {
|
||||
/*optimization_configs*/ {},
|
||||
&optimize_dataset_kernel));
|
||||
Tensor optimizations =
|
||||
CreateTensor<string>(TensorShape({1}), {"noop_elimination"});
|
||||
CreateTensor<string>(TensorShape({1}), {kNoopElimination});
|
||||
gtl::InlinedVector<TensorValue, 4> inputs(
|
||||
{TensorValue(&take_dataset_tensor), TensorValue(&optimizations)});
|
||||
std::unique_ptr<OpKernelContext> optimize_dataset_context;
|
||||
@ -127,7 +134,7 @@ TEST_F(OptimizeDatasetOpTest, NoopElimination) {
|
||||
CreateIteratorContext(optimize_dataset_context.get(), &iterator_context));
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
TF_ASSERT_OK(optimize_dataset->MakeIterator(iterator_context.get(),
|
||||
"Iterator", &iterator));
|
||||
kIteratorPrefix, &iterator));
|
||||
|
||||
bool end_of_sequence = false;
|
||||
std::vector<Tensor> out_tensors;
|
||||
|
Loading…
x
Reference in New Issue
Block a user