diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 6f862b613f4..d590ae0f711 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -122,20 +122,6 @@ class Executor { n.WaitForNotification(); return ret; } - - // Synchronous wrapper for RunAsync() with callback support. - // Chains the callback to enable custom processing e.g., to collect stats. - virtual Status Run(const Args& args, DoneCallback done) { - Status ret; - Notification n; - RunAsync(args, [&ret, &n, done = std::move(done)](const Status& s) { - ret = s; - done(s); - n.Notify(); - }); - n.WaitForNotification(); - return ret; - } }; // Creates an Executor that computes the given "graph". diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index d231e09ceab..b5a2e0a9ef9 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -169,15 +169,9 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime { Status RunSync(Options opts, Handle handle, gtl::ArraySlice args, std::vector* rets) override; - Status RunSync(Options opts, Handle handle, gtl::ArraySlice args, - std::vector* rets, DoneCallback done) override; - Status RunSync(Options opts, Handle handle, CallFrameInterface* frame) override; - Status RunSync(Options opts, Handle handle, - CallFrameInterface* frame, DoneCallback done) override; - Status CreateKernel(const std::shared_ptr& props, OpKernel** kernel) override; @@ -254,26 +248,11 @@ Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, return base_flr_->RunSync(std::move(opts), handle, args, rets); } -Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, - gtl::ArraySlice args, - std::vector* rets, - DoneCallback done) { - return base_flr_->RunSync(std::move(opts), handle, args, rets, - std::move(done)); -} - Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, CallFrameInterface* call_frame) { return base_flr_->RunSync(std::move(opts), handle, call_frame); } -Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, - CallFrameInterface* call_frame, - DoneCallback done) { - return base_flr_->RunSync(std::move(opts), handle, call_frame, - std::move(done)); -} - Status FunctionLibraryRuntimeOverlay::CreateKernel( const std::shared_ptr&, OpKernel**) { // We don't have access to base_lib_def_ in base function library runtime (aka @@ -371,12 +350,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { DoneCallback done) override; Status RunSync(Options opts, Handle handle, gtl::ArraySlice args, std::vector* rets) override; - Status RunSync(Options opts, Handle handle, gtl::ArraySlice args, - std::vector* rets, DoneCallback done) override; Status RunSync(Options opts, Handle handle, CallFrameInterface* call_frame) override; - Status RunSync(Options opts, Handle handle, - CallFrameInterface* call_frame, DoneCallback done) override; bool IsStateful(const string& function) const override; @@ -1312,27 +1287,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, return frame.ConsumeRetvals(rets, opts.allow_dead_tensors); } -Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, - gtl::ArraySlice args, - std::vector* rets, - DoneCallback done) { - Item* item = nullptr; - std::unique_ptr rendezvous; - TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous)); - if (item == nullptr) { - return parent_->RunSync(opts, handle, args, rets, done); - } - - Executor::Args exec_args; - const FunctionBody* fbody = GetFunctionBody(handle); - FunctionCallFrame frame(fbody->arg_types, fbody->ret_types); - TF_RETURN_IF_ERROR(frame.SetArgs(args)); - ExecutorArgsFromOptions(opts, &frame, &exec_args); - - TF_RETURN_IF_ERROR(item->exec->Run(exec_args, done)); - return frame.ConsumeRetvals(rets, opts.allow_dead_tensors); -} - Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, CallFrameInterface* call_frame) { Item* item = nullptr; @@ -1347,21 +1301,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, return item->exec->Run(exec_args); } -Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, - CallFrameInterface* call_frame, - DoneCallback done) { - Item* item = nullptr; - std::unique_ptr rendezvous; - TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous)); - if (item == nullptr) { - return parent_->RunSync(opts, handle, call_frame, done); - } - - Executor::Args exec_args; - ExecutorArgsFromOptions(opts, call_frame, &exec_args); - return item->exec->Run(exec_args, done); -} - bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const { const OpDef* op_def; const Status s = base_lib_def_->LookUpOpDef(func, &op_def); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index b4fe757d559..50f3b52e4c6 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -1611,23 +1611,6 @@ Status ProcessFunctionLibraryRuntime::RunSync( return s; } -Status ProcessFunctionLibraryRuntime::RunSync( - const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, - std::vector* rets, - FunctionLibraryRuntime::DoneCallback done) const { - Notification n; - Status s; - Run(opts, handle, args, rets, - [&n, &s, done = std::move(done)](const Status& status) { - s.Update(status); - done(s); - n.Notify(); - }); - n.WaitForNotification(); - return s; -} - Status ProcessFunctionLibraryRuntime::RunSync( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const { @@ -1641,22 +1624,6 @@ Status ProcessFunctionLibraryRuntime::RunSync( return s; } -Status ProcessFunctionLibraryRuntime::RunSync( - const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame, - FunctionLibraryRuntime::DoneCallback done) const { - Notification n; - Status s; - Run(opts, handle, frame, - [&n, &s, done = std::move(done)](const Status& status) { - s.Update(status); - done(s); - n.Notify(); - }); - n.WaitForNotification(); - return s; -} - void ProcessFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args, diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 14accf9e55c..54d59f35ff3 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -197,17 +197,9 @@ class ProcessFunctionLibraryRuntime { Status RunSync(const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, std::vector* rets) const; - Status RunSync(const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, - gtl::ArraySlice args, std::vector* rets, - FunctionLibraryRuntime::DoneCallback done) const; Status RunSync(const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const; - Status RunSync(const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, - CallFrameInterface* frame, - FunctionLibraryRuntime::DoneCallback done) const; const DeviceMgr* device_mgr() { return device_mgr_; } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 4a2d14d0fe2..3c048161b7d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -779,15 +779,8 @@ class FunctionLibraryRuntime { virtual Status RunSync(Options opts, Handle handle, gtl::ArraySlice args, std::vector* rets) = 0; - virtual Status RunSync(Options opts, Handle handle, - gtl::ArraySlice args, - std::vector* rets, - DoneCallback done) = 0; virtual Status RunSync(Options opts, Handle handle, CallFrameInterface* call_frame) = 0; - virtual Status RunSync(Options opts, Handle handle, - CallFrameInterface* call_frame, - DoneCallback done) = 0; // Creates a "kernel" for the given NodeProperties "props". // diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index bf8d5a4311c..b95e46e414d 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -850,33 +850,24 @@ Status InstantiatedCapturedFunction::Run( profiler::TraceMeLevel::kInfo); if (collect_usage) { // Resource usage is for function execution is gathered from the executor. - // NOTE(mkuchnik): RecordStop and RecordStart have to be called around - // this function to prevent double-counting resource usage. - auto callback = std::bind( - [this, node, collect_usage]( - IteratorContext* ctx, - const std::shared_ptr& stats_collector, - // Begin unbound arguments. - Status s) { - if (node) { - // TODO(b/129085499) Utilize the `node_name` which would be unique - // than the prefix for the function execution time statistics. - // prefix_with_func_name would then be node_name + func_name. - if (ctx->stats_aggregator()) { - string prefix_with_func_name = - strings::StrCat(node->name(), stats_utils::kDelimiter, - captured_func_->func().name()); - ctx->stats_aggregator()->AddToHistogram( - stats_utils::ExecutionTimeHistogramName(prefix_with_func_name), - {static_cast(stats_collector->processing_time())}, - node->num_elements()); - } - node->add_processing_time(stats_collector->processing_time()); + node->record_stop(EnvTime::NowNanos()); + TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); + node->record_start(EnvTime::NowNanos()); + if (node) { + // TODO(b/129085499) Utilize the `node_name` which would be unique + // than the prefix for the function execution time statistics. + // prefix_with_func_name would then be node_name + func_name. + if (ctx->stats_aggregator()) { + string prefix_with_func_name = + strings::StrCat(node->name(), stats_utils::kDelimiter, + captured_func_->func().name()); + ctx->stats_aggregator()->AddToHistogram( + stats_utils::ExecutionTimeHistogramName(prefix_with_func_name), + {static_cast(stats_collector->processing_time())}, + node->num_elements()); } - }, - ctx, std::move(stats_collector), std::placeholders::_1); - TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame, - std::move(callback))); + node->add_processing_time(stats_collector->processing_time()); + } } else { TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); } diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc index 57c97333fc1..0572259cac6 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc @@ -243,11 +243,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { if (states_.find(key) == states_.end()) { // Run the init function to create the initial state. std::vector init_func_output; - RecordStop(ctx); TF_RETURN_IF_ERROR(instantiated_init_func_->Run( ctx, std::move(key_func_output), &init_func_output, model_node())); - RecordStart(ctx); states_[key] = init_func_output; } @@ -260,10 +258,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { std::back_inserter(args)); std::vector reduce_func_output; - RecordStop(ctx); TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run( ctx, std::move(args), &reduce_func_output, model_node())); - RecordStart(ctx); states_[key] = reduce_func_output; } else { keys_.resize(states_.size()); diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index d66158ec242..8ac4b17a4cb 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -254,11 +254,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { // Run the window size function on the key to identify its // window size. std::vector window_size_func_output; - RecordStop(ctx); TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run( ctx, std::move(key_func_output), &window_size_func_output, model_node())); - RecordStart(ctx); if (window_size_func_output.size() != 1 || window_size_func_output[0].dtype() != DT_INT64 || @@ -489,11 +487,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { std::vector args( {std::move(key_arg), std::move(group_dataset_arg)}); std::vector return_values; - RecordStop(ctx); TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args), &return_values, model_node())); - RecordStart(ctx); if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && diff --git a/tensorflow/core/kernels/data/experimental/io_ops.cc b/tensorflow/core/kernels/data/experimental/io_ops.cc index b4fc16a8e42..10512d8a296 100644 --- a/tensorflow/core/kernels/data/experimental/io_ops.cc +++ b/tensorflow/core/kernels/data/experimental/io_ops.cc @@ -315,10 +315,8 @@ class LoadDatasetOp::Dataset : public DatasetBase { std::vector reader_output; reader_input.push_back(std::move(input_dataset_tensor)); - RecordStop(ctx); TF_RETURN_IF_ERROR(instantiated_captured_func_->Run( ctx, std::move(reader_input), &reader_output, model_node())); - RecordStart(ctx); if (reader_output.size() != 1) { return errors::InvalidArgument( "reader_func returns more than one argument."); diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index ccf399cf0fa..0a7ae61fa78 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -200,11 +200,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { state_and_output.reserve(dataset()->state_types_.size() + output_dtypes().size()); - RecordStop(ctx); Status s = instantiated_captured_func_->Run(ctx, std::move(args), &state_and_output, model_node()); - RecordStart(ctx); DCHECK(state_and_output.size() <= dataset()->state_types_.size() + output_dtypes().size()); if (s.ok()) { diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index 88eed32c551..3575c3c5c93 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -571,10 +571,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( std::vector reader_output; reader_input.push_back(std::move(input_dataset_tensor)); - RecordStop(ctx); TF_RETURN_IF_ERROR(instantiated_reader_func_->Run( ctx, std::move(reader_input), &reader_output, model_node())); - RecordStart(ctx); if (reader_output.size() != 1) { return errors::InvalidArgument( "reader_func returns more than one argument."); diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 7674e7e5a41..37fd40673d4 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -158,11 +158,9 @@ class MapDatasetOp::Dataset : public DatasetBase { return Status::OK(); } - RecordStop(ctx); Status s = instantiated_captured_func_->Run(ctx, std::move(args), out_tensors, model_node()); - RecordStart(ctx); if (errors::IsOutOfRange(s)) { if (dataset()->preserve_cardinality_) { // To guarantee that the transformation preserves the cardinality of diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 727e6e82dec..42244ca61fd 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -454,11 +454,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { RecordStop(ctx.get()); (*ctx->runner())( [this, ctx, fn = std::move(fn), done = std::move(done)]() { - Status s = fn(); RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); - done(s); + done(fn()); }); RecordStart(ctx.get()); }