Remove extra code introduced by callback

This commit is contained in:
Michael Kuchnik 2020-11-02 16:32:50 -05:00
parent 02a8bb1da9
commit a57ecacbcb
13 changed files with 18 additions and 167 deletions

View File

@ -122,20 +122,6 @@ class Executor {
n.WaitForNotification();
return ret;
}
// Synchronous wrapper for RunAsync() with callback support.
// Chains the callback to enable custom processing e.g., to collect stats.
virtual Status Run(const Args& args, DoneCallback done) {
Status ret;
Notification n;
RunAsync(args, [&ret, &n, done = std::move(done)](const Status& s) {
ret = s;
done(s);
n.Notify();
});
n.WaitForNotification();
return ret;
}
};
// Creates an Executor that computes the given "graph".

View File

@ -169,15 +169,9 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets) override;
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, DoneCallback done) override;
Status RunSync(Options opts, Handle handle,
CallFrameInterface* frame) override;
Status RunSync(Options opts, Handle handle,
CallFrameInterface* frame, DoneCallback done) override;
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) override;
@ -254,26 +248,11 @@ Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
return base_flr_->RunSync(std::move(opts), handle, args, rets);
}
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) {
return base_flr_->RunSync(std::move(opts), handle, args, rets,
std::move(done));
}
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) {
return base_flr_->RunSync(std::move(opts), handle, call_frame);
}
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame,
DoneCallback done) {
return base_flr_->RunSync(std::move(opts), handle, call_frame,
std::move(done));
}
Status FunctionLibraryRuntimeOverlay::CreateKernel(
const std::shared_ptr<const NodeProperties>&, OpKernel**) {
// We don't have access to base_lib_def_ in base function library runtime (aka
@ -371,12 +350,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
DoneCallback done) override;
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets) override;
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, DoneCallback done) override;
Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) override;
Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame, DoneCallback done) override;
bool IsStateful(const string& function) const override;
@ -1312,27 +1287,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
}
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) {
Item* item = nullptr;
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
if (item == nullptr) {
return parent_->RunSync(opts, handle, args, rets, done);
}
Executor::Args exec_args;
const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
TF_RETURN_IF_ERROR(frame.SetArgs(args));
ExecutorArgsFromOptions(opts, &frame, &exec_args);
TF_RETURN_IF_ERROR(item->exec->Run(exec_args, done));
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
}
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) {
Item* item = nullptr;
@ -1347,21 +1301,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
return item->exec->Run(exec_args);
}
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame,
DoneCallback done) {
Item* item = nullptr;
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
if (item == nullptr) {
return parent_->RunSync(opts, handle, call_frame, done);
}
Executor::Args exec_args;
ExecutorArgsFromOptions(opts, call_frame, &exec_args);
return item->exec->Run(exec_args, done);
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
const OpDef* op_def;
const Status s = base_lib_def_->LookUpOpDef(func, &op_def);

View File

@ -1611,23 +1611,6 @@ Status ProcessFunctionLibraryRuntime::RunSync(
return s;
}
Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
Notification n;
Status s;
Run(opts, handle, args, rets,
[&n, &s, done = std::move(done)](const Status& status) {
s.Update(status);
done(s);
n.Notify();
});
n.WaitForNotification();
return s;
}
Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
@ -1641,22 +1624,6 @@ Status ProcessFunctionLibraryRuntime::RunSync(
return s;
}
Status ProcessFunctionLibraryRuntime::RunSync(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const {
Notification n;
Status s;
Run(opts, handle, frame,
[&n, &s, done = std::move(done)](const Status& status) {
s.Update(status);
done(s);
n.Notify();
});
n.WaitForNotification();
return s;
}
void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,

View File

@ -197,17 +197,9 @@ class ProcessFunctionLibraryRuntime {
Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
CallFrameInterface* frame) const;
Status RunSync(const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle,
CallFrameInterface* frame,
FunctionLibraryRuntime::DoneCallback done) const;
const DeviceMgr* device_mgr() { return device_mgr_; }

View File

@ -779,15 +779,8 @@ class FunctionLibraryRuntime {
virtual Status RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets) = 0;
virtual Status RunSync(Options opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) = 0;
virtual Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame) = 0;
virtual Status RunSync(Options opts, Handle handle,
CallFrameInterface* call_frame,
DoneCallback done) = 0;
// Creates a "kernel" for the given NodeProperties "props".
//

