Add similar implementation for various captured function Runs

This commit is contained in:
Michael Kuchnik 2020-11-03 11:29:54 -05:00
parent 4082c29741
commit 3f4d17bff8
9 changed files with 49 additions and 20 deletions

View File

@ -450,8 +450,9 @@ Status MakeIteratorFromInputElement(
std::unique_ptr<IteratorBase>* out_iterator) { std::unique_ptr<IteratorBase>* out_iterator) {
std::vector<Tensor> return_values; std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element, TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
&return_values)); ctx, input_element, &return_values,
std::shared_ptr<model::Node>(nullptr)));
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT && if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) { TensorShapeUtils::IsScalar(return_values[0].shape()))) {
@ -848,11 +849,10 @@ Status InstantiatedCapturedFunction::Run(
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#"); "InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
}, },
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
if (collect_usage) { if (node) {
// Resource usage is for function execution is gathered from the executor. // Resource usage for function execution is gathered from the executor.
node->record_stop(EnvTime::NowNanos()); if (collect_usage) node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
node->record_start(EnvTime::NowNanos());
if (ctx->stats_aggregator()) { if (ctx->stats_aggregator()) {
string prefix_with_func_name = string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter, strings::StrCat(node->name(), stats_utils::kDelimiter,
@ -863,6 +863,7 @@ Status InstantiatedCapturedFunction::Run(
node->num_elements()); node->num_elements());
} }
node->add_processing_time(stats_collector->processing_time()); node->add_processing_time(stats_collector->processing_time());
if (collect_usage) node->record_start(EnvTime::NowNanos());
} else { } else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
} }
@ -871,7 +872,7 @@ Status InstantiatedCapturedFunction::Run(
Status InstantiatedCapturedFunction::RunWithBorrowedArgs( Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
IteratorContext* ctx, const std::vector<Tensor>& args, IteratorContext* ctx, const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const { std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info(); auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) { if (!info.indices.empty()) {
return RunShortCircuit(info, args, captured_func_, rets); return RunShortCircuit(info, args, captured_func_, rets);
@ -888,6 +889,14 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
CancellationManager cancellation_manager(ctx->cancellation_manager()); CancellationManager cancellation_manager(ctx->cancellation_manager());
f_opts.cancellation_manager = &cancellation_manager; f_opts.cancellation_manager = &cancellation_manager;
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (node || ctx->stats_aggregator()) {
stats_collector = std::make_shared<SimpleStepStatsCollector>();
}
const bool collect_usage =
node && ctx->model() && ctx->model()->collect_resource_usage();
f_opts.stats_collector = stats_collector.get();
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(), BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_); ret_types_);
profiler::TraceMe activity( profiler::TraceMe activity(
@ -897,7 +906,24 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
f_opts.step_id, "#"); f_opts.step_id, "#");
}, },
profiler::TraceMeLevel::kInfo); profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame)); if (node) {
// Resource usage for function execution is gathered from the executor.
if (collect_usage) node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
if (ctx->stats_aggregator()) {
string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter,
captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
}
node->add_processing_time(stats_collector->processing_time());
if (collect_usage) node->record_start(EnvTime::NowNanos());
} else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
}
return frame.ConsumeRetvals(rets); return frame.ConsumeRetvals(rets);
} }

View File

@ -221,7 +221,8 @@ class InstantiatedCapturedFunction {
// possible. // possible.
Status RunWithBorrowedArgs(IteratorContext* ctx, Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args, const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const; std::vector<Tensor>* rets,
const std::shared_ptr<model::Node>& node) const;
// Synchronously runs the captured function on the given `args`, and stores // Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when

View File

@ -229,7 +229,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
// Run the key function on the input element. // Run the key function on the input element.
std::vector<Tensor> key_func_output; std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output)); ctx, next_input_element, &key_func_output, model_node()));
if (key_func_output.size() != 1 || if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 || key_func_output[0].dtype() != DT_INT64 ||
@ -275,7 +275,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
return Status::OK(); return Status::OK();
} }
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
ctx, states_[keys_[keys_index_++]], out_tensors)); ctx, states_[keys_[keys_index_++]], out_tensors, model_node()));
*end_of_sequence = false; *end_of_sequence = false;
return Status::OK(); return Status::OK();
} }

View File

@ -239,7 +239,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// group. // group.
std::vector<Tensor> key_func_output; std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output)); ctx, next_input_element, &key_func_output, model_node()));
if (key_func_output.size() != 1 || if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 || key_func_output[0].dtype() != DT_INT64 ||

View File

@ -148,7 +148,8 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
} }
std::vector<Tensor> output_tensors; std::vector<Tensor> output_tensors;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
function->RunWithBorrowedArgs(ctx, element, &output_tensors)); function->RunWithBorrowedArgs(ctx, element, &output_tensors,
std::shared_ptr<model::Node>(nullptr)));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 || if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) { output_tensors[0].NumElements() != 1) {

View File

@ -666,7 +666,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
// Run the shard function // Run the shard function
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
ctx, tensors, &output_tensors)); ctx, tensors, &output_tensors, model_node()));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 || if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) { output_tensors[0].NumElements() != 1) {

View File

@ -151,7 +151,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
} }
std::vector<Tensor> result; std::vector<Tensor> result;
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
ctx, *out_tensors, &result)); ctx, *out_tensors, &result, model_node()));
if (result.size() != 1 || result[0].dtype() != DT_BOOL || if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
result[0].NumElements() != 1) { result[0].NumElements() != 1) {

View File

@ -148,7 +148,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> result; std::vector<Tensor> result;
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
ctx, *out_tensors, &result)); ctx, *out_tensors, &result, model_node()));
if (result.size() != 1 || result[0].dtype() != DT_BOOL || if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
result[0].NumElements() != 1) { result[0].NumElements() != 1) {

View File

@ -128,7 +128,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
if (!initialized_) { if (!initialized_) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); instantiated_init_func_->RunWithBorrowedArgs(
ctx, {}, &state_, model_node()));
initialized_ = true; initialized_ = true;
} }
@ -137,8 +138,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
return Status::OK(); return Status::OK();
} }
Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_, Status s = instantiated_next_func_->RunWithBorrowedArgs(
out_tensors); ctx, state_, out_tensors, model_node());
if (s.ok()) { if (s.ok()) {
*end_of_sequence = false; *end_of_sequence = false;
} else if (errors::IsOutOfRange(s)) { } else if (errors::IsOutOfRange(s)) {
@ -150,7 +151,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
// NOTE(mrry): We ignore any tensors returned by the finalize function. // NOTE(mrry): We ignore any tensors returned by the finalize function.
std::vector<Tensor> ignored; std::vector<Tensor> ignored;
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs( TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
ctx, state_, &ignored)); ctx, state_, &ignored, model_node()));
finalized_ = true; finalized_ = true;
} }
return s; return s;