Add similar implementation for various captured function Runs
This commit is contained in:
parent
4082c29741
commit
3f4d17bff8
@ -450,8 +450,9 @@ Status MakeIteratorFromInputElement(
|
||||
std::unique_ptr<IteratorBase>* out_iterator) {
|
||||
std::vector<Tensor> return_values;
|
||||
|
||||
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
|
||||
&return_values));
|
||||
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
|
||||
ctx, input_element, &return_values,
|
||||
std::shared_ptr<model::Node>(nullptr)));
|
||||
|
||||
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
|
||||
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
|
||||
@ -848,11 +849,10 @@ Status InstantiatedCapturedFunction::Run(
|
||||
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
if (collect_usage) {
|
||||
// Resource usage is for function execution is gathered from the executor.
|
||||
node->record_stop(EnvTime::NowNanos());
|
||||
if (node) {
|
||||
// Resource usage for function execution is gathered from the executor.
|
||||
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
node->record_start(EnvTime::NowNanos());
|
||||
if (ctx->stats_aggregator()) {
|
||||
string prefix_with_func_name =
|
||||
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
||||
@ -863,6 +863,7 @@ Status InstantiatedCapturedFunction::Run(
|
||||
node->num_elements());
|
||||
}
|
||||
node->add_processing_time(stats_collector->processing_time());
|
||||
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
}
|
||||
@ -871,7 +872,7 @@ Status InstantiatedCapturedFunction::Run(
|
||||
|
||||
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
IteratorContext* ctx, const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets) const {
|
||||
std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
|
||||
auto& info = captured_func_->short_circuit_info();
|
||||
if (!info.indices.empty()) {
|
||||
return RunShortCircuit(info, args, captured_func_, rets);
|
||||
@ -888,6 +889,14 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
CancellationManager cancellation_manager(ctx->cancellation_manager());
|
||||
f_opts.cancellation_manager = &cancellation_manager;
|
||||
|
||||
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
|
||||
if (node || ctx->stats_aggregator()) {
|
||||
stats_collector = std::make_shared<SimpleStepStatsCollector>();
|
||||
}
|
||||
const bool collect_usage =
|
||||
node && ctx->model() && ctx->model()->collect_resource_usage();
|
||||
f_opts.stats_collector = stats_collector.get();
|
||||
|
||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||
ret_types_);
|
||||
profiler::TraceMe activity(
|
||||
@ -897,7 +906,24 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
if (node) {
|
||||
// Resource usage for function execution is gathered from the executor.
|
||||
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
if (ctx->stats_aggregator()) {
|
||||
string prefix_with_func_name =
|
||||
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
||||
captured_func_->func().name());
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||
{static_cast<float>(stats_collector->processing_time())},
|
||||
node->num_elements());
|
||||
}
|
||||
node->add_processing_time(stats_collector->processing_time());
|
||||
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
}
|
||||
return frame.ConsumeRetvals(rets);
|
||||
}
|
||||
|
||||
|
@ -221,7 +221,8 @@ class InstantiatedCapturedFunction {
|
||||
// possible.
|
||||
Status RunWithBorrowedArgs(IteratorContext* ctx,
|
||||
const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
std::vector<Tensor>* rets,
|
||||
const std::shared_ptr<model::Node>& node) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
|
@ -229,7 +229,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Run the key function on the input element.
|
||||
std::vector<Tensor> key_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
||||
ctx, next_input_element, &key_func_output));
|
||||
ctx, next_input_element, &key_func_output, model_node()));
|
||||
|
||||
if (key_func_output.size() != 1 ||
|
||||
key_func_output[0].dtype() != DT_INT64 ||
|
||||
@ -275,7 +275,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
||||
ctx, states_[keys_[keys_index_++]], out_tensors));
|
||||
ctx, states_[keys_[keys_index_++]], out_tensors, model_node()));
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
// group.
|
||||
std::vector<Tensor> key_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
||||
ctx, next_input_element, &key_func_output));
|
||||
ctx, next_input_element, &key_func_output, model_node()));
|
||||
|
||||
if (key_func_output.size() != 1 ||
|
||||
key_func_output[0].dtype() != DT_INT64 ||
|
||||
|
@ -148,7 +148,8 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
|
||||
}
|
||||
std::vector<Tensor> output_tensors;
|
||||
TF_RETURN_IF_ERROR(
|
||||
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
|
||||
function->RunWithBorrowedArgs(ctx, element, &output_tensors,
|
||||
std::shared_ptr<model::Node>(nullptr)));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
|
@ -666,7 +666,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
|
||||
|
||||
// Run the shard function
|
||||
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
|
||||
ctx, tensors, &output_tensors));
|
||||
ctx, tensors, &output_tensors, model_node()));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
|
@ -151,7 +151,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
std::vector<Tensor> result;
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
||||
ctx, *out_tensors, &result));
|
||||
ctx, *out_tensors, &result, model_node()));
|
||||
|
||||
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
||||
result[0].NumElements() != 1) {
|
||||
|
@ -148,7 +148,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
std::vector<Tensor> result;
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
||||
ctx, *out_tensors, &result));
|
||||
ctx, *out_tensors, &result, model_node()));
|
||||
|
||||
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
||||
result[0].NumElements() != 1) {
|
||||
|
@ -128,7 +128,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
if (!initialized_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
|
||||
instantiated_init_func_->RunWithBorrowedArgs(
|
||||
ctx, {}, &state_, model_node()));
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
@ -137,8 +138,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
|
||||
out_tensors);
|
||||
Status s = instantiated_next_func_->RunWithBorrowedArgs(
|
||||
ctx, state_, out_tensors, model_node());
|
||||
if (s.ok()) {
|
||||
*end_of_sequence = false;
|
||||
} else if (errors::IsOutOfRange(s)) {
|
||||
@ -150,7 +151,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
// NOTE(mrry): We ignore any tensors returned by the finalize function.
|
||||
std::vector<Tensor> ignored;
|
||||
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
||||
ctx, state_, &ignored));
|
||||
ctx, state_, &ignored, model_node()));
|
||||
finalized_ = true;
|
||||
}
|
||||
return s;
|
||||
|
Loading…
x
Reference in New Issue
Block a user