Remove extra code introduced by callback

This commit is contained in:
Michael Kuchnik 2020-11-02 16:32:50 -05:00
parent 02a8bb1da9
commit a57ecacbcb
13 changed files with 18 additions and 167 deletions

View File

@ -122,20 +122,6 @@ class Executor {
n.WaitForNotification(); n.WaitForNotification();
return ret; return ret;
} }
// Synchronous wrapper for RunAsync() with callback support.
// Chains the callback to enable custom processing e.g., to collect stats.
virtual Status Run(const Args& args, DoneCallback done) {
Status ret;
Notification n;
RunAsync(args, [&ret, &n, done = std::move(done)](const Status& s) {
ret = s;
done(s);
n.Notify();
});
n.WaitForNotification();
return ret;
}
}; };
// Creates an Executor that computes the given "graph". // Creates an Executor that computes the given "graph".

View File

@ -169,15 +169,9 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args, Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets) override; std::vector<Tensor>* rets) override;
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, DoneCallback done) override;
Status RunSync(Options opts, Handle handle, Status RunSync(Options opts, Handle handle,
CallFrameInterface* frame) override; CallFrameInterface* frame) override;
Status RunSync(Options opts, Handle handle,
CallFrameInterface* frame, DoneCallback done) override;
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props, Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) override; OpKernel** kernel) override;
@ -254,26 +248,11 @@ Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
return base_flr_->RunSync(std::move(opts), handle, args, rets); return base_flr_->RunSync(std::move(opts), handle, args, rets);
} }
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) {
return base_flr_->RunSync(std::move(opts), handle, args, rets,
std::move(done));
}
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle, Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) { CallFrameInterface* call_frame) {
return base_flr_->RunSync(std::move(opts), handle, call_frame); return base_flr_->RunSync(std::move(opts), handle, call_frame);
} }
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame,
DoneCallback done) {
return base_flr_->RunSync(std::move(opts), handle, call_frame,
std::move(done));
}
Status FunctionLibraryRuntimeOverlay::CreateKernel( Status FunctionLibraryRuntimeOverlay::CreateKernel(
const std::shared_ptr<const NodeProperties>&, OpKernel**) { const std::shared_ptr<const NodeProperties>&, OpKernel**) {
// We don't have access to base_lib_def_ in base function library runtime (aka // We don't have access to base_lib_def_ in base function library runtime (aka
@ -371,12 +350,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
DoneCallback done) override; DoneCallback done) override;
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args, Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets) override; std::vector<Tensor>* rets) override;
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, DoneCallback done) override;
Status RunSync(Options opts, Handle handle, Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) override; CallFrameInterface* call_frame) override;
Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame, DoneCallback done) override;
bool IsStateful(const string& function) const override; bool IsStateful(const string& function) const override;
@ -1312,27 +1287,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors); return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
} }
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) {
Item* item = nullptr;
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
if (item == nullptr) {
return parent_->RunSync(opts, handle, args, rets, done);
}
Executor::Args exec_args;
const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
TF_RETURN_IF_ERROR(frame.SetArgs(args));
ExecutorArgsFromOptions(opts, &frame, &exec_args);
TF_RETURN_IF_ERROR(item->exec->Run(exec_args, done));
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
}
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle, Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) { CallFrameInterface* call_frame) {
Item* item = nullptr; Item* item = nullptr;
@ -1347,21 +1301,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
return item->exec->Run(exec_args); return item->exec->Run(exec_args);
} }
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame,
DoneCallback done) {
Item* item = nullptr;
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
if (item == nullptr) {
return parent_->RunSync(opts, handle, call_frame, done);
}
Executor::Args exec_args;
ExecutorArgsFromOptions(opts, call_frame, &exec_args);
return item->exec->Run(exec_args, done);
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const { bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
const OpDef* op_def; const OpDef* op_def;
const Status s = base_lib_def_->LookUpOpDef(func, &op_def); const Status s = base_lib_def_->LookUpOpDef(func, &op_def);

View File

@ -1611,23 +1611,6 @@ Status ProcessFunctionLibraryRuntime::RunSync(
return s; return s;
} }
Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
Notification n;
Status s;
Run(opts, handle, args, rets,
[&n, &s, done = std::move(done)](const Status& status) {
s.Update(status);
done(s);
n.Notify();
});
n.WaitForNotification();
return s;
}
Status ProcessFunctionLibraryRuntime::RunSync( Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& opts, const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const { FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
@ -1641,22 +1624,6 @@ Status ProcessFunctionLibraryRuntime::RunSync(
return s; return s;
} }
Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const {
Notification n;
Status s;
Run(opts, handle, frame,
[&n, &s, done = std::move(done)](const Status& status) {
s.Update(status);
done(s);
n.Notify();
});
n.WaitForNotification();
return s;
}
void ProcessFunctionLibraryRuntime::Run( void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts, const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args, FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,

View File

@ -197,17 +197,9 @@ class ProcessFunctionLibraryRuntime {
Status RunSync(const FunctionLibraryRuntime::Options& opts, Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, FunctionLibraryRuntime::Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const; gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts, Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, FunctionLibraryRuntime::Handle handle,
CallFrameInterface* frame) const; CallFrameInterface* frame) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const;
const DeviceMgr* device_mgr() { return device_mgr_; } const DeviceMgr* device_mgr() { return device_mgr_; }

View File

@ -779,15 +779,8 @@ class FunctionLibraryRuntime {
virtual Status RunSync(Options opts, Handle handle, virtual Status RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets) = 0; std::vector<Tensor>* rets) = 0;
virtual Status RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) = 0;
virtual Status RunSync(Options opts, Handle handle, virtual Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) = 0; CallFrameInterface* call_frame) = 0;
virtual Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame,
DoneCallback done) = 0;
// Creates a "kernel" for the given NodeProperties "props". // Creates a "kernel" for the given NodeProperties "props".
// //

