diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 422955221f8..eb506d29571 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -2297,6 +2297,10 @@ tf_cc_test( ":core", ":core_cpu", ":core_cpu_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -2307,6 +2311,8 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:math", "//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:state", diff --git a/tensorflow/core/common_runtime/executor_test.cc b/tensorflow/core/common_runtime/executor_test.cc index c5231449a00..9a1b7cff813 100644 --- a/tensorflow/core/common_runtime/executor_test.cc +++ b/tensorflow/core/common_runtime/executor_test.cc @@ -17,15 +17,24 @@ limitations under the License. #include +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/lower_functional_ops.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -532,4 +541,115 @@ static void BM_FeedInputFetchOutput(int iters) { } BENCHMARK(BM_FeedInputFetchOutput); +// Defines a graph to perform the following computation: +// +// i = 0 +// while (i < loop_iters) +// i += 1; +// +// ...using the functional `WhileOp` (if `lower` is false) or the +// `Switch`/`Merge`-style of control flow (if `lower` is true). +static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) { + testing::StopTiming(); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + // Add test functions for cond and body. + FunctionDefLibrary f_lib_proto; + + // Define the loop body as a function: `x = x + 1`. + const Tensor one_t = test::AsScalar(1); + *f_lib_proto.add_function() = FunctionDefHelper::Define( + // Name + "XPlusOne", + // Args + {"x: int32"}, + // Return values + {"y: int32"}, + // Attr def + {}, + // Nodes + { + {{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}}, + {{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}}, + }); + + // Define the loop condition as a function: `x < loop_iters`. + const Tensor loop_iters_t = test::AsScalar(loop_iters); + *f_lib_proto.add_function() = FunctionDefHelper::Define( + // Name + "LessThanOrEqualToN", + // Args + {"x: int32"}, + // Return values + {"z: bool"}, + // Attr def + {}, + // Nodes + { + {{"N"}, "Const", {}, {{"value", loop_iters_t}, {"dtype", DT_INT32}}}, + {{"z"}, "LessEqual", {"x", "N"}, {{"T", DT_INT32}}}, + }); + + Scope root = Scope::NewRootScope().ExitOnError(); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto)); + auto a = ops::Const(root.WithOpName("A"), 0, {}); + Node* while_node; + std::vector inputs({NodeBuilder::NodeOut(a.node())}); + AttrValue int32_attr; + int32_attr.set_type(DT_INT32); + AttrValue cond_func; + cond_func.mutable_func()->set_name("LessThanOrEqualToN"); + AttrValue body_func; + body_func.mutable_func()->set_name("XPlusOne"); + TF_ASSERT_OK( + NodeBuilder("while", "While", &root.graph()->flib_def()) + .Input(inputs) + .Attr("T", {DT_INT32}) + .Attr("cond", cond_func) + .Attr("body", body_func) + .Attr("parallel_iterations", 100) + .Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true) + .Finalize(root.graph(), &while_node)); + auto c = ops::Identity( + root.WithOpName("C").WithControlDependencies(Output(while_node)), + Output(while_node)); + TF_ASSERT_OK(root.DoShapeInference(while_node)); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + if (lower) { + FunctionLibraryDefinition flib_def(graph->flib_def()); + GraphOptimizationPassOptions opt_options; + SessionOptions session_options; + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_do_function_inlining(true); + opt_options.session_options = &session_options; + opt_options.graph = &graph; + opt_options.flib_def = &flib_def; + LowerFunctionalOpsPass pass; + TF_ASSERT_OK(pass.Run(opt_options)); + } + + FixupSourceAndSinkEdges(graph.get()); + testing::StartTiming(); + test::Benchmark("cpu", graph.release()).Run(iters); +} + +static void BM_LoweredWhileLoop(int iters, int loop_iters) { + BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true); +} +BENCHMARK(BM_LoweredWhileLoop)->Arg(0); +BENCHMARK(BM_LoweredWhileLoop)->Arg(1); +BENCHMARK(BM_LoweredWhileLoop)->Arg(10); +BENCHMARK(BM_LoweredWhileLoop)->Arg(100); +BENCHMARK(BM_LoweredWhileLoop)->Arg(1000); + +static void BM_FunctionalWhileLoop(int iters, int loop_iters) { + BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false); +} +BENCHMARK(BM_FunctionalWhileLoop)->Arg(0); +BENCHMARK(BM_FunctionalWhileLoop)->Arg(1); +BENCHMARK(BM_FunctionalWhileLoop)->Arg(10); +BENCHMARK(BM_FunctionalWhileLoop)->Arg(100); +BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 4118534cb3e..1b1234d114f 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -19,7 +19,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/executor_factory.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,8 +64,10 @@ Benchmark::Benchmark(const string& device, Graph* g, // Allow NewDevice to allocate a new threadpool with different number of // threads for each new benchmark. LocalDevice::set_use_global_threadpool(false); - device_ = - DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"); + + device_mgr_ = absl::make_unique( + DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0")); + device_ = device_mgr_->ListDevices()[0]; CHECK(device_) << "Could not create a " << device << " device"; pool_ = @@ -81,14 +85,24 @@ Benchmark::Benchmark(const string& device, Graph* g, const int graph_def_version = g->versions().producer(); + flib_def_ = absl::make_unique(g->flib_def()); + + pflr_ = std::unique_ptr( + new ProcessFunctionLibraryRuntime( + device_mgr_.get(), Env::Default(), nullptr, graph_def_version, + flib_def_.get(), OptimizerOptions(), pool_, nullptr, nullptr, nullptr, + Rendezvous::Factory())); + + flr_ = pflr_->GetFLR(device_->name()); + LocalExecutorParams params; - params.device = device_.get(); - params.function_library = nullptr; + params.device = device_; + params.function_library = flr_; params.create_kernel = [this, graph_def_version]( const std::shared_ptr& props, OpKernel** kernel) { - return CreateNonCachedKernel(device_.get(), nullptr, props, - graph_def_version, kernel); + return CreateNonCachedKernel(device_, flr_, props, graph_def_version, + kernel); }; params.delete_kernel = [](OpKernel* kernel) { DeleteNonCachedKernel(kernel); @@ -109,11 +123,12 @@ Benchmark::Benchmark(const string& device, Graph* g, Benchmark::~Benchmark() { if (device_) { rendez_->Unref(); - // We delete `exec_` before `device_` because the `exec_` destructor may + // We delete `exec_` before `device_mgr_` because the `exec_` destructor may // run kernel destructors that may attempt to access state borrowed from - // `device_`, such as the resource manager. + // `device_mgr_`, such as the resource manager. exec_.reset(); - device_.reset(); + pflr_.reset(); + device_mgr_.reset(); delete pool_; } } diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h index 742f40de0c2..9c6b1eb088c 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h @@ -29,7 +29,10 @@ limitations under the License. namespace tensorflow { class Device; +class FunctionLibraryRuntime; +class ProcessFunctionLibraryRuntime; struct SessionOptions; +class StaticDeviceMgr; namespace test { @@ -55,9 +58,13 @@ class Benchmark { const std::vector& outputs, int iters); private: - thread::ThreadPool* pool_ = nullptr; - std::unique_ptr device_ = nullptr; + thread::ThreadPool* pool_ = nullptr; // Not owned. + Device* device_ = nullptr; // Not owned. Rendezvous* rendez_ = nullptr; + std::unique_ptr device_mgr_; + std::unique_ptr flib_def_; + std::unique_ptr pflr_; + FunctionLibraryRuntime* flr_; // Not owned. std::unique_ptr exec_; TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); diff --git a/tensorflow/core/common_runtime/propagator_state.cc b/tensorflow/core/common_runtime/propagator_state.cc index 6d714d2fae9..4fd5e0f97d9 100644 --- a/tensorflow/core/common_runtime/propagator_state.cc +++ b/tensorflow/core/common_runtime/propagator_state.cc @@ -37,7 +37,7 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, // Initialize iteration 0. root_frame_->SetIteration( - 0, new PropagatorState::IterationState(root_frame_->pending_counts, + 0, new PropagatorState::IterationState(0, root_frame_->pending_counts, root_frame_->total_input_tensors)); outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); @@ -51,12 +51,13 @@ PropagatorState::~PropagatorState() { void PropagatorState::ActivateRoots(gtl::ArraySlice roots, TaggedNodeSeq* ready) { + mutex_lock l(root_frame_->mu); + IterationState* root_iter = root_frame_->GetIteration(0); for (const NodeItem* item : roots) { DCHECK_EQ(item->num_inputs, 0); - ready->emplace_back(item, root_frame_, 0, false); + ready->emplace_back(item, root_frame_, root_iter, false); } - mutex_lock l(root_frame_->mu); - root_frame_->GetIteration(0)->outstanding_ops = ready->size(); + root_iter->outstanding_ops = ready->size(); } void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, @@ -75,7 +76,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* const item = tagged_node.node_item; FrameState* const input_frame = tagged_node.input_frame; - const int64 input_iter = tagged_node.input_iter; + IterationState* const input_iter = tagged_node.input_iter; const bool is_dead = tagged_node.is_dead; // Propagates outputs along out edges, and puts newly ready nodes @@ -83,7 +84,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, DCHECK(ready->empty()); bool is_frame_done = false; FrameState* output_frame = input_frame; - int64 output_iter = input_iter; + IterationState* output_iter = input_iter; if (!item->is_enter_exit_or_next_iter) { // Fast path for nodes types that don't need special handling @@ -95,9 +96,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, input_frame->DecrementOutstandingOpsLocked(input_iter, ready); } else if (item->is_enter) { FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); - output_iter = 0; { mutex_lock l(output_frame->mu); + output_iter = output_frame->GetIteration(0); if (item->is_constant_enter) { // Propagate to all active iterations if this is a loop invariant. output_frame->AddLoopInv(item, (*outputs)[0], ready); @@ -111,7 +112,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, if (is_dead) { mutex_lock l(input_frame->mu); // Stop and remember this node if it is a dead exit. - if (input_iter == input_frame->iteration_count) { + if (input_iter->iter_num == input_frame->iteration_count) { input_frame->dead_exits.push_back(item); } is_frame_done = @@ -132,7 +133,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // Stop the deadness propagation. output_frame = nullptr; } else { - if (input_iter == input_frame->iteration_count && + if (input_iter->iter_num == input_frame->iteration_count && input_frame->num_outstanding_iterations == input_frame->max_parallel_iterations) { // Reached the maximum for parallel iterations. @@ -140,10 +141,11 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, output_frame = nullptr; } else { // If this is a new iteration, start it. - if (input_iter == input_frame->iteration_count) { - input_frame->IncrementIteration(ready); + if (input_iter->iter_num == input_frame->iteration_count) { + output_iter = input_frame->IncrementIteration(ready); + } else { + output_iter = input_frame->GetIteration(input_iter->iter_num + 1); } - output_iter = input_iter + 1; } } if (output_frame != nullptr) { @@ -159,7 +161,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, // completion of this node makes its frame completed. if (is_frame_done) { FrameState* parent_frame = input_frame->parent_frame; - const int64 parent_iter = input_frame->parent_iter; + IterationState* parent_iter = input_frame->parent_iter; DeleteFrame(input_frame, ready); if (parent_frame != nullptr) { // The completion of frame may cause completions in its parent frame. @@ -217,7 +219,8 @@ void PropagatorState::DumpState() { } } -void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, +void PropagatorState::FindOrCreateChildFrame(FrameState* frame, + IterationState* iter_state, const NodeItem& node_item, FrameState** child) { // Get the child frame name. @@ -225,8 +228,8 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, const string& enter_name = GetNodeAttrString(attrs, "frame_name"); DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node " << node_item.kernel->name(); - const string child_name = - strings::StrCat(frame->frame_name, ";", iter, ";", enter_name); + const string child_name = strings::StrCat( + frame->frame_name, ";", iter_state->iter_num, ";", enter_name); { mutex_lock executor_lock(mu_); @@ -251,14 +254,14 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, temp->frame_name = child_name; temp->frame_id = Hash64(child_name); temp->parent_frame = frame; - temp->parent_iter = iter; + temp->parent_iter = iter_state; temp->InitializeFrameInfo(enter_name); // Initialize iteration 0. { mutex_lock l(temp->mu); - temp->SetIteration( - 0, new IterationState(temp->pending_counts, temp->total_input_tensors)); + temp->SetIteration(0, new IterationState(0, temp->pending_counts, + temp->total_input_tensors)); } { @@ -268,7 +271,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, *child = it->second; } else { mutex_lock frame_lock(frame->mu); - frame->GetIteration(iter)->outstanding_frame_count++; + iter_state->outstanding_frame_count++; outstanding_frames_[child_name] = temp; *child = temp; temp = nullptr; @@ -280,20 +283,19 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { // First, propagate dead_exits (if any) to the parent frame. FrameState* parent_frame = frame->parent_frame; - const int64 parent_iter = frame->parent_iter; + IterationState* parent_iter_state = frame->parent_iter; if (parent_frame != nullptr) { mutex_lock parent_frame_lock(parent_frame->mu); // Propagate all the dead exits to the parent frame. mutex_lock this_frame_lock(frame->mu); for (const NodeItem* item : frame->dead_exits) { - auto parent_iter_state = parent_frame->GetIteration(parent_iter); - auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready, bool dst_dead) { if (dst_ready) { if (dst_item.is_control_trigger) dst_dead = false; - ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead); + ready->emplace_back(&dst_item, parent_frame, parent_iter_state, + dst_dead); parent_iter_state->outstanding_ops++; } }; @@ -356,17 +358,18 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { delete frame; } -void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter, +void PropagatorState::CleanupFramesIterations(FrameState* frame, + IterationState* iter_state, TaggedNodeSeq* ready) { bool is_frame_done = false; { mutex_lock frame_lock(frame->mu); - frame->GetIteration(iter)->outstanding_frame_count--; - is_frame_done = frame->CleanupIterations(iter, ready); + iter_state->outstanding_frame_count--; + is_frame_done = frame->CleanupIterations(iter_state, ready); } if (is_frame_done) { FrameState* parent_frame = frame->parent_frame; - const int64 parent_iter = frame->parent_iter; + IterationState* parent_iter = frame->parent_iter; DeleteFrame(frame, ready); if (parent_frame != nullptr) { // The completion of frame may cause completions in its parent frame. @@ -376,16 +379,13 @@ void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter, } } -void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, - const bool is_dead, - int64 iter, - EntryVector* outputs, - TaggedNodeSeq* ready) { +void PropagatorState::FrameState::ActivateNodesFastPath( + const NodeItem* item, const bool is_dead, IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) { // If we know that none of the item's edge destinations require special // handling (i.e. none of the nodes is a merge or control trigger node), we // can take a fast path that avoids accessing the destination NodeItem. const GraphView& gview = immutable_state.graph_view(); - IterationState* iter_state = GetIteration(iter); // Add dst to the ready queue if it's ready // @@ -398,7 +398,7 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, TaggedNode& t = ready->emplace_back(); \ t.node_item = dst_item; \ t.input_frame = this; \ - t.input_iter = iter; \ + t.input_iter = iter_state; \ t.is_dead = adjust_result.any_dead; \ iter_state->outstanding_ops++; \ } \ @@ -436,23 +436,20 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, #undef MAYBE_ADD_TO_READY } -void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, - const bool is_dead, - int64 iter, - EntryVector* outputs, - TaggedNodeSeq* ready) { +void PropagatorState::FrameState::ActivateNodesSlowPath( + const NodeItem* item, const bool is_dead, IterationState* iter_state, + EntryVector* outputs, TaggedNodeSeq* ready) { // If any of the edge destinations is a merge or a control trigger node, // we need to read each destination NodeItem to determine what action // to take. const GraphView& gview = immutable_state.graph_view(); - IterationState* iter_state = GetIteration(iter); auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, bool dst_ready, bool dst_dead) { // Add dst to the ready queue if it's ready if (dst_ready) { if (dst_item->is_control_trigger) dst_dead = false; - ready->emplace_back(dst_item, this, iter, dst_dead); + ready->emplace_back(dst_item, this, iter_state, dst_dead); iter_state->outstanding_ops++; } }; @@ -551,17 +548,18 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, } void PropagatorState::FrameState::ActivateNodes(const NodeItem* item, - const bool is_dead, int64 iter, + const bool is_dead, + IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) { if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { - ActivateNodesSlowPath(item, is_dead, iter, outputs, ready); + ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready); } else { - ActivateNodesFastPath(item, is_dead, iter, outputs, ready); + ActivateNodesFastPath(item, is_dead, iter_state, outputs, ready); } } -void PropagatorState::FrameState::ActivateNexts(int64 iter, +void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready) { // Propagate the deferred NextIteration nodes to the new iteration. for (auto& node_entry : next_iter_roots) { @@ -569,12 +567,12 @@ void PropagatorState::FrameState::ActivateNexts(int64 iter, const Entry& entry = node_entry.second; const bool is_dead = entry.state == Entry::State::NO_VALUE; EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter, &outputs, ready); + ActivateNodes(item, is_dead, iter_state, &outputs, ready); } next_iter_roots.clear(); } -void PropagatorState::FrameState::ActivateLoopInvs(int64 iter, +void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready) { // Propagate loop invariants to the new iteration. for (auto& node_entry : inv_values) { @@ -582,7 +580,7 @@ void PropagatorState::FrameState::ActivateLoopInvs(int64 iter, const Entry& entry = node_entry.second; const bool is_dead = entry.state == Entry::State::NO_VALUE; EntryVector outputs{entry}; - ActivateNodes(item, is_dead, iter, &outputs, ready); + ActivateNodes(item, is_dead, iter_state, &outputs, ready); } } @@ -596,33 +594,32 @@ void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, const bool is_dead = entry.state == Entry::State::NO_VALUE; for (int i = 0; i <= iteration_count; ++i) { EntryVector outputs{entry}; - ActivateNodes(item, is_dead, i, &outputs, ready); + ActivateNodes(item, is_dead, GetIteration(i), &outputs, ready); } } -bool PropagatorState::FrameState::IsIterationDone(int64 iter) { - IterationState* iter_state = GetIteration(iter); +bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) { if (iter_state->outstanding_ops == 0 && iter_state->outstanding_frame_count == 0) { - if (iter == 0) { + if (iter_state->iter_num == 0) { // The enclosing frame has no pending input. return num_pending_inputs == 0; } else { // The preceding iteration is deleted (and therefore done). - return (GetIteration(iter - 1) == nullptr); + return (GetIteration(iter_state->iter_num - 1) == nullptr); } } return false; } -void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) { +PropagatorState::IterationState* +PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) { iteration_count++; - const int64 next_iter = iteration_count; // Initialize the next iteration. - IterationState* iter_state = - new IterationState(pending_counts, total_input_tensors); - SetIteration(next_iter, iter_state); + IterationState* next_iter = + new IterationState(iteration_count, pending_counts, total_input_tensors); + SetIteration(iteration_count, next_iter); num_outstanding_iterations++; dead_exits.clear(); @@ -631,14 +628,15 @@ void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) { // Activate the loop invariants in the new iteration. ActivateLoopInvs(next_iter, ready); + + return next_iter; } -bool PropagatorState::FrameState::CleanupIterations(int64 iter, +bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready) { - int64 curr_iter = iter; - while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) { - // Delete the iteration curr_iter. - delete GetIteration(curr_iter); + int64 curr_iter = iter_state->iter_num; + while (curr_iter <= iteration_count && IsIterationDone(iter_state)) { + delete iter_state; SetIteration(curr_iter, nullptr); --num_outstanding_iterations; ++curr_iter; @@ -648,6 +646,10 @@ bool PropagatorState::FrameState::CleanupIterations(int64 iter, if (!next_iter_roots.empty()) { IncrementIteration(ready); } + + if (curr_iter <= iteration_count) { + iter_state = GetIteration(curr_iter); + } } return IsFrameDone(); } @@ -677,21 +679,21 @@ void PropagatorState::FrameState::SetIteration(int64 iter, // Decrement the outstanding op count and clean up the iterations in the // frame. Return true iff the execution of the frame is done. bool PropagatorState::FrameState::DecrementOutstandingOps( - int64 iter, TaggedNodeSeq* ready) { + IterationState* iter_state, TaggedNodeSeq* ready) { mutex_lock l(mu); - return DecrementOutstandingOpsLocked(iter, ready); + return DecrementOutstandingOpsLocked(iter_state, ready); } // Decrement the outstanding op count and clean up the iterations in the // frame. Return true iff the execution of the frame is done. bool PropagatorState::FrameState::DecrementOutstandingOpsLocked( - int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { - IterationState* istate = GetIteration(iter); - istate->outstanding_ops--; - if (istate->outstanding_ops != 0) { + IterationState* iter_state, TaggedNodeSeq* ready) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { + iter_state->outstanding_ops--; + if (iter_state->outstanding_ops != 0) { return false; } else { - return CleanupIterations(iter, ready); + return CleanupIterations(iter_state, ready); } } diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 13aadde7ff0..459e28a83ee 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -49,8 +49,10 @@ class PropagatorState { ~PropagatorState(); private: - // Forward declaration so that `TaggedNode` can include a `FrameState*`. + // Forward declaration so that `TaggedNode` can include a `FrameState*` and an + // `IterationState*`. struct FrameState; + struct IterationState; public: // A `TaggedNode` corresponds to a single invocation of a node's kernel, @@ -59,12 +61,12 @@ class PropagatorState { struct TaggedNode { const NodeItem* node_item; FrameState* input_frame; - int64 input_iter; + IterationState* input_iter; bool is_dead; TaggedNode() = default; - TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter, - bool dead) + TaggedNode(const NodeItem* node_item, FrameState* in_frame, + IterationState* in_iter, bool dead) : node_item(node_item), input_frame(in_frame), input_iter(in_iter), @@ -73,7 +75,7 @@ class PropagatorState { const NodeItem& get_node_item() const { return *node_item; } bool get_is_dead() const { return is_dead; } - int64 get_iter_num() const { return input_iter; } + int64 get_iter_num() const; }; // A drop-in replacement for std::deque. We typically don't @@ -116,16 +118,18 @@ class PropagatorState { typedef gtl::InlinedVector TaggedNodeSeq; private: + // The state of an iteration in a particular frame. struct IterationState { - explicit IterationState(const PendingCounts* pending_counts, + explicit IterationState(int64 iter_num, const PendingCounts* pending_counts, int total_input_tensors) - : input_tensors(new Entry[total_input_tensors]), + : iter_num(iter_num), + input_tensors(new Entry[total_input_tensors]), outstanding_ops(0), outstanding_frame_count(0), counts(*pending_counts) { // Initialize with copy of *pending_counts } - // The state of an iteration. + const int64 iter_num; // The index of this iteration in the enclosing loop. // One copy per iteration. For iteration k, i-th node's j-th input is in // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is @@ -221,10 +225,10 @@ class PropagatorState { // frame_name. uint64 frame_id; - // The iteration id of its parent frame when this frame is created. - // -1 if there is no parent frame. The frame_name/parent_iter pair + // The iteration state of its parent frame when this frame is created. + // nullptr if there is no parent frame. The frame_name/parent_iter pair // uniquely identifies this FrameState. - int64 parent_iter = -1; + IterationState* parent_iter = nullptr; // The FrameState of its parent frame. FrameState* parent_frame = nullptr; @@ -291,28 +295,33 @@ class PropagatorState { // Decrement the outstanding op count and clean up the iterations in the // frame. Return true iff the execution of the frame is done. - bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready); + bool DecrementOutstandingOps(IterationState* iter_state, + TaggedNodeSeq* ready); // Decrement the outstanding op count and clean up the iterations in the // frame. Return true iff the execution of the frame is done. - bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready); + bool DecrementOutstandingOpsLocked(IterationState* iter_state, + TaggedNodeSeq* ready); // Returns true if the computation in the frame is completed. bool IsFrameDone(); // Returns true if the iteration of the frame is completed. - bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + bool IsIterationDone(IterationState* iter_state) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu); // Increments the iteration id. If this is a new iteration, initialize it. - void IncrementIteration(TaggedNodeSeq* ready) + // + // Returns a pointer to the new iteration. + IterationState* IncrementIteration(TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); // Activate all the deferred NextIteration nodes in a new iteration. - void ActivateNexts(int64 iter, TaggedNodeSeq* ready) + void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); // Activate all the current loop invariants in a new iteration. - void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready) + void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); // Add a new loop invariant and make it available to all active @@ -322,12 +331,12 @@ class PropagatorState { // Activate the successors of a node. Contents of *outputs are left in an // indeterminate state after returning from this method. - void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, - EntryVector* outputs, TaggedNodeSeq* ready) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu); + void ActivateNodes(const NodeItem* item, const bool is_dead, + IterationState* iter_state, EntryVector* outputs, + TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); - // Cleanup iterations of this frame starting from iteration iter. - bool CleanupIterations(int64 iter, TaggedNodeSeq* ready) + // Cleanup iterations of this frame starting from the given iteration. + bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); void DumpIterationState(PropagatorState* parent) { @@ -350,12 +359,12 @@ class PropagatorState { private: // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, - int64 iter, EntryVector* outputs, + IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, - int64 iter, EntryVector* outputs, + IterationState* iter_state, EntryVector* outputs, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); }; @@ -379,13 +388,13 @@ class PropagatorState { // same address while the iteration is live. Entry* GetInputTensors(const TaggedNode& tagged_node) const TF_NO_THREAD_SAFETY_ANALYSIS { - return tagged_node.input_frame->GetIteration(tagged_node.input_iter) - ->input_tensors + + return tagged_node.input_iter->input_tensors + tagged_node.node_item->input_start; } FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { - return {tagged_node.input_frame->frame_id, tagged_node.input_iter}; + return {tagged_node.input_frame->frame_id, + tagged_node.input_iter->iter_num}; } // Provide debugging output of the state of the executor. @@ -397,9 +406,8 @@ class PropagatorState { // optional debugging support. if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { mutex_lock l(tagged_node.input_frame->mu); - tagged_node.input_frame->GetIteration(tagged_node.input_iter) - ->mark_started( - immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + tagged_node.input_iter->mark_started( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); } } @@ -408,16 +416,15 @@ class PropagatorState { // optional debugging support. if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { mutex_lock l(tagged_node.input_frame->mu); - tagged_node.input_frame->GetIteration(tagged_node.input_iter) - ->mark_completed( - immutable_state_.pending_ids()[tagged_node.node_item->node_id]); + tagged_node.input_iter->mark_completed( + immutable_state_.pending_ids()[tagged_node.node_item->node_id]); } } private: // Find an existing or create a new child frame in the frame 'frame' at // iteration 'iter'. - void FindOrCreateChildFrame(FrameState* frame, int64 iter, + void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state, const NodeItem& node_item, FrameState** child); // Delete a frame. Called when the frame is done. @@ -425,7 +432,7 @@ class PropagatorState { // Cleanup frames and iterations starting from frame/iter. Called when // a child frame is done. - void CleanupFramesIterations(FrameState* frame, int64 iter, + void CleanupFramesIterations(FrameState* frame, IterationState* iter_state, TaggedNodeSeq* ready); // Provide debugging output about an outstanding iteration in the executor. @@ -450,6 +457,10 @@ class PropagatorState { TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState); }; +inline int64 PropagatorState::TaggedNode::get_iter_num() const { + return input_iter->iter_num; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_