Merge pull request #44523 from mkuchnik:executor_based_runsync_metrics
PiperOrigin-RevId: 341011613 Change-Id: Ic4e8ec1beeae76fe9b6ba2ece93a1a27c38004a4
This commit is contained in:
commit
609bad4d8b
tensorflow/core/kernels/data
captured_function.cccaptured_function.h
experimental
choose_fastest_branch_dataset_op.ccgroup_by_reducer_dataset_op.ccgroup_by_window_dataset_op.ccio_ops.ccparallel_interleave_dataset_op.ccscan_dataset_op.ccsnapshot_dataset_op.cctake_while_dataset_op.cc
filter_dataset_op.ccflat_map_dataset_op.ccgenerator_dataset_op.ccinterleave_dataset_op.cciterator_ops.ccmap_dataset_op.ccparallel_interleave_dataset_op.ccparallel_map_dataset_op.cc@ -448,10 +448,21 @@ Status MakeIteratorFromInputElement(
|
||||
const std::vector<Tensor>& input_element, int64 thread_index,
|
||||
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
|
||||
std::unique_ptr<IteratorBase>* out_iterator) {
|
||||
return MakeIteratorFromInputElement(ctx, parent, input_element, thread_index,
|
||||
inst_captured_func, prefix, out_iterator,
|
||||
/*node=*/nullptr);
|
||||
}
|
||||
|
||||
Status MakeIteratorFromInputElement(
|
||||
IteratorContext* ctx, const IteratorBase* parent,
|
||||
const std::vector<Tensor>& input_element, int64 thread_index,
|
||||
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
|
||||
std::unique_ptr<IteratorBase>* out_iterator,
|
||||
const std::shared_ptr<model::Node>& node) {
|
||||
std::vector<Tensor> return_values;
|
||||
|
||||
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(ctx, input_element,
|
||||
&return_values));
|
||||
TF_RETURN_IF_ERROR(inst_captured_func.RunWithBorrowedArgs(
|
||||
ctx, input_element, &return_values, node));
|
||||
|
||||
if (!(return_values.size() == 1 && return_values[0].dtype() == DT_VARIANT &&
|
||||
TensorShapeUtils::IsScalar(return_values[0].shape()))) {
|
||||
@ -816,6 +827,12 @@ InstantiatedCapturedFunction::InstantiatedCapturedFunction(
|
||||
Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets) const {
|
||||
return Run(ctx, std::move(args), rets, /*node=*/nullptr);
|
||||
}
|
||||
|
||||
Status InstantiatedCapturedFunction::Run(
|
||||
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
|
||||
const std::shared_ptr<model::Node>& node) const {
|
||||
auto& info = captured_func_->short_circuit_info();
|
||||
if (!info.indices.empty()) {
|
||||
return RunShortCircuit(info, std::move(args), captured_func_, rets);
|
||||
@ -832,6 +849,14 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
CancellationManager cancellation_manager(ctx->cancellation_manager());
|
||||
f_opts.cancellation_manager = &cancellation_manager;
|
||||
|
||||
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
|
||||
if (node || ctx->stats_aggregator()) {
|
||||
stats_collector = std::make_shared<SimpleStepStatsCollector>();
|
||||
}
|
||||
const bool collect_usage =
|
||||
node && ctx->model() && ctx->model()->collect_resource_usage();
|
||||
f_opts.stats_collector = stats_collector.get();
|
||||
|
||||
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
|
||||
ret_types_);
|
||||
profiler::TraceMe activity(
|
||||
@ -840,13 +865,37 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
"InstantiatedCapturedFunction::Run#id=", f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
if (node) {
|
||||
// Resource usage for function execution is gathered from the executor.
|
||||
// TODO(jsimsa): Factor out common code for Run, RunAsync, and
|
||||
// RunWithBorrowedArguments
|
||||
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
if (ctx->stats_aggregator()) {
|
||||
string prefix_with_func_name = strings::StrCat(
|
||||
node->name(), stats_utils::kDelimiter, captured_func_->func().name());
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||
{static_cast<float>(stats_collector->processing_time())},
|
||||
node->num_elements());
|
||||
}
|
||||
node->add_processing_time(stats_collector->processing_time());
|
||||
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
}
|
||||
return frame.ConsumeRetvals(rets);
|
||||
}
|
||||
|
||||
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
IteratorContext* ctx, const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets) const {
|
||||
std::vector<Tensor>* ret) const {
|
||||
return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr);
|
||||
}
|
||||
|
||||
Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
IteratorContext* ctx, const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets, const std::shared_ptr<model::Node>& node) const {
|
||||
auto& info = captured_func_->short_circuit_info();
|
||||
if (!info.indices.empty()) {
|
||||
return RunShortCircuit(info, args, captured_func_, rets);
|
||||
@ -863,6 +912,14 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
CancellationManager cancellation_manager(ctx->cancellation_manager());
|
||||
f_opts.cancellation_manager = &cancellation_manager;
|
||||
|
||||
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
|
||||
if (node || ctx->stats_aggregator()) {
|
||||
stats_collector = std::make_shared<SimpleStepStatsCollector>();
|
||||
}
|
||||
const bool collect_usage =
|
||||
node && ctx->model() && ctx->model()->collect_resource_usage();
|
||||
f_opts.stats_collector = stats_collector.get();
|
||||
|
||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||
ret_types_);
|
||||
profiler::TraceMe activity(
|
||||
@ -872,7 +929,23 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
f_opts.step_id, "#");
|
||||
},
|
||||
profiler::TraceMeLevel::kInfo);
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
if (node) {
|
||||
// Resource usage for function execution is gathered from the executor.
|
||||
if (collect_usage) node->record_stop(EnvTime::NowNanos());
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
if (ctx->stats_aggregator()) {
|
||||
string prefix_with_func_name = strings::StrCat(
|
||||
node->name(), stats_utils::kDelimiter, captured_func_->func().name());
|
||||
ctx->stats_aggregator()->AddToHistogram(
|
||||
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
|
||||
{static_cast<float>(stats_collector->processing_time())},
|
||||
node->num_elements());
|
||||
}
|
||||
node->add_processing_time(stats_collector->processing_time());
|
||||
if (collect_usage) node->record_start(EnvTime::NowNanos());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(lib_->RunSync(std::move(f_opts), f_handle_, &frame));
|
||||
}
|
||||
return frame.ConsumeRetvals(rets);
|
||||
}
|
||||
|
||||
|
@ -48,6 +48,16 @@ Status MakeIteratorFromInputElement(
|
||||
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
|
||||
std::unique_ptr<IteratorBase>* out_iterator);
|
||||
|
||||
// Creates an iterator for a dataset which is created by applying the given
|
||||
// function to the given input element. Pass non-null `node` to record
|
||||
// processing time for modeling Iterator's GetNext() resource usage.
|
||||
Status MakeIteratorFromInputElement(
|
||||
IteratorContext* ctx, const IteratorBase* parent,
|
||||
const std::vector<Tensor>& input_element, int64 thread_index,
|
||||
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
|
||||
std::unique_ptr<IteratorBase>* out_iterator,
|
||||
const std::shared_ptr<model::Node>& node);
|
||||
|
||||
// Determines whether the given node is stateful.
|
||||
Status IsNodeStateful(const FunctionLibraryDefinition& library,
|
||||
const NodeDef& node);
|
||||
@ -215,6 +225,15 @@ class InstantiatedCapturedFunction {
|
||||
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
|
||||
// Runs the instantiated captured function. This method takes ownership of
|
||||
// the tensors in `args`, in order to be able to deallocate them as early as
|
||||
// possible. Use `RunWithBorrowedArgs()` if the caller needs to retain
|
||||
// ownership of the `args`. Pass non-null `node` to record processing time
|
||||
// for modeling Iterator's GetNext() resource usage.
|
||||
Status Run(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets,
|
||||
const std::shared_ptr<model::Node>& node) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible.
|
||||
@ -222,6 +241,15 @@ class InstantiatedCapturedFunction {
|
||||
const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible. Pass non-null `node` to record processing time for modeling
|
||||
// Iterator's GetNext() resource usage.
|
||||
Status RunWithBorrowedArgs(IteratorContext* ctx,
|
||||
const std::vector<Tensor>& args,
|
||||
std::vector<Tensor>* rets,
|
||||
const std::shared_ptr<model::Node>& node) const;
|
||||
|
||||
// Synchronously runs the captured function on the given `args`, and stores
|
||||
// the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when
|
||||
// possible. This can be useful for calling a captured function in cases where
|
||||
@ -234,7 +262,8 @@ class InstantiatedCapturedFunction {
|
||||
// Asynchronously runs the captured function on the given `args`, stores the
|
||||
// results in `*rets`, and calls the given `done` callback when the function
|
||||
// returns. This method takes ownership of the tensors in `args`, in order to
|
||||
// be able to deallocate them as early as possible.
|
||||
// be able to deallocate them as early as possible. Pass non-null `node` to
|
||||
// record processing time for modeling Iterator's GetNext() resource usage.
|
||||
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
|
||||
std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done,
|
||||
|
@ -366,7 +366,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Still running experiments
|
||||
if (!current_iterator_) {
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
|
||||
/*is_experiment=*/true));
|
||||
/*is_experiment=*/true,
|
||||
/*is_get_next=*/true));
|
||||
}
|
||||
|
||||
Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence);
|
||||
@ -385,7 +386,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!current_iterator_) {
|
||||
SelectFastestInputIndex();
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
|
||||
/*is_experiment=*/false));
|
||||
/*is_experiment=*/false,
|
||||
/*is_get_next=*/true));
|
||||
}
|
||||
}
|
||||
|
||||
@ -438,10 +440,12 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
if (!reader->Contains(full_name("input_impl_empty"))) {
|
||||
if (branch_index_ < dataset()->captured_funcs_.size()) {
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, branch_index_,
|
||||
/*is_experiment=*/true));
|
||||
/*is_experiment=*/true,
|
||||
/*is_get_next=*/false));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(MakeCurrentIterator(ctx, fastest_index_,
|
||||
/*is_experiment=*/false));
|
||||
/*is_experiment=*/false,
|
||||
/*is_get_next=*/false));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_iterator_));
|
||||
}
|
||||
@ -492,7 +496,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index,
|
||||
bool is_experiment)
|
||||
bool is_experiment, bool is_get_next)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
DCHECK_GE(branch_index, 0);
|
||||
DCHECK_LT(branch_index, histograms_.size());
|
||||
@ -528,10 +532,18 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(
|
||||
temp_dataset, wrapper_dataset_tensor_.get()));
|
||||
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, {*wrapper_dataset_tensor_}, branch_index,
|
||||
*instantiated_captured_funcs_[branch_index], prefix(),
|
||||
¤t_iterator_));
|
||||
if (is_get_next) {
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, {*wrapper_dataset_tensor_}, branch_index,
|
||||
*instantiated_captured_funcs_[branch_index], prefix(),
|
||||
¤t_iterator_, model_node()));
|
||||
} else {
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, {*wrapper_dataset_tensor_}, branch_index,
|
||||
*instantiated_captured_funcs_[branch_index], prefix(),
|
||||
¤t_iterator_, /*node=*/nullptr));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -229,7 +229,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Run the key function on the input element.
|
||||
std::vector<Tensor> key_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
||||
ctx, next_input_element, &key_func_output));
|
||||
ctx, next_input_element, &key_func_output, model_node()));
|
||||
|
||||
if (key_func_output.size() != 1 ||
|
||||
key_func_output[0].dtype() != DT_INT64 ||
|
||||
@ -244,7 +244,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
// Run the init function to create the initial state.
|
||||
std::vector<Tensor> init_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_init_func_->Run(
|
||||
ctx, std::move(key_func_output), &init_func_output));
|
||||
ctx, std::move(key_func_output), &init_func_output,
|
||||
model_node()));
|
||||
states_[key] = init_func_output;
|
||||
}
|
||||
|
||||
@ -258,7 +259,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
std::vector<Tensor> reduce_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
|
||||
ctx, std::move(args), &reduce_func_output));
|
||||
ctx, std::move(args), &reduce_func_output, model_node()));
|
||||
states_[key] = reduce_func_output;
|
||||
} else {
|
||||
keys_.resize(states_.size());
|
||||
@ -274,7 +275,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
||||
ctx, states_[keys_[keys_index_++]], out_tensors));
|
||||
ctx, states_[keys_[keys_index_++]], out_tensors, model_node()));
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -239,7 +239,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
// group.
|
||||
std::vector<Tensor> key_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
|
||||
ctx, next_input_element, &key_func_output));
|
||||
ctx, next_input_element, &key_func_output, model_node()));
|
||||
|
||||
if (key_func_output.size() != 1 ||
|
||||
key_func_output[0].dtype() != DT_INT64 ||
|
||||
@ -255,7 +255,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
// window size.
|
||||
std::vector<Tensor> window_size_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
|
||||
ctx, std::move(key_func_output), &window_size_func_output));
|
||||
ctx, std::move(key_func_output), &window_size_func_output,
|
||||
model_node()));
|
||||
|
||||
if (window_size_func_output.size() != 1 ||
|
||||
window_size_func_output[0].dtype() != DT_INT64 ||
|
||||
@ -486,8 +487,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor> args(
|
||||
{std::move(key_arg), std::move(group_dataset_arg)});
|
||||
std::vector<Tensor> return_values;
|
||||
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(ctx, std::move(args),
|
||||
&return_values));
|
||||
TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
|
||||
ctx, std::move(args), &return_values, model_node()));
|
||||
|
||||
if (!(return_values.size() == 1 &&
|
||||
return_values[0].dtype() == DT_VARIANT &&
|
||||
|
@ -147,8 +147,8 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<Tensor> output_tensors;
|
||||
TF_RETURN_IF_ERROR(
|
||||
function->RunWithBorrowedArgs(ctx, element, &output_tensors));
|
||||
TF_RETURN_IF_ERROR(function->RunWithBorrowedArgs(
|
||||
ctx, element, &output_tensors, /*node=*/nullptr));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
@ -315,8 +315,9 @@ class LoadDatasetOp::Dataset : public DatasetBase {
|
||||
std::vector<Tensor> reader_output;
|
||||
reader_input.push_back(std::move(input_dataset_tensor));
|
||||
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->Run(
|
||||
ctx, std::move(reader_input), &reader_output));
|
||||
ctx, std::move(reader_input), &reader_output, /*node=*/nullptr));
|
||||
if (reader_output.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"reader_func returns more than one argument.");
|
||||
|
@ -773,7 +773,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
MakeIteratorFromInputElement(
|
||||
ctx.get(), this, worker_thread_states_[thread_index].input,
|
||||
thread_index, *instantiated_captured_func_, prefix(),
|
||||
&worker_thread_states_[thread_index].iterator);
|
||||
&worker_thread_states_[thread_index].iterator,
|
||||
model_node());
|
||||
iterator_creation_status =
|
||||
worker_thread_states_[thread_index].iterator_creation_status;
|
||||
if (!iterator_creation_status.ok()) {
|
||||
@ -1011,9 +1012,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
state->iterator.reset();
|
||||
} else {
|
||||
std::unique_ptr<IteratorBase> iterator;
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, state->input, index, *instantiated_captured_func_,
|
||||
prefix(), &iterator));
|
||||
prefix(), &iterator, /*node=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
|
||||
state->iterator.swap(iterator);
|
||||
}
|
||||
|
@ -200,8 +200,8 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
state_and_output.reserve(dataset()->state_types_.size() +
|
||||
output_dtypes().size());
|
||||
|
||||
Status s = instantiated_captured_func_->Run(ctx, std::move(args),
|
||||
&state_and_output);
|
||||
Status s = instantiated_captured_func_->Run(
|
||||
ctx, std::move(args), &state_and_output, model_node());
|
||||
DCHECK(state_and_output.size() <=
|
||||
dataset()->state_types_.size() + output_dtypes().size());
|
||||
if (s.ok()) {
|
||||
|
@ -578,8 +578,9 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
std::vector<Tensor> reader_output;
|
||||
reader_input.push_back(std::move(input_dataset_tensor));
|
||||
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
TF_RETURN_IF_ERROR(instantiated_reader_func_->Run(
|
||||
ctx, std::move(reader_input), &reader_output));
|
||||
ctx, std::move(reader_input), &reader_output, /*node=*/nullptr));
|
||||
if (reader_output.size() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"reader_func returns more than one argument.");
|
||||
@ -673,7 +674,7 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Writer::GetShardIndex(
|
||||
|
||||
// Run the shard function
|
||||
TF_RETURN_IF_ERROR(instantiated_shard_func_->RunWithBorrowedArgs(
|
||||
ctx, tensors, &output_tensors));
|
||||
ctx, tensors, &output_tensors, model_node()));
|
||||
|
||||
if (output_tensors.size() != 1 || output_tensors[0].dtype() != DT_INT64 ||
|
||||
output_tensors[0].NumElements() != 1) {
|
||||
|
@ -151,7 +151,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
std::vector<Tensor> result;
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
||||
ctx, *out_tensors, &result));
|
||||
ctx, *out_tensors, &result, model_node()));
|
||||
|
||||
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
||||
result[0].NumElements() != 1) {
|
||||
|
@ -148,7 +148,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
std::vector<Tensor> result;
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func_->RunWithBorrowedArgs(
|
||||
ctx, *out_tensors, &result));
|
||||
ctx, *out_tensors, &result, model_node()));
|
||||
|
||||
if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
|
||||
result[0].NumElements() != 1) {
|
||||
|
@ -157,7 +157,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(BuildCurrentElementIteratorLocked(ctx));
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/true));
|
||||
} while (true);
|
||||
}
|
||||
|
||||
@ -230,7 +231,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
&captured_func_inputs_.back()));
|
||||
}
|
||||
element_index_--;
|
||||
TF_RETURN_IF_ERROR(BuildCurrentElementIteratorLocked(ctx));
|
||||
TF_RETURN_IF_ERROR(
|
||||
BuildCurrentElementIteratorLocked(ctx, /*is_get_next=*/false));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RestoreInput(ctx, reader, current_element_iterator_));
|
||||
}
|
||||
@ -239,11 +241,21 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx)
|
||||
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx,
|
||||
bool is_get_next)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
return MakeIteratorFromInputElement(
|
||||
ctx, this, captured_func_inputs_, element_index_++,
|
||||
*instantiated_captured_func_, prefix(), ¤t_element_iterator_);
|
||||
if (is_get_next) {
|
||||
return MakeIteratorFromInputElement(
|
||||
ctx, this, captured_func_inputs_, element_index_++,
|
||||
*instantiated_captured_func_, prefix(), ¤t_element_iterator_,
|
||||
model_node());
|
||||
} else {
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
return MakeIteratorFromInputElement(
|
||||
ctx, this, captured_func_inputs_, element_index_++,
|
||||
*instantiated_captured_func_, prefix(), ¤t_element_iterator_,
|
||||
/*node=*/nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
|
@ -127,8 +127,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
mutex_lock l(mu_);
|
||||
|
||||
if (!initialized_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
instantiated_init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
|
||||
TF_RETURN_IF_ERROR(instantiated_init_func_->RunWithBorrowedArgs(
|
||||
ctx, {}, &state_, model_node()));
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
@ -137,8 +137,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status s = instantiated_next_func_->RunWithBorrowedArgs(ctx, state_,
|
||||
out_tensors);
|
||||
Status s = instantiated_next_func_->RunWithBorrowedArgs(
|
||||
ctx, state_, out_tensors, model_node());
|
||||
if (s.ok()) {
|
||||
*end_of_sequence = false;
|
||||
} else if (errors::IsOutOfRange(s)) {
|
||||
@ -150,7 +150,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
// NOTE(mrry): We ignore any tensors returned by the finalize function.
|
||||
std::vector<Tensor> ignored;
|
||||
TF_RETURN_IF_ERROR(instantiated_finalize_func_->RunWithBorrowedArgs(
|
||||
ctx, state_, &ignored));
|
||||
ctx, state_, &ignored, model_node()));
|
||||
finalized_ = true;
|
||||
}
|
||||
return s;
|
||||
|
@ -182,7 +182,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, args_list_[cycle_index_], cycle_index_,
|
||||
*instantiated_captured_func_, prefix(),
|
||||
¤t_elements_[cycle_index_]));
|
||||
¤t_elements_[cycle_index_], model_node()));
|
||||
++num_open_;
|
||||
}
|
||||
} else {
|
||||
@ -276,9 +276,10 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
full_name(strings::StrCat(kArgsList, "[", idx, "][", i, "]")),
|
||||
&args_list_[idx][i]));
|
||||
}
|
||||
// NOTE: We intentionally ignore resource modeling outside GetNext().
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, args_list_[idx], idx, *instantiated_captured_func_,
|
||||
prefix(), ¤t_elements_[idx]));
|
||||
prefix(), ¤t_elements_[idx], /*node=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, current_elements_[idx]));
|
||||
} else {
|
||||
current_elements_[idx].reset();
|
||||
|
@ -736,7 +736,7 @@ class ReduceDatasetOp : public HybridAsyncOpKernel {
|
||||
|
||||
std::vector<Tensor> reduce_func_output;
|
||||
TF_RETURN_IF_ERROR(instantiated_captured_func->Run(
|
||||
&iter_ctx, std::move(args), &reduce_func_output));
|
||||
&iter_ctx, std::move(args), &reduce_func_output, /*node=*/nullptr));
|
||||
if (reduce_func_output.size() != state.size()) {
|
||||
return errors::InvalidArgument(
|
||||
"The number of components of the initial state and the "
|
||||
|
@ -158,8 +158,8 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status s =
|
||||
instantiated_captured_func_->Run(ctx, std::move(args), out_tensors);
|
||||
Status s = instantiated_captured_func_->Run(ctx, std::move(args),
|
||||
out_tensors, model_node());
|
||||
if (errors::IsOutOfRange(s)) {
|
||||
if (dataset()->preserve_cardinality_) {
|
||||
// To guarantee that the transformation preserves the cardinality of
|
||||
|
@ -999,7 +999,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
absl::make_unique<std::vector<Tensor>>(std::move(inputs));
|
||||
status = MakeIteratorFromInputElement(
|
||||
ctx_.get(), this, *element->inputs, element->id,
|
||||
*instantiated_captured_func_, prefix(), &element->iterator);
|
||||
*instantiated_captured_func_, prefix(), &element->iterator,
|
||||
model_node());
|
||||
if (!status.ok()) {
|
||||
element->inputs.reset();
|
||||
element->iterator.reset();
|
||||
@ -1291,7 +1292,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
reader->ReadScalar(iterator_name, kIdSuffix, &element->id));
|
||||
TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
|
||||
ctx, this, *element->inputs, element->id,
|
||||
*instantiated_captured_func_.get(), prefix(), &iterator));
|
||||
*instantiated_captured_func_.get(), prefix(), &iterator,
|
||||
model_node()));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
|
||||
mutex_lock l(*mu_);
|
||||
|
@ -444,7 +444,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
auto fn = std::bind(
|
||||
[this, ctx, result](std::vector<Tensor> input_element) {
|
||||
return instantiated_captured_func_->Run(
|
||||
ctx.get(), std::move(input_element), &result->return_values);
|
||||
ctx.get(), std::move(input_element), &result->return_values,
|
||||
model_node());
|
||||
},
|
||||
std::move(input_element));
|
||||
// `ctx->runner()` may execute its logic synchronously so we wrap it in
|
||||
|
Loading…
Reference in New Issue
Block a user