Add similar implementation for various captured function Runs
This commit is contained in:
parent
4082c29741
commit
3f4d17bff8
@ -450,8 +450,9 @@ Status MakeIteratorFromInputElement(
|
|||||||
std::unique_ptr<IteratorBase>* out_iterator) {
|
std::unique_ptr<IteratorBase>* out_iterator) {
|
||||||
std::vector<Tensor> return_values;
|
std::vector<Tensor> return_values;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
|
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
|
||||||
&return_values));
|
ctx, input_element, &return_values,
|
||||||
|
std::shared_ptr<model::Node>(nullptr)));
|
||||||
|
|
||||||
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
|
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
|
||||||
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
|
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
|
||||||
@ -848,11 +849,10 @@ Status InstantiatedCapturedFunction::Run(
|
|||||||
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
||||||
},
|
},
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
if (collect_usage) {
|
if (node) {
|
||||||
// Resource usage is for function execution is gathered from the executor.
|
// Resource usage for function execution is gathered from the executor.
|
||||||
node->record_stop(EnvTime::NowNanos());
|
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
||||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||||
node->record_start(EnvTime::NowNanos());
|
|
||||||
if (ctx->stats_aggregator()) {
|
if (ctx->stats_aggregator()) {
|
||||||
string prefix_with_func_name =
|
string prefix_with_func_name =
|
||||||
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
||||||
@ -863,6 +863,7 @@ Status InstantiatedCapturedFunction::Run(
|
|||||||
node->num_elements());
|
node->num_elements());
|
||||||
}
|
}
|
||||||
node->add_processing_time(stats_collector->processing_time());
|
node->add_processing_time(stats_collector->processing_time());
|
||||||
|
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||||
}
|
}
|
||||||
@ -871,7 +872,7 @@ Status InstantiatedCapturedFunction::Run(
|
|||||||
|
|
||||||
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||||
IteratorContext* ctx, const std::vector<Tensor>& args,
|
IteratorContext* ctx, const std::vector<Tensor>& args,
|
||||||
std::vector<Tensor>* rets) const {
|
std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
|
||||||
auto& info = captured_func_->short_circuit_info();
|
auto& info = captured_func_->short_circuit_info();
|
||||||
if (!info.indices.empty()) {
|
if (!info.indices.empty()) {
|
||||||
return RunShortCircuit(info, args, captured_func_, rets);
|
return RunShortCircuit(info, args, captured_func_, rets);
|
||||||
@ -888,6 +889,14 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
|||||||
CancellationManager cancellation_manager(ctx->cancellation_manager());
|
CancellationManager cancellation_manager(ctx->cancellation_manager());
|
||||||
f_opts.cancellation_manager = &cancellation_manager;
|
f_opts.cancellation_manager = &cancellation_manager;
|
||||||
|
|
||||||
|
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
|
||||||
|
if (node || ctx->stats_aggregator()) {
|
||||||
|
stats_collector = std::make_shared<SimpleStepStatsCollector>();
|
||||||
|
}
|
||||||
|
const bool collect_usage =
|
||||||
|
node && ctx->model() && ctx->model()->collect_resource_usage();
|
||||||
|
f_opts.stats_collector = stats_collector.get();
|
||||||
|
|
||||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||||
ret_types_);
|
ret_types_);
|
||||||
profiler::TraceMe activity(
|
profiler::TraceMe activity(
|
||||||
@ -897,7 +906,24 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
|||||||
f_opts.step_id, "#");
|
f_opts.step_id, "#");
|
||||||
},
|
},
|
||||||
profiler::TraceMeLevel::kInfo);
|
profiler::TraceMeLevel::kInfo);
|
||||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
if (node) {
|
||||||
|
// Resource usage for function execution is gathered from the executor.
|
||||||
|
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
||||||
|
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||||
|
if (ctx->stats_aggregator()) {
|
||||||
|
string prefix_with_func_name =
|
||||||
|
strings::StrCat(node->name(), stats_utils::kDelimiter,
|
||||||
|
captured_func_->func().name());
|
||||||
|
ctx->stats_aggregator()->AddToHistogram(
|
||||||
|
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||||
|
{static_cast<float>(stats_collector->processing_time())},
|
||||||
|
node->num_elements());
|
||||||
|
}
|
||||||
|
node->add_processing_time(stats_collector->processing_time());
|
||||||
|
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||||
|
}
|
||||||
return frame.ConsumeRetvals(rets);
|
return frame.ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,7 +221,8 @@ class InstantiatedCapturedFunction {
|
|||||||
// possible.
|
// possible.
|
||||||
Status RunWithBorrowedArgs(IteratorContext* ctx,
|
Status RunWithBorrowedArgs(IteratorContext* ctx,
|
||||||
const std::vector<Tensor>& args,
|
const std::vector<Tensor>& args,
|
||||||
std::vector<Tensor>* rets) const;
|
std::vector<Tensor>* rets,
|
||||||
|
const std::shared_ptr<model::Node>& node) const;
|
||||||
|
|
||||||
// Synchronously runs the captured function on the given `args`, and stores
|
// Synchronously runs the captured function on the given `args`, and stores
|
||||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||||
|
@ -229,7 +229,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
// Run the key function on the input element.
|
// Run the key function on the input element.
|
||||||
std::vector<Tensor> key_func_output;
|
std::vector<Tensor> key_func_output;
|
||||||
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
||||||
ctx, next_input_element, &key_func_output));
|
ctx, next_input_element, &key_func_output, model_node()));
|
||||||
|
|
||||||
if (key_func_output.size() != 1 ||
|
if (key_func_output.size() != 1 ||
|
||||||
key_func_output[0].dtype() != DT_INT64 ||
|
key_func_output[0].dtype() != DT_INT64 ||
|
||||||
@ -275,7 +275,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
||||||
ctx, states_[keys_[keys_index_++]], out_tensors));
|
ctx, states_[keys_[keys_index_++]], out_tensors, model_node()));
|
||||||
*end_of_sequence = false;
|
*end_of_sequence = false;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -239,7 +239,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
// group.
|
// group.
|
||||||
std::vector<Tensor> key_func_output;
|
std::vector<Tensor> key_func_output;
|
||||||
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
||||||
ctx, next_input_element, &key_func_output));
|
ctx, next_input_element, &key_func_output, model_node()));
|
||||||
|
|
||||||
if (key_func_output.size() != 1 ||
|
if (key_func_output.size() != 1 ||
|
||||||
key_func_output[0].dtype() != DT_INT64 ||
|
key_func_output[0].dtype() != DT_INT64 ||
|
||||||
|
@ -148,7 +148,8 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
|
|||||||
}
|
}
|
||||||
std::vector<Tensor> output_tensors;
|
std::vector<Tensor> output_tensors;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
|
function->RunWithBorrowedArgs(ctx, element, &output_tensors,
|
||||||
|
std::shared_ptr<model::Node>(nullptr)));
|
||||||
|
|
||||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||||
output_tensors[0].NumElements() != 1) {
|
output_tensors[0].NumElements() != 1) {
|
||||||
|
@ -666,7 +666,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
|
|||||||
|
|
||||||
// Run the shard function
|
// Run the shard function
|
||||||
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
|
||||||
ctx, tensors, &output_tensors));
|
ctx, tensors, &output_tensors, model_node()));
|
||||||
|
|
||||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||||
output_tensors[0].NumElements() != 1) {
|
output_tensors[0].NumElements() != 1) {
|
||||||
|
@ -151,7 +151,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
std::vector<Tensor> result;
|
std::vector<Tensor> result;
|
||||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
||||||
ctx, *out_tensors, &result));
|
ctx, *out_tensors, &result, model_node()));
|
||||||
|
|
||||||
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
||||||
result[0].NumElements() != 1) {
|
result[0].NumElements() != 1) {
|
||||||
|
@ -148,7 +148,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
std::vector<Tensor> result;
|
std::vector<Tensor> result;
|
||||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
||||||
ctx, *out_tensors, &result));
|
ctx, *out_tensors, &result, model_node()));
|
||||||
|
|
||||||
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
||||||
result[0].NumElements() != 1) {
|
result[0].NumElements() != 1) {
|
||||||
|
@ -128,7 +128,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
if (!initialized_) {
|
if (!initialized_) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
|
instantiated_init_func_->RunWithBorrowedArgs(
|
||||||
|
ctx, {}, &state_, model_node()));
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -137,8 +138,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
|
Status s = instantiated_next_func_->RunWithBorrowedArgs(
|
||||||
out_tensors);
|
ctx, state_, out_tensors, model_node());
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
*end_of_sequence = false;
|
*end_of_sequence = false;
|
||||||
} else if (errors::IsOutOfRange(s)) {
|
} else if (errors::IsOutOfRange(s)) {
|
||||||
@ -150,7 +151,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
|||||||
// NOTE(mrry): We ignore any tensors returned by the finalize function.
|
// NOTE(mrry): We ignore any tensors returned by the finalize function.
|
||||||
std::vector<Tensor> ignored;
|
std::vector<Tensor> ignored;
|
||||||
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
||||||
ctx, state_, &ignored));
|
ctx, state_, &ignored, model_node()));
|
||||||
finalized_ = true;
|
finalized_ = true;
|
||||||
}
|
}
|
||||||
return s;
|
return s;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user