Merge pull request from mkuchnik:executor_based_runsync_metrics

PiperOrigin-RevId: 341011613
Change-Id: Ic4e8ec1beeae76fe9b6ba2ece93a1a27c38004a4
This commit is contained in:
TensorFlower Gardener 2020-11-06 02:10:14 -08:00
commit 609bad4d8b
18 changed files with 189 additions and 53 deletions

View File

@ -448,10 +448,21 @@ Status MakeIteratorFromInputElement(
const std::vector<Tensor>& input_element, int64 thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator) {
return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
inst_captured_func, prefix, out_iterator,
/*node=*/nullptr);
}
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const IteratorBase* parent,
const std::vector<Tensor>& input_element, int64 thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator,
const std::shared_ptr<model::Node>& node) {
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
&return_values));
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
ctx, input_element, &return_values, node));
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
@ -816,6 +827,12 @@ InstantiatedCapturedFunction::InstantiatedCapturedFunction(
Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const {
return Run(ctx, std::move(args), rets, /*node=*/nullptr);
}
Status InstantiatedCapturedFunction::Run(
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, std::move(args), captured_func_, rets);
@ -832,6 +849,14 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
CancellationManager cancellation_manager(ctx->cancellation_manager());
f_opts.cancellation_manager = &cancellation_manager;
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (node || ctx->stats_aggregator()) {
stats_collector = std::make_shared<SimpleStepStatsCollector>();
}
const bool collect_usage =
node && ctx->model() && ctx->model()->collect_resource_usage();
f_opts.stats_collector = stats_collector.get();
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
ret_types_);
profiler::TraceMe activity(
@ -840,13 +865,37 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
if (node) {
// Resource usage for function execution is gathered from the executor.
// TODO(jsimsa): Factor out common code for Run, RunAsync, and
// RunWithBorrowedArguments
if (collect_usage) node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
if (ctx->stats_aggregator()) {
string prefix_with_func_name = strings::StrCat(
node->name(), stats_utils::kDelimiter, captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
}
node->add_processing_time(stats_collector->processing_time());
if (collect_usage) node->record_start(EnvTime::NowNanos());
} else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
}
return frame.ConsumeRetvals(rets);
}
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
IteratorContext* ctx, const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const {
std::vector<Tensor>* ret) const {
return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
}
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
IteratorContext* ctx, const std::vector<Tensor>& args,
std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
return RunShortCircuit(info, args, captured_func_, rets);
@ -863,6 +912,14 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
CancellationManager cancellation_manager(ctx->cancellation_manager());
f_opts.cancellation_manager = &cancellation_manager;
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (node || ctx->stats_aggregator()) {
stats_collector = std::make_shared<SimpleStepStatsCollector>();
}
const bool collect_usage =
node && ctx->model() && ctx->model()->collect_resource_usage();
f_opts.stats_collector = stats_collector.get();
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
ret_types_);
profiler::TraceMe activity(
@ -872,7 +929,23 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
if (node) {
// Resource usage for function execution is gathered from the executor.
if (collect_usage) node->record_stop(EnvTime::NowNanos());
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
if (ctx->stats_aggregator()) {
string prefix_with_func_name = strings::StrCat(
node->name(), stats_utils::kDelimiter, captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
node->num_elements());
}
node->add_processing_time(stats_collector->processing_time());
if (collect_usage) node->record_start(EnvTime::NowNanos());
} else {
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
}
return frame.ConsumeRetvals(rets);
}

View File

