Remove extra code introduced by callback
This commit is contained in:
parent
02a8bb1da9
commit
a57ecacbcb
@ -122,20 +122,6 @@ class Executor {
|
||||
n.WaitForNotification();
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Synchronous wrapper for RunAsync() with callback support.
|
||||
// Chains the callback to enable custom processing e.g., to collect stats.
|
||||
virtual Status Run(const Args& args, DoneCallback done) {
|
||||
Status ret;
|
||||
Notification n;
|
||||
RunAsync(args, [&ret, &n, done = std::move(done)](const Status& s) {
|
||||
ret = s;
|
||||
done(s);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
// Creates an Executor that computes the given "graph".
|
||||
|
@ -169,15 +169,9 @@ class FunctionLibraryRuntimeOverlay : public FunctionLibraryRuntime {
|
||||
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) override;
|
||||
|
||||
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets, DoneCallback done) override;
|
||||
|
||||
Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* frame) override;
|
||||
|
||||
Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* frame, DoneCallback done) override;
|
||||
|
||||
Status CreateKernel(const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) override;
|
||||
|
||||
@ -254,26 +248,11 @@ Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
|
||||
return base_flr_->RunSync(std::move(opts), handle, args, rets);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets,
|
||||
DoneCallback done) {
|
||||
return base_flr_->RunSync(std::move(opts), handle, args, rets,
|
||||
std::move(done));
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) {
|
||||
return base_flr_->RunSync(std::move(opts), handle, call_frame);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame,
|
||||
DoneCallback done) {
|
||||
return base_flr_->RunSync(std::move(opts), handle, call_frame,
|
||||
std::move(done));
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeOverlay::CreateKernel(
|
||||
const std::shared_ptr<const NodeProperties>&, OpKernel**) {
|
||||
// We don't have access to base_lib_def_ in base function library runtime (aka
|
||||
@ -371,12 +350,8 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
||||
DoneCallback done) override;
|
||||
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) override;
|
||||
Status RunSync(Options opts, Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets, DoneCallback done) override;
|
||||
Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) override;
|
||||
Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame, DoneCallback done) override;
|
||||
|
||||
bool IsStateful(const string& function) const override;
|
||||
|
||||
@ -1312,27 +1287,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets,
|
||||
DoneCallback done) {
|
||||
Item* item = nullptr;
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
|
||||
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
|
||||
if (item == nullptr) {
|
||||
return parent_->RunSync(opts, handle, args, rets, done);
|
||||
}
|
||||
|
||||
Executor::Args exec_args;
|
||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||
FunctionCallFrame frame(fbody->arg_types, fbody->ret_types);
|
||||
TF_RETURN_IF_ERROR(frame.SetArgs(args));
|
||||
ExecutorArgsFromOptions(opts, &frame, &exec_args);
|
||||
|
||||
TF_RETURN_IF_ERROR(item->exec->Run(exec_args, done));
|
||||
return frame.ConsumeRetvals(rets, opts.allow_dead_tensors);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) {
|
||||
Item* item = nullptr;
|
||||
@ -1347,21 +1301,6 @@ Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
return item->exec->Run(exec_args);
|
||||
}
|
||||
|
||||
Status FunctionLibraryRuntimeImpl::RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame,
|
||||
DoneCallback done) {
|
||||
Item* item = nullptr;
|
||||
std::unique_ptr<PrivateIntraProcessRendezvous> rendezvous;
|
||||
TF_RETURN_IF_ERROR(PrepareRunSync(handle, &opts, &item, &rendezvous));
|
||||
if (item == nullptr) {
|
||||
return parent_->RunSync(opts, handle, call_frame, done);
|
||||
}
|
||||
|
||||
Executor::Args exec_args;
|
||||
ExecutorArgsFromOptions(opts, call_frame, &exec_args);
|
||||
return item->exec->Run(exec_args, done);
|
||||
}
|
||||
|
||||
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) const {
|
||||
const OpDef* op_def;
|
||||
const Status s = base_lib_def_->LookUpOpDef(func, &op_def);
|
||||
|
@ -1611,23 +1611,6 @@ Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
return s;
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const {
|
||||
Notification n;
|
||||
Status s;
|
||||
Run(opts, handle, args, rets,
|
||||
[&n, &s, done = std::move(done)](const Status& status) {
|
||||
s.Update(status);
|
||||
done(s);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return s;
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const {
|
||||
@ -1641,22 +1624,6 @@ Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
return s;
|
||||
}
|
||||
|
||||
Status ProcessFunctionLibraryRuntime::RunSync(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame,
|
||||
FunctionLibraryRuntime::DoneCallback done) const {
|
||||
Notification n;
|
||||
Status s;
|
||||
Run(opts, handle, frame,
|
||||
[&n, &s, done = std::move(done)](const Status& status) {
|
||||
s.Update(status);
|
||||
done(s);
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return s;
|
||||
}
|
||||
|
||||
void ProcessFunctionLibraryRuntime::Run(
|
||||
const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle, const FunctionArgsInterface& args,
|
||||
|
@ -197,17 +197,9 @@ class ProcessFunctionLibraryRuntime {
|
||||
Status RunSync(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets) const;
|
||||
Status RunSync(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const;
|
||||
Status RunSync(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
CallFrameInterface* frame) const;
|
||||
Status RunSync(const FunctionLibraryRuntime::Options& opts,
|
||||
FunctionLibraryRuntime::Handle handle,
|
||||
CallFrameInterface* frame,
|
||||
FunctionLibraryRuntime::DoneCallback done) const;
|
||||
|
||||
const DeviceMgr* device_mgr() { return device_mgr_; }
|
||||
|
||||
|
@ -779,15 +779,8 @@ class FunctionLibraryRuntime {
|
||||
virtual Status RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets) = 0;
|
||||
virtual Status RunSync(Options opts, Handle handle,
|
||||
gtl::ArraySlice<Tensor> args,
|
||||
std::vector<Tensor>* rets,
|
||||
DoneCallback done) = 0;
|
||||
virtual Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame) = 0;
|
||||
virtual Status RunSync(Options opts, Handle handle,
|
||||
CallFrameInterface* call_frame,
|
||||
DoneCallback done) = 0;
|
||||
|
||||
// Creates a "kernel" for the given NodeProperties "props".
|
||||
//
|
||||
|
@ -850,33 +850,24 @@ Status InstantiatedCapturedFunction::Run(
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
if (collect_usage) {
|
||||
// Resource usage is for function execution is gathered from the executor.
|
||||
// NOTE(mkuchnik): RecordStop and RecordStart have to be called around
|
||||
// this function to prevent double-counting resource usage.
|
||||
auto callback = std::bind(
|
||||
[this, node, collect_usage](
|
||||
IteratorContext* ctx,
|
||||
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
|
||||
// Begin unbound arguments.
|
||||
Status s) {
|
||||
if (node) {
|
||||
// TODO(b/129085499) Utilize the `node_name` which would be unique
|
||||
// than the prefix for the function execution time statistics.
|
||||
// prefix_with_func_name would then be node_name + func_name.
|
||||
if (ctx->stats_aggregator()) {
|
||||
string prefix_with_func_name =
|
||||
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
||||
captured_func_->func().name());
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||
{static_cast<float>(stats_collector->processing_time())},
|
||||
node->num_elements());
|
||||
}
|
||||
node->add_processing_time(stats_collector->processing_time());
|
||||
node->record_stop(EnvTime::NowNanos());
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
node->record_start(EnvTime::NowNanos());
|
||||
if (node) {
|
||||
// TODO(b/129085499) Utilize the `node_name` which would be unique
|
||||
// than the prefix for the function execution time statistics.
|
||||
// prefix_with_func_name would then be node_name + func_name.
|
||||
if (ctx->stats_aggregator()) {
|
||||
string prefix_with_func_name =
|
||||
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
||||
captured_func_->func().name());
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||
{static_cast<float>(stats_collector->processing_time())},
|
||||
node->num_elements());
|
||||
}
|
||||
},
|
||||
ctx, std::move(stats_collector), std::placeholders::_1);
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame,
|
||||
std::move(callback)));
|
||||
node->add_processing_time(stats_collector->processing_time());
|
||||
}
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
}
|
||||
|
@ -243,11 +243,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (states_.find(key) == states_.end()) {
|
||||
// Run the init function to create the initial state.
|
||||
std::vector<Tensor> init_func_output;
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(instantiated_init_func_->Run(
|
||||
ctx, std::move(key_func_output), &init_func_output,
|
||||
model_node()));
|
||||
RecordStart(ctx);
|
||||
states_[key] = init_func_output;
|
||||
}
|
||||
|
||||
@ -260,10 +258,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::back_inserter(args));
|
||||
|
||||
std::vector<Tensor> reduce_func_output;
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
|
||||
ctx, std::move(args), &reduce_func_output, model_node()));
|
||||
RecordStart(ctx);
|
||||
states_[key] = reduce_func_output;
|
||||
} else {
|
||||
keys_.resize(states_.size());
|
||||
|
@ -254,11 +254,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Run the window size function on the key to identify its
|
||||
// window size.
|
||||
std::vector<Tensor> window_size_func_output;
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
|
||||
ctx, std::move(key_func_output), &window_size_func_output,
|
||||
model_node()));
|
||||
RecordStart(ctx);
|
||||
|
||||
if (window_size_func_output.size() != 1 ||
|
||||
window_size_func_output[0].dtype() != DT_INT64 ||
|
||||
@ -489,11 +487,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor> args(
|
||||
{std::move(key_arg), std::move(group_dataset_arg)});
|
||||
std::vector<Tensor> return_values;
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args),
|
||||
&return_values,
|
||||
model_node()));
|
||||
RecordStart(ctx);
|
||||
|
||||
if (!(return_values.size() == 1 &&
|
||||
return_values[0].dtype() == DT_VARIANT &&
|
||||
|
@ -315,10 +315,8 @@ class LoadDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<Tensor> reader_output;
|
||||
reader_input.push_back(std::move(input_dataset_tensor));
|
||||
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
|
||||
ctx, std::move(reader_input), &reader_output, model_node()));
|
||||
RecordStart(ctx);
|
||||
if (reader_output.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"reader_func returns more than one argument.");
|
||||
|
@ -200,11 +200,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
state_and_output.reserve(dataset()->state_types_.size() +
|
||||
output_dtypes().size());
|
||||
|
||||
RecordStop(ctx);
|
||||
Status s = instantiated_captured_func_->Run(ctx, std::move(args),
|
||||
&state_and_output,
|
||||
model_node());
|
||||
RecordStart(ctx);
|
||||
DCHECK(state_and_output.size() <=
|
||||
dataset()->state_types_.size() + output_dtypes().size());
|
||||
if (s.ok()) {
|
||||
|
@ -571,10 +571,8 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
std::vector<Tensor> reader_output;
|
||||
reader_input.push_back(std::move(input_dataset_tensor));
|
||||
|
||||
RecordStop(ctx);
|
||||
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
|
||||
ctx, std::move(reader_input), &reader_output, model_node()));
|
||||
RecordStart(ctx);
|
||||
if (reader_output.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"reader_func returns more than one argument.");
|
||||
|
@ -158,11 +158,9 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
RecordStop(ctx);
|
||||
Status s =
|
||||
instantiated_captured_func_->Run(ctx, std::move(args), out_tensors,
|
||||
model_node());
|
||||
RecordStart(ctx);
|
||||
if (errors::IsOutOfRange(s)) {
|
||||
if (dataset()->preserve_cardinality_) {
|
||||
// To guarantee that the transformation preserves the cardinality of
|
||||
|
@ -454,11 +454,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
RecordStop(ctx.get());
|
||||
(*ctx->runner())(
|
||||
[this, ctx, fn = std::move(fn), done = std::move(done)]() {
|
||||
Status s = fn();
|
||||
RecordStart(ctx.get());
|
||||
auto cleanup =
|
||||
gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
|
||||
done(s);
|
||||
done(fn());
|
||||
});
|
||||
RecordStart(ctx.get());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user