Add similar implementation for various captured function Runs

This commit is contained in:
Michael Kuchnik 2020-11-03 11:29:54 -05:00
parent 4082c29741
commit 3f4d17bff8
9 changed files with 49 additions and 20 deletions

View File

@ -450,8 +450,9 @@ Status MakeIteratorFromInputElement(
std::unique_ptr<IteratorBase>* out_iterator) {
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
&return_values));
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
ctx, input_element, &return_values,
std::shared_ptr<model::Node>(nullptr)));
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
@ -848,11 +849,10 @@ Status InstantiatedCapturedFunction::Run(
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
if (collect_usage) {
// Resource usage is for function execution is gathered from the executor.
node->record_stop(EnvTime::NowNanos());
if (node) {
// Resource usage for function execution is gathered from the executor.
if (collect_usage) node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
node->record_start(EnvTime::NowNanos());
if (ctx->stats_aggregator()) {
string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter,
@ -863,6 +863,7 @@ Status InstantiatedCapturedFunction::Run(
node->num_elements());
}
node->add_processing_time(stats_collector->processing_time());
if (collect_usage) node->record_start(EnvTime::NowNanos());
} else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
}
@ -871,7 +872,7 @@ Status InstantiatedCapturedFunction::Run(
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
IteratorContext* ctx, const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const {
std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, args, captured_func_, rets);
@ -888,6 +889,14 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
CancellationManager cancellation_manager(ctx->cancellation_manager());
f_opts.cancellation_manager = &cancellation_manager;
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (node || ctx->stats_aggregator()) {
stats_collector = std::make_shared<SimpleStepStatsCollector>();
}
const bool collect_usage =
node && ctx->model() && ctx->model()->collect_resource_usage();
f_opts.stats_collector = stats_collector.get();
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_);
profiler::TraceMe activity(
@ -897,7 +906,24 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
if (node) {
// Resource usage for function execution is gathered from the executor.
if (collect_usage) node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
if (ctx->stats_aggregator()) {
string prefix_with_func_name =
strings::StrCat(node->name(), stats_utils::kDelimiter,
captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
}
node->add_processing_time(stats_collector->processing_time());
if (collect_usage) node->record_start(EnvTime::NowNanos());
} else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
}
return frame.ConsumeRetvals(rets);
}

View File

@ -221,7 +221,8 @@ class InstantiatedCapturedFunction {
// possible.
Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const;
std::vector<Tensor>* rets,
const std::shared_ptr<model::Node>& node) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when

View File

@ -229,7 +229,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
// Run the key function on the input element.
std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output));
ctx, next_input_element, &key_func_output, model_node()));
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
@ -275,7 +275,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
ctx, states_[keys_[keys_index_++]], out_tensors));
ctx, states_[keys_[keys_index_++]], out_tensors, model_node()));
*end_of_sequence = false;
return Status::OK();
}

View File

@ -239,7 +239,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// group.
std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output));
ctx, next_input_element, &key_func_output, model_node()));
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||

View File

@ -148,7 +148,8 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
}
std::vector<Tensor> output_tensors;
TF_RETURN_IF_ERROR(
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
function->RunWithBorrowedArgs(ctx, element, &output_tensors,
std::shared_ptr<model::Node>(nullptr)));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {

View File

@ -666,7 +666,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
// Run the shard function
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
ctx, tensors, &output_tensors));
ctx, tensors, &output_tensors, model_node()));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {

View File

@ -151,7 +151,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
}
std::vector<Tensor> result;
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
ctx, *out_tensors, &result));
ctx, *out_tensors, &result, model_node()));
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
result[0].NumElements() != 1) {

View File

@ -148,7 +148,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> result;
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
ctx, *out_tensors, &result));
ctx, *out_tensors, &result, model_node()));
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
result[0].NumElements() != 1) {

View File

@ -128,7 +128,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
if (!initialized_) {
TF_RETURN_IF_ERROR(
instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
instantiated_init_func_->RunWithBorrowedArgs(
ctx, {}, &state_, model_node()));
initialized_ = true;
}
@ -137,8 +138,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
out_tensors);
Status s = instantiated_next_func_->RunWithBorrowedArgs(
ctx, state_, out_tensors, model_node());
if (s.ok()) {
*end_of_sequence = false;
} else if (errors::IsOutOfRange(s)) {
@ -150,7 +151,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
// NOTE(mrry): We ignore any tensors returned by the finalize function.
std::vector<Tensor> ignored;
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
ctx, state_, &ignored));
ctx, state_, &ignored, model_node()));
finalized_ = true;
}
return s;