@ -48,6 +48,16 @@ Status MakeIteratorFromInputElement(
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
// Creates an iterator for a dataset which is created by applying the given
// function to the given input element. Pass non-null `node` to record
// processing time for modeling Iterator's GetNext() resource usage.
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const IteratorBase* parent,
const std::vector<Tensor>& input_element, int64 thread_index,
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator,
const std::shared_ptr<model::Node>& node);
// Determines whether the given node is stateful.
Status IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node);
@ -215,6 +225,15 @@ class InstantiatedCapturedFunction {
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) const;
// Runs the instantiated captured function. This method takes ownership of
// the tensors in `args`, in order to be able to deallocate them as early as
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
// ownership of the `args`. Pass non-null `node` to record processing time
// for modeling Iterator's GetNext() resource usage.
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
const std::shared_ptr<model::Node>& node) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible.
@ -222,6 +241,15 @@ class InstantiatedCapturedFunction {
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible. Pass non-null `node` to record processing time for modeling
// Iterator's GetNext() resource usage.
Status RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets,
const std::shared_ptr<model::Node>& node) const;
// Synchronously runs the captured function on the given `args`, and stores
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
// possible. This can be useful for calling a captured function in cases where
@ -234,7 +262,8 @@ class InstantiatedCapturedFunction {
// Asynchronously runs the captured function on the given `args`, stores the
// results in `*rets`, and calls the given `done` callback when the function
// returns. This method takes ownership of the tensors in `args`, in order to
// be able to deallocate them as early as possible.
// be able to deallocate them as early as possible. Pass non-null `node` to
// record processing time for modeling Iterator's GetNext() resource usage.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,

View File

@ -366,7 +366,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
// Still running experiments
if (!current_iterator_) {
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
/*is_experiment=*/true));
/*is_experiment=*/true,
/*is_get_next=*/true));
}
Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence);
@ -385,7 +386,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
if (!current_iterator_) {
SelectFastestInputIndex();
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
/*is_experiment=*/false));
/*is_experiment=*/false,
/*is_get_next=*/true));
}
}
@ -438,10 +440,12 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
if (!reader->Contains(full_name("input_impl_empty"))) {
if (branch_index_ < dataset()->captured_funcs_.size()) {
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
/*is_experiment=*/true));
/*is_experiment=*/true,
/*is_get_next=*/false));
} else {
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
/*is_experiment=*/false));
/*is_experiment=*/false,
/*is_get_next=*/false));
}
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_));
}
@ -492,7 +496,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
}
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
bool is_experiment)
bool is_experiment, bool is_get_next)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
DCHECK_GE(branch_index, 0);
DCHECK_LT(branch_index, histograms_.size());
@ -528,10 +532,18 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(
temp_dataset, wrapper_dataset_tensor_.get()));
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, {*wrapper_dataset_tensor_}, branch_index,
*instantiated_captured_funcs_[branch_index], prefix(),
&current_iterator_));
if (is_get_next) {
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, {*wrapper_dataset_tensor_}, branch_index,
*instantiated_captured_funcs_[branch_index], prefix(),
&current_iterator_, model_node()));
} else {
// NOTE: We intentionally ignore resource modeling outside GetNext().
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, {*wrapper_dataset_tensor_}, branch_index,
*instantiated_captured_funcs_[branch_index], prefix(),
&current_iterator_, /*node=*/nullptr));
}
return Status::OK();
}

View File

@ -229,7 +229,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
// Run the key function on the input element.
std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output));
ctx, next_input_element, &key_func_output, model_node()));
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
@ -244,7 +244,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
// Run the init function to create the initial state.
std::vector<Tensor> init_func_output;
TF_RETURN_IF_ERROR(instantiated_init_func_->Run(
ctx, std::move(key_func_output), &init_func_output));
ctx, std::move(key_func_output), &init_func_output,
model_node()));
states_[key] = init_func_output;
}
@ -258,7 +259,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> reduce_func_output;
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
ctx, std::move(args), &reduce_func_output));
ctx, std::move(args), &reduce_func_output, model_node()));
states_[key] = reduce_func_output;
} else {
keys_.resize(states_.size());
@ -274,7 +275,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
ctx, states_[keys_[keys_index_++]], out_tensors));
ctx, states_[keys_[keys_index_++]], out_tensors, model_node()));
*end_of_sequence = false;
return Status::OK();
}

View File