View File

@ -850,14 +850,9 @@ Status InstantiatedCapturedFunction::Run(
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
if (collect_usage) { if (collect_usage) {
// Resource usage is for function execution is gathered from the executor. // Resource usage is for function execution is gathered from the executor.
// NOTE(mkuchnik): RecordStop and RecordStart have to be called around node->record_stop(EnvTime::NowNanos());
// this function to prevent double-counting resource usage. TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
auto callback = std::bind( node->record_start(EnvTime::NowNanos());
[this, node, collect_usage](
IteratorContext* ctx,
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
if (node) { if (node) {
// TODO(b/129085499) Utilize the `node_name` which would be unique // TODO(b/129085499) Utilize the `node_name` which would be unique
// than the prefix for the function execution time statistics. // than the prefix for the function execution time statistics.
@ -873,10 +868,6 @@ Status InstantiatedCapturedFunction::Run(
} }
node->add_processing_time(stats_collector->processing_time()); node->add_processing_time(stats_collector->processing_time());
} }
},
ctx, std::move(stats_collector), std::placeholders::_1);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame,
std::move(callback)));
} else { } else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
} }

View File

@ -243,11 +243,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
if (states_.find(key) == states_.end()) { if (states_.find(key) == states_.end()) {
// Run the init function to create the initial state. // Run the init function to create the initial state.
std::vector<Tensor> init_func_output; std::vector<Tensor> init_func_output;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_init_func_->Run( TF_RETURN_IF_ERROR(instantiated_init_func_->Run(
ctx, std::move(key_func_output), &init_func_output, ctx, std::move(key_func_output), &init_func_output,
model_node())); model_node()));
RecordStart(ctx);
states_[key] = init_func_output; states_[key] = init_func_output;
} }
@ -260,10 +258,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::back_inserter(args)); std::back_inserter(args));
std::vector<Tensor> reduce_func_output; std::vector<Tensor> reduce_func_output;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run( TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
ctx, std::move(args), &reduce_func_output, model_node())); ctx, std::move(args), &reduce_func_output, model_node()));
RecordStart(ctx);
states_[key] = reduce_func_output; states_[key] = reduce_func_output;
} else { } else {
keys_.resize(states_.size()); keys_.resize(states_.size());

View File

@ -254,11 +254,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// Run the window size function on the key to identify its // Run the window size function on the key to identify its
// window size. // window size.
std::vector<Tensor> window_size_func_output; std::vector<Tensor> window_size_func_output;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run( TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
ctx, std::move(key_func_output), &window_size_func_output, ctx, std::move(key_func_output), &window_size_func_output,
model_node())); model_node()));
RecordStart(ctx);
if (window_size_func_output.size() != 1 || if (window_size_func_output.size() != 1 ||
window_size_func_output[0].dtype() != DT_INT64 || window_size_func_output[0].dtype() != DT_INT64 ||
@ -489,11 +487,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> args( std::vector<Tensor> args(
{std::move(key_arg), std::move(group_dataset_arg)}); {std::move(key_arg), std::move(group_dataset_arg)});
std::vector<Tensor> return_values; std::vector<Tensor> return_values;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args), TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args),
&return_values, &return_values,
model_node())); model_node()));
RecordStart(ctx);
if (!(return_values.size() == 1 && if (!(return_values.size() == 1 &&
return_values[0].dtype() == DT_VARIANT && return_values[0].dtype() == DT_VARIANT &&

View File

@ -315,10 +315,8 @@ class LoadDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> reader_output; std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor)); reader_input.push_back(std::move(input_dataset_tensor));
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run( TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
ctx, std::move(reader_input), &reader_output, model_node())); ctx, std::move(reader_input), &reader_output, model_node()));
RecordStart(ctx);
if (reader_output.size() != 1) { if (reader_output.size() != 1) {
return errors::InvalidArgument( return errors::InvalidArgument(
"reader_func returns more than one argument."); "reader_func returns more than one argument.");

View File

@ -200,11 +200,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_and_output.reserve(dataset()->state_types_.size() + state_and_output.reserve(dataset()->state_types_.size() +
output_dtypes().size()); output_dtypes().size());
RecordStop(ctx);
Status s = instantiated_captured_func_->Run(ctx, std::move(args), Status s = instantiated_captured_func_->Run(ctx, std::move(args),
&state_and_output, &state_and_output,
model_node()); model_node());
RecordStart(ctx);
DCHECK(state_and_output.size() <= DCHECK(state_and_output.size() <=
dataset()->state_types_.size() + output_dtypes().size()); dataset()->state_types_.size() + output_dtypes().size());
if (s.ok()) { if (s.ok()) {

View File

@ -571,10 +571,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
std::vector<Tensor> reader_output; std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor)); reader_input.push_back(std::move(input_dataset_tensor));
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run( TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
ctx, std::move(reader_input), &reader_output, model_node())); ctx, std::move(reader_input), &reader_output, model_node()));
RecordStart(ctx);
if (reader_output.size() != 1) { if (reader_output.size() != 1) {
return errors::InvalidArgument( return errors::InvalidArgument(
"reader_func returns more than one argument."); "reader_func returns more than one argument.");

View File

@ -158,11 +158,9 @@ class MapDatasetOp::Dataset : public DatasetBase {
return Status::OK(); return Status::OK();
} }
RecordStop(ctx);
Status s = Status s =
instantiated_captured_func_->Run(ctx, std::move(args), out_tensors, instantiated_captured_func_->Run(ctx, std::move(args), out_tensors,
model_node()); model_node());
RecordStart(ctx);
if (errors::IsOutOfRange(s)) { if (errors::IsOutOfRange(s)) {
if (dataset()->preserve_cardinality_) { if (dataset()->preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of // To guarantee that the transformation preserves the cardinality of

View File

@ -454,11 +454,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
RecordStop(ctx.get()); RecordStop(ctx.get());
(*ctx->runner())( (*ctx->runner())(
[this, ctx, fn = std::move(fn), done = std::move(done)]() { [this, ctx, fn = std::move(fn), done = std::move(done)]() {
Status s = fn();
RecordStart(ctx.get()); RecordStart(ctx.get());
auto cleanup = auto cleanup =
gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
done(s); done(fn());
}); });
RecordStart(ctx.get()); RecordStart(ctx.get());
} }