View File

@ -850,33 +850,24 @@ Status InstantiatedCapturedFunction::Run(
profiler::TraceMeLevel::kInfo);
if (collect_usage) {
// Resource usage is for function execution is gathered from the executor.
// NOTE(mkuchnik): RecordStop and RecordStart have to be called around
// this function to prevent double-counting resource usage.
auto callback = std::bind(
[this, node, collect_usage](
IteratorContext* ctx,
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
if (node) {
// TODO(b/129085499) Utilize the `node_name` which would be unique
// than the prefix for the function execution time statistics.
// prefix_with_func_name would then be node_name + func_name.
if (ctx->stats_aggregator()) {
string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter,
captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
}
node->add_processing_time(stats_collector->processing_time());
node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
node->record_start(EnvTime::NowNanos());
if (node) {
// TODO(b/129085499) Utilize the `node_name` which would be unique
// than the prefix for the function execution time statistics.
// prefix_with_func_name would then be node_name + func_name.
if (ctx->stats_aggregator()) {
string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter,
captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
}
},
ctx, std::move(stats_collector), std::placeholders::_1);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame,
std::move(callback)));
node->add_processing_time(stats_collector->processing_time());
}
} else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
}

View File

@ -243,11 +243,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
if (states_.find(key) == states_.end()) {
// Run the init function to create the initial state.
std::vector<Tensor> init_func_output;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_init_func_->Run(
ctx, std::move(key_func_output), &init_func_output,
model_node()));
RecordStart(ctx);
states_[key] = init_func_output;
}
@ -260,10 +258,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::back_inserter(args));
std::vector<Tensor> reduce_func_output;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
ctx, std::move(args), &reduce_func_output, model_node()));
RecordStart(ctx);
states_[key] = reduce_func_output;
} else {
keys_.resize(states_.size());

View File

@ -254,11 +254,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// Run the window size function on the key to identify its
// window size.
std::vector<Tensor> window_size_func_output;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
ctx, std::move(key_func_output), &window_size_func_output,
model_node()));
RecordStart(ctx);
if (window_size_func_output.size() != 1 ||
window_size_func_output[0].dtype() != DT_INT64 ||
@ -489,11 +487,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> args(
{std::move(key_arg), std::move(group_dataset_arg)});
std::vector<Tensor> return_values;
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args),
&return_values,
model_node()));
RecordStart(ctx);
if (!(return_values.size() == 1 &&
return_values[0].dtype() == DT_VARIANT &&

View File

@ -315,10 +315,8 @@ class LoadDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor));
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
ctx, std::move(reader_input), &reader_output, model_node()));
RecordStart(ctx);
if (reader_output.size() != 1) {
return errors::InvalidArgument(
"reader_func returns more than one argument.");

View File

@ -200,11 +200,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_and_output.reserve(dataset()->state_types_.size() +
output_dtypes().size());
RecordStop(ctx);
Status s = instantiated_captured_func_->Run(ctx, std::move(args),
&state_and_output,
model_node());
RecordStart(ctx);
DCHECK(state_and_output.size() <=
dataset()->state_types_.size() + output_dtypes().size());
if (s.ok()) {

View File

@ -571,10 +571,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor));
RecordStop(ctx);
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
ctx, std::move(reader_input), &reader_output, model_node()));
RecordStart(ctx);
if (reader_output.size() != 1) {
return errors::InvalidArgument(
"reader_func returns more than one argument.");

View File

@ -158,11 +158,9 @@ class MapDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
RecordStop(ctx);
Status s =
instantiated_captured_func_->Run(ctx, std::move(args), out_tensors,
model_node());
RecordStart(ctx);
if (errors::IsOutOfRange(s)) {
if (dataset()->preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of

View File

@ -454,11 +454,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
RecordStop(ctx.get());
(*ctx->runner())(
[this, ctx, fn = std::move(fn), done = std::move(done)]() {
Status s = fn();
RecordStart(ctx.get());
auto cleanup =
gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
done(s);
done(fn());
});
RecordStart(ctx.get());
}