@ -239,7 +239,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// group.
std::vector<Tensor> key_func_output;
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
ctx, next_input_element, &key_func_output));
ctx, next_input_element, &key_func_output, model_node()));
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
@ -255,7 +255,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
// window size.
std::vector<Tensor> window_size_func_output;
TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
ctx, std::move(key_func_output), &window_size_func_output));
ctx, std::move(key_func_output), &window_size_func_output,
model_node()));
if (window_size_func_output.size() != 1 ||
window_size_func_output[0].dtype() != DT_INT64 ||
@ -486,8 +487,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor> args(
{std::move(key_arg), std::move(group_dataset_arg)});
std::vector<Tensor> return_values;
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args),
&return_values));
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
ctx, std::move(args), &return_values, model_node()));
if (!(return_values.size() == 1 &&
return_values[0].dtype() == DT_VARIANT &&

View File

@ -147,8 +147,8 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
return Status::OK();
}
std::vector<Tensor> output_tensors;
TF_RETURN_IF_ERROR(
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
TF_RETURN_IF_ERROR(function->RunWithBorrowedArgs(
ctx, element, &output_tensors, /*node=*/nullptr));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {
@ -315,8 +315,9 @@ class LoadDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor));
// NOTE: We intentionally ignore resource modeling outside GetNext().
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
ctx, std::move(reader_input), &reader_output));
ctx, std::move(reader_input), &reader_output, /*node=*/nullptr));
if (reader_output.size() != 1) {
return errors::InvalidArgument(
"reader_func returns more than one argument.");

View File

@ -773,7 +773,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
MakeIteratorFromInputElement(
ctx.get(), this, worker_thread_states_[thread_index].input,
thread_index, *instantiated_captured_func_, prefix(),
&worker_thread_states_[thread_index].iterator);
&worker_thread_states_[thread_index].iterator,
model_node());
iterator_creation_status =
worker_thread_states_[thread_index].iterator_creation_status;
if (!iterator_creation_status.ok()) {
@ -1011,9 +1012,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
state->iterator.reset();
} else {
std::unique_ptr<IteratorBase> iterator;
// NOTE: We intentionally ignore resource modeling outside GetNext().
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, state->input, index, *instantiated_captured_func_,
prefix(), &iterator));
prefix(), &iterator, /*node=*/nullptr));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
state->iterator.swap(iterator);
}

View File

@ -200,8 +200,8 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_and_output.reserve(dataset()->state_types_.size() +
output_dtypes().size());
Status s = instantiated_captured_func_->Run(ctx, std::move(args),
&state_and_output);
Status s = instantiated_captured_func_->Run(
ctx, std::move(args), &state_and_output, model_node());
DCHECK(state_and_output.size() <=
dataset()->state_types_.size() + output_dtypes().size());
if (s.ok()) {

View File

@ -578,8 +578,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
std::vector<Tensor> reader_output;
reader_input.push_back(std::move(input_dataset_tensor));
// NOTE: We intentionally ignore resource modeling outside GetNext().
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
ctx, std::move(reader_input), &reader_output));
ctx, std::move(reader_input), &reader_output, /*node=*/nullptr));
if (reader_output.size() != 1) {
return errors::InvalidArgument(
"reader_func returns more than one argument.");
@ -673,7 +674,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
// Run the shard function
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
ctx, tensors, &output_tensors));
ctx, tensors, &output_tensors, model_node()));
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
output_tensors[0].NumElements() != 1) {

View File

@ -151,7 +151,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
}
std::vector<Tensor> result;
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
ctx, *out_tensors, &result));
ctx, *out_tensors, &result, model_node()));
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
result[0].NumElements() != 1) {

View File

@ -148,7 +148,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
std::vector<Tensor> result;
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
ctx, *out_tensors, &result));
ctx, *out_tensors, &result, model_node()));
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
result[0].NumElements() != 1) {

View File

@ -157,7 +157,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
TF_RETURN_IF_ERROR(BuildCurrentElementIteratorLocked(ctx));
TF_RETURN_IF_ERROR(
BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/true));
} while (true);
}
@ -230,7 +231,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
&captured_func_inputs_.back()));
}
element_index_--;
TF_RETURN_IF_ERROR(BuildCurrentElementIteratorLocked(ctx));
TF_RETURN_IF_ERROR(
BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false));
TF_RETURN_IF_ERROR(
RestoreInput(ctx, reader, current_element_iterator_));
}
@ -239,11 +241,21 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
}
private:
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx)
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx,
bool is_get_next)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return MakeIteratorFromInputElement(
ctx, this, captured_func_inputs_, element_index_++,
*instantiated_captured_func_, prefix(), &current_element_iterator_);
if (is_get_next) {
return MakeIteratorFromInputElement(
ctx, this, captured_func_inputs_, element_index_++,
*instantiated_captured_func_, prefix(), &current_element_iterator_,
model_node());
} else {
// NOTE: We intentionally ignore resource modeling outside GetNext().
return MakeIteratorFromInputElement(
ctx, this, captured_func_inputs_, element_index_++,
*instantiated_captured_func_, prefix(), &current_element_iterator_,
/*node=*/nullptr);
}
}
mutex mu_;

View File

@ -127,8 +127,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
mutex_lock l(mu_);
if (!initialized_) {
TF_RETURN_IF_ERROR(
instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
TF_RETURN_IF_ERROR(instantiated_init_func_->RunWithBorrowedArgs(
ctx, {}, &state_, model_node()));
initialized_ = true;
}
@ -137,8 +137,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
out_tensors);
Status s = instantiated_next_func_->RunWithBorrowedArgs(
ctx, state_, out_tensors, model_node());
if (s.ok()) {
*end_of_sequence = false;
} else if (errors::IsOutOfRange(s)) {
@ -150,7 +150,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
// NOTE(mrry): We ignore any tensors returned by the finalize function.
std::vector<Tensor> ignored;
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
ctx, state_, &ignored));
ctx, state_, &ignored, model_node()));
finalized_ = true;
}
return s;

View File

@ -182,7 +182,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, args_list_[cycle_index_], cycle_index_,
*instantiated_captured_func_, prefix(),
&current_elements_[cycle_index_]));
&current_elements_[cycle_index_], model_node()));
++num_open_;
}
} else {
@ -276,9 +276,10 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
full_name(strings::StrCat(kArgsList, "[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
// NOTE: We intentionally ignore resource modeling outside GetNext().
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, args_list_[idx], idx, *instantiated_captured_func_,
prefix(), &current_elements_[idx]));
prefix(), &current_elements_[idx], /*node=*/nullptr));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_elements_[idx]));
} else {
current_elements_[idx].reset();

View File

@ -736,7 +736,7 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
std::vector<Tensor> reduce_func_output;
TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
&iter_ctx, std::move(args), &reduce_func_output));
&iter_ctx, std::move(args), &reduce_func_output, /*node=*/nullptr));
if (reduce_func_output.size() != state.size()) {
return errors::InvalidArgument(
"The number of components of the initial state and the "

View File

@ -158,8 +158,8 @@ class MapDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
Status s =
instantiated_captured_func_->Run(ctx, std::move(args), out_tensors);
Status s = instantiated_captured_func_->Run(ctx, std::move(args),
out_tensors, model_node());
if (errors::IsOutOfRange(s)) {
if (dataset()->preserve_cardinality_) {
// To guarantee that the transformation preserves the cardinality of

View File

@ -999,7 +999,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
absl::make_unique<std::vector<Tensor>>(std::move(inputs));
status = MakeIteratorFromInputElement(
ctx_.get(), this, *element->inputs, element->id,
*instantiated_captured_func_, prefix(), &element->iterator);
*instantiated_captured_func_, prefix(), &element->iterator,
model_node());
if (!status.ok()) {
element->inputs.reset();
element->iterator.reset();
@ -1291,7 +1292,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
reader->ReadScalar(iterator_name, kIdSuffix, &element->id));
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, this, *element->inputs, element->id,
*instantiated_captured_func_.get(), prefix(), &iterator));
*instantiated_captured_func_.get(), prefix(), &iterator,
model_node()));
}
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
mutex_lock l(*mu_);

View File

@ -444,7 +444,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
auto fn = std::bind(
[this, ctx, result](std::vector<Tensor> input_element) {
return instantiated_captured_func_->Run(
ctx.get(), std::move(input_element), &result->return_values);
ctx.get(), std::move(input_element), &result->return_values,
model_node());
},
std::move(input_element));
// `ctx->runner()` may execute its logic synchronously so we wrap it in