[PropagatorState] Use IterationState* instead of int64 iteration ID in TaggedNode.

This change avoids making frequent calls (at least 2 per kernel dispatched) to `FrameState::GetIteration()`, which perform integer division. Since this work happens under the `FrameState` mutex, reducing the work done here should also ease contention on that mutex slightly.

This change adds microbenchmarks for both the functional (`WhileOp`) and lowered (`Switch`/`Merge`/`Enter`/`Exit`) loop implementations. To support these microbenchmarks, it adds support to "kernel_benchmark_testlib" for running functions during the benchmark.

PiperOrigin-RevId: 310175986
Change-Id: I9af0781e341e5d8cb19729de375078f9f4237c54
This commit is contained in:
Derek Murray 2020-05-06 10:13:40 -07:00 committed by TensorFlower Gardener
parent 17d5c85577
commit ae2a0e5c47
6 changed files with 278 additions and 117 deletions

View File

@ -2297,6 +2297,10 @@ tf_cc_test(
":core", ":core",
":core_cpu", ":core_cpu",
":core_cpu_internal", ":core_cpu_internal",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
@ -2307,6 +2311,8 @@ tf_cc_test(
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"//tensorflow/core/kernels:array", "//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
"//tensorflow/core/kernels:math", "//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:random_ops", "//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:state", "//tensorflow/core/kernels:state",

View File

@ -17,15 +17,24 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
#include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -532,4 +541,115 @@ static void BM_FeedInputFetchOutput(int iters) {
} }
BENCHMARK(BM_FeedInputFetchOutput); BENCHMARK(BM_FeedInputFetchOutput);
// Defines a graph to perform the following computation:
//
// i = 0
// while (i < loop_iters)
// i += 1;
//
// ...using the functional `WhileOp` (if `lower` is false) or the
// `Switch`/`Merge`-style of control flow (if `lower` is true).
static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
testing::StopTiming();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
// Add test functions for cond and body.
FunctionDefLibrary f_lib_proto;
// Define the loop body as a function: `x = x + 1`.
const Tensor one_t = test::AsScalar<int32>(1);
*f_lib_proto.add_function() = FunctionDefHelper::Define(
// Name
"XPlusOne",
// Args
{"x: int32"},
// Return values
{"y: int32"},
// Attr def
{},
// Nodes
{
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}},
{{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}},
});
// Define the loop condition as a function: `x < loop_iters`.
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
*f_lib_proto.add_function() = FunctionDefHelper::Define(
// Name
"LessThanOrEqualToN",
// Args
{"x: int32"},
// Return values
{"z: bool"},
// Attr def
{},
// Nodes
{
{{"N"}, "Const", {}, {{"value", loop_iters_t}, {"dtype", DT_INT32}}},
{{"z"}, "LessEqual", {"x", "N"}, {{"T", DT_INT32}}},
});
Scope root = Scope::NewRootScope().ExitOnError();
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto a = ops::Const(root.WithOpName("A"), 0, {});
Node* while_node;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
AttrValue int32_attr;
int32_attr.set_type(DT_INT32);
AttrValue cond_func;
cond_func.mutable_func()->set_name("LessThanOrEqualToN");
AttrValue body_func;
body_func.mutable_func()->set_name("XPlusOne");
TF_ASSERT_OK(
NodeBuilder("while", "While", &root.graph()->flib_def())
.Input(inputs)
.Attr("T", {DT_INT32})
.Attr("cond", cond_func)
.Attr("body", body_func)
.Attr("parallel_iterations", 100)
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
.Finalize(root.graph(), &while_node));
auto c = ops::Identity(
root.WithOpName("C").WithControlDependencies(Output(while_node)),
Output(while_node));
TF_ASSERT_OK(root.DoShapeInference(while_node));
TF_ASSERT_OK(root.ToGraph(graph.get()));
if (lower) {
FunctionLibraryDefinition flib_def(graph->flib_def());
GraphOptimizationPassOptions opt_options;
SessionOptions session_options;
session_options.config.mutable_graph_options()
->mutable_optimizer_options()
->set_do_function_inlining(true);
opt_options.session_options = &session_options;
opt_options.graph = &graph;
opt_options.flib_def = &flib_def;
LowerFunctionalOpsPass pass;
TF_ASSERT_OK(pass.Run(opt_options));
}
FixupSourceAndSinkEdges(graph.get());
testing::StartTiming();
test::Benchmark("cpu", graph.release()).Run(iters);
}
static void BM_LoweredWhileLoop(int iters, int loop_iters) {
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true);
}
BENCHMARK(BM_LoweredWhileLoop)->Arg(0);
BENCHMARK(BM_LoweredWhileLoop)->Arg(1);
BENCHMARK(BM_LoweredWhileLoop)->Arg(10);
BENCHMARK(BM_LoweredWhileLoop)->Arg(100);
BENCHMARK(BM_LoweredWhileLoop)->Arg(1000);
static void BM_FunctionalWhileLoop(int iters, int loop_iters) {
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false);
}
BENCHMARK(BM_FunctionalWhileLoop)->Arg(0);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(10);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(100);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000);
} // namespace tensorflow } // namespace tensorflow

View File

@ -19,7 +19,9 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/executor_factory.h" #include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -62,8 +64,10 @@ Benchmark::Benchmark(const string& device, Graph* g,
// Allow NewDevice to allocate a new threadpool with different number of // Allow NewDevice to allocate a new threadpool with different number of
// threads for each new benchmark. // threads for each new benchmark.
LocalDevice::set_use_global_threadpool(false); LocalDevice::set_use_global_threadpool(false);
device_ =
DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"); device_mgr_ = absl::make_unique<StaticDeviceMgr>(
DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"));
device_ = device_mgr_->ListDevices()[0];
CHECK(device_) << "Could not create a " << device << " device"; CHECK(device_) << "Could not create a " << device << " device";
pool_ = pool_ =
@ -81,14 +85,24 @@ Benchmark::Benchmark(const string& device, Graph* g,
const int graph_def_version = g->versions().producer(); const int graph_def_version = g->versions().producer();
flib_def_ = absl::make_unique<FunctionLibraryDefinition>(g->flib_def());
pflr_ = std::unique_ptr<ProcessFunctionLibraryRuntime>(
new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), nullptr, graph_def_version,
flib_def_.get(), OptimizerOptions(), pool_, nullptr, nullptr, nullptr,
Rendezvous::Factory()));
flr_ = pflr_->GetFLR(device_->name());
LocalExecutorParams params; LocalExecutorParams params;
params.device = device_.get(); params.device = device_;
params.function_library = nullptr; params.function_library = flr_;
params.create_kernel = [this, graph_def_version]( params.create_kernel = [this, graph_def_version](
const std::shared_ptr<const NodeProperties>& props, const std::shared_ptr<const NodeProperties>& props,
OpKernel** kernel) { OpKernel** kernel) {
return CreateNonCachedKernel(device_.get(), nullptr, props, return CreateNonCachedKernel(device_, flr_, props, graph_def_version,
graph_def_version, kernel); kernel);
}; };
params.delete_kernel = [](OpKernel* kernel) { params.delete_kernel = [](OpKernel* kernel) {
DeleteNonCachedKernel(kernel); DeleteNonCachedKernel(kernel);
@ -109,11 +123,12 @@ Benchmark::Benchmark(const string& device, Graph* g,
Benchmark::~Benchmark() { Benchmark::~Benchmark() {
if (device_) { if (device_) {
rendez_->Unref(); rendez_->Unref();
// We delete `exec_` before `device_` because the `exec_` destructor may // We delete `exec_` before `device_mgr_` because the `exec_` destructor may
// run kernel destructors that may attempt to access state borrowed from // run kernel destructors that may attempt to access state borrowed from
// `device_`, such as the resource manager. // `device_mgr_`, such as the resource manager.
exec_.reset(); exec_.reset();
device_.reset(); pflr_.reset();
device_mgr_.reset();
delete pool_; delete pool_;
} }
} }

View File

@ -29,7 +29,10 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
class Device; class Device;
class FunctionLibraryRuntime;
class ProcessFunctionLibraryRuntime;
struct SessionOptions; struct SessionOptions;
class StaticDeviceMgr;
namespace test { namespace test {
@ -55,9 +58,13 @@ class Benchmark {
const std::vector<string>& outputs, int iters); const std::vector<string>& outputs, int iters);
private: private:
thread::ThreadPool* pool_ = nullptr; thread::ThreadPool* pool_ = nullptr; // Not owned.
std::unique_ptr<Device> device_ = nullptr; Device* device_ = nullptr; // Not owned.
Rendezvous* rendez_ = nullptr; Rendezvous* rendez_ = nullptr;
std::unique_ptr<StaticDeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
FunctionLibraryRuntime* flr_; // Not owned.
std::unique_ptr<Executor> exec_; std::unique_ptr<Executor> exec_;
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark); TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);

View File

@ -37,7 +37,7 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
// Initialize iteration 0. // Initialize iteration 0.
root_frame_->SetIteration( root_frame_->SetIteration(
0, new PropagatorState::IterationState(root_frame_->pending_counts, 0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
root_frame_->total_input_tensors)); root_frame_->total_input_tensors));
outstanding_frames_.insert({root_frame_->frame_name, root_frame_}); outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
@ -51,12 +51,13 @@ PropagatorState::~PropagatorState() {
void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
TaggedNodeSeq* ready) { TaggedNodeSeq* ready) {
mutex_lock l(root_frame_->mu);
IterationState* root_iter = root_frame_->GetIteration(0);
for (const NodeItem* item : roots) { for (const NodeItem* item : roots) {
DCHECK_EQ(item->num_inputs, 0); DCHECK_EQ(item->num_inputs, 0);
ready->emplace_back(item, root_frame_, 0, false); ready->emplace_back(item, root_frame_, root_iter, false);
} }
mutex_lock l(root_frame_->mu); root_iter->outstanding_ops = ready->size();
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
} }
void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
@ -75,7 +76,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
const NodeItem* const item = tagged_node.node_item; const NodeItem* const item = tagged_node.node_item;
FrameState* const input_frame = tagged_node.input_frame; FrameState* const input_frame = tagged_node.input_frame;
const int64 input_iter = tagged_node.input_iter; IterationState* const input_iter = tagged_node.input_iter;
const bool is_dead = tagged_node.is_dead; const bool is_dead = tagged_node.is_dead;
// Propagates outputs along out edges, and puts newly ready nodes // Propagates outputs along out edges, and puts newly ready nodes
@ -83,7 +84,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
DCHECK(ready->empty()); DCHECK(ready->empty());
bool is_frame_done = false; bool is_frame_done = false;
FrameState* output_frame = input_frame; FrameState* output_frame = input_frame;
int64 output_iter = input_iter; IterationState* output_iter = input_iter;
if (!item->is_enter_exit_or_next_iter) { if (!item->is_enter_exit_or_next_iter) {
// Fast path for nodes types that don't need special handling // Fast path for nodes types that don't need special handling
@ -95,9 +96,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
input_frame->DecrementOutstandingOpsLocked(input_iter, ready); input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
} else if (item->is_enter) { } else if (item->is_enter) {
FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
output_iter = 0;
{ {
mutex_lock l(output_frame->mu); mutex_lock l(output_frame->mu);
output_iter = output_frame->GetIteration(0);
if (item->is_constant_enter) { if (item->is_constant_enter) {
// Propagate to all active iterations if this is a loop invariant. // Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready); output_frame->AddLoopInv(item, (*outputs)[0], ready);
@ -111,7 +112,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
if (is_dead) { if (is_dead) {
mutex_lock l(input_frame->mu); mutex_lock l(input_frame->mu);
// Stop and remember this node if it is a dead exit. // Stop and remember this node if it is a dead exit.
if (input_iter == input_frame->iteration_count) { if (input_iter->iter_num == input_frame->iteration_count) {
input_frame->dead_exits.push_back(item); input_frame->dead_exits.push_back(item);
} }
is_frame_done = is_frame_done =
@ -132,7 +133,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
// Stop the deadness propagation. // Stop the deadness propagation.
output_frame = nullptr; output_frame = nullptr;
} else { } else {
if (input_iter == input_frame->iteration_count && if (input_iter->iter_num == input_frame->iteration_count &&
input_frame->num_outstanding_iterations == input_frame->num_outstanding_iterations ==
input_frame->max_parallel_iterations) { input_frame->max_parallel_iterations) {
// Reached the maximum for parallel iterations. // Reached the maximum for parallel iterations.
@ -140,10 +141,11 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
output_frame = nullptr; output_frame = nullptr;
} else { } else {
// If this is a new iteration, start it. // If this is a new iteration, start it.
if (input_iter == input_frame->iteration_count) { if (input_iter->iter_num == input_frame->iteration_count) {
input_frame->IncrementIteration(ready); output_iter = input_frame->IncrementIteration(ready);
} else {
output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
} }
output_iter = input_iter + 1;
} }
} }
if (output_frame != nullptr) { if (output_frame != nullptr) {
@ -159,7 +161,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
// completion of this node makes its frame completed. // completion of this node makes its frame completed.
if (is_frame_done) { if (is_frame_done) {
FrameState* parent_frame = input_frame->parent_frame; FrameState* parent_frame = input_frame->parent_frame;
const int64 parent_iter = input_frame->parent_iter; IterationState* parent_iter = input_frame->parent_iter;
DeleteFrame(input_frame, ready); DeleteFrame(input_frame, ready);
if (parent_frame != nullptr) { if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame. // The completion of frame may cause completions in its parent frame.
@ -217,7 +219,8 @@ void PropagatorState::DumpState() {
} }
} }
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter, void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
IterationState* iter_state,
const NodeItem& node_item, const NodeItem& node_item,
FrameState** child) { FrameState** child) {
// Get the child frame name. // Get the child frame name.
@ -225,8 +228,8 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
const string& enter_name = GetNodeAttrString(attrs, "frame_name"); const string& enter_name = GetNodeAttrString(attrs, "frame_name");
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node " DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
<< node_item.kernel->name(); << node_item.kernel->name();
const string child_name = const string child_name = strings::StrCat(
strings::StrCat(frame->frame_name, ";", iter, ";", enter_name); frame->frame_name, ";", iter_state->iter_num, ";", enter_name);
{ {
mutex_lock executor_lock(mu_); mutex_lock executor_lock(mu_);
@ -251,14 +254,14 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
temp->frame_name = child_name; temp->frame_name = child_name;
temp->frame_id = Hash64(child_name); temp->frame_id = Hash64(child_name);
temp->parent_frame = frame; temp->parent_frame = frame;
temp->parent_iter = iter; temp->parent_iter = iter_state;
temp->InitializeFrameInfo(enter_name); temp->InitializeFrameInfo(enter_name);
// Initialize iteration 0. // Initialize iteration 0.
{ {
mutex_lock l(temp->mu); mutex_lock l(temp->mu);
temp->SetIteration( temp->SetIteration(0, new IterationState(0, temp->pending_counts,
0, new IterationState(temp->pending_counts, temp->total_input_tensors)); temp->total_input_tensors));
} }
{ {
@ -268,7 +271,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
*child = it->second; *child = it->second;
} else { } else {
mutex_lock frame_lock(frame->mu); mutex_lock frame_lock(frame->mu);
frame->GetIteration(iter)->outstanding_frame_count++; iter_state->outstanding_frame_count++;
outstanding_frames_[child_name] = temp; outstanding_frames_[child_name] = temp;
*child = temp; *child = temp;
temp = nullptr; temp = nullptr;
@ -280,20 +283,19 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
// First, propagate dead_exits (if any) to the parent frame. // First, propagate dead_exits (if any) to the parent frame.
FrameState* parent_frame = frame->parent_frame; FrameState* parent_frame = frame->parent_frame;
const int64 parent_iter = frame->parent_iter; IterationState* parent_iter_state = frame->parent_iter;
if (parent_frame != nullptr) { if (parent_frame != nullptr) {
mutex_lock parent_frame_lock(parent_frame->mu); mutex_lock parent_frame_lock(parent_frame->mu);
// Propagate all the dead exits to the parent frame. // Propagate all the dead exits to the parent frame.
mutex_lock this_frame_lock(frame->mu); mutex_lock this_frame_lock(frame->mu);
for (const NodeItem* item : frame->dead_exits) { for (const NodeItem* item : frame->dead_exits) {
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready, auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
bool dst_dead) { bool dst_dead) {
if (dst_ready) { if (dst_ready) {
if (dst_item.is_control_trigger) dst_dead = false; if (dst_item.is_control_trigger) dst_dead = false;
ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead); ready->emplace_back(&dst_item, parent_frame, parent_iter_state,
dst_dead);
parent_iter_state->outstanding_ops++; parent_iter_state->outstanding_ops++;
} }
}; };
@ -356,17 +358,18 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
delete frame; delete frame;
} }
void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter, void PropagatorState::CleanupFramesIterations(FrameState* frame,
IterationState* iter_state,
TaggedNodeSeq* ready) { TaggedNodeSeq* ready) {
bool is_frame_done = false; bool is_frame_done = false;
{ {
mutex_lock frame_lock(frame->mu); mutex_lock frame_lock(frame->mu);
frame->GetIteration(iter)->outstanding_frame_count--; iter_state->outstanding_frame_count--;
is_frame_done = frame->CleanupIterations(iter, ready); is_frame_done = frame->CleanupIterations(iter_state, ready);
} }
if (is_frame_done) { if (is_frame_done) {
FrameState* parent_frame = frame->parent_frame; FrameState* parent_frame = frame->parent_frame;
const int64 parent_iter = frame->parent_iter; IterationState* parent_iter = frame->parent_iter;
DeleteFrame(frame, ready); DeleteFrame(frame, ready);
if (parent_frame != nullptr) { if (parent_frame != nullptr) {
// The completion of frame may cause completions in its parent frame. // The completion of frame may cause completions in its parent frame.
@ -376,16 +379,13 @@ void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
} }
} }
void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item, void PropagatorState::FrameState::ActivateNodesFastPath(
const bool is_dead, const NodeItem* item, const bool is_dead, IterationState* iter_state,
int64 iter, EntryVector* outputs, TaggedNodeSeq* ready) {
EntryVector* outputs,
TaggedNodeSeq* ready) {
// If we know that none of the item's edge destinations require special // If we know that none of the item's edge destinations require special
// handling (i.e. none of the nodes is a merge or control trigger node), we // handling (i.e. none of the nodes is a merge or control trigger node), we
// can take a fast path that avoids accessing the destination NodeItem. // can take a fast path that avoids accessing the destination NodeItem.
const GraphView& gview = immutable_state.graph_view(); const GraphView& gview = immutable_state.graph_view();
IterationState* iter_state = GetIteration(iter);
// Add dst to the ready queue if it's ready // Add dst to the ready queue if it's ready
// //
@ -398,7 +398,7 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
TaggedNode& t = ready->emplace_back(); \ TaggedNode& t = ready->emplace_back(); \
t.node_item = dst_item; \ t.node_item = dst_item; \
t.input_frame = this; \ t.input_frame = this; \
t.input_iter = iter; \ t.input_iter = iter_state; \
t.is_dead = adjust_result.any_dead; \ t.is_dead = adjust_result.any_dead; \
iter_state->outstanding_ops++; \ iter_state->outstanding_ops++; \
} \ } \
@ -436,23 +436,20 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
#undef MAYBE_ADD_TO_READY #undef MAYBE_ADD_TO_READY
} }
void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item, void PropagatorState::FrameState::ActivateNodesSlowPath(
const bool is_dead, const NodeItem* item, const bool is_dead, IterationState* iter_state,
int64 iter, EntryVector* outputs, TaggedNodeSeq* ready) {
EntryVector* outputs,
TaggedNodeSeq* ready) {
// If any of the edge destinations is a merge or a control trigger node, // If any of the edge destinations is a merge or a control trigger node,
// we need to read each destination NodeItem to determine what action // we need to read each destination NodeItem to determine what action
// to take. // to take.
const GraphView& gview = immutable_state.graph_view(); const GraphView& gview = immutable_state.graph_view();
IterationState* iter_state = GetIteration(iter);
auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
bool dst_ready, bool dst_dead) { bool dst_ready, bool dst_dead) {
// Add dst to the ready queue if it's ready // Add dst to the ready queue if it's ready
if (dst_ready) { if (dst_ready) {
if (dst_item->is_control_trigger) dst_dead = false; if (dst_item->is_control_trigger) dst_dead = false;
ready->emplace_back(dst_item, this, iter, dst_dead); ready->emplace_back(dst_item, this, iter_state, dst_dead);
iter_state->outstanding_ops++; iter_state->outstanding_ops++;
} }
}; };
@ -551,17 +548,18 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
} }
void PropagatorState::FrameState::ActivateNodes(const NodeItem* item, void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
const bool is_dead, int64 iter, const bool is_dead,
IterationState* iter_state,
EntryVector* outputs, EntryVector* outputs,
TaggedNodeSeq* ready) { TaggedNodeSeq* ready) {
if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
ActivateNodesSlowPath(item, is_dead, iter, outputs, ready); ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
} else { } else {
ActivateNodesFastPath(item, is_dead, iter, outputs, ready); ActivateNodesFastPath(item, is_dead, iter_state, outputs, ready);
} }
} }
void PropagatorState::FrameState::ActivateNexts(int64 iter, void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state,
TaggedNodeSeq* ready) { TaggedNodeSeq* ready) {
// Propagate the deferred NextIteration nodes to the new iteration. // Propagate the deferred NextIteration nodes to the new iteration.
for (auto& node_entry : next_iter_roots) { for (auto& node_entry : next_iter_roots) {
@ -569,12 +567,12 @@ void PropagatorState::FrameState::ActivateNexts(int64 iter,
const Entry& entry = node_entry.second; const Entry& entry = node_entry.second;
const bool is_dead = entry.state == Entry::State::NO_VALUE; const bool is_dead = entry.state == Entry::State::NO_VALUE;
EntryVector outputs{entry}; EntryVector outputs{entry};
ActivateNodes(item, is_dead, iter, &outputs, ready); ActivateNodes(item, is_dead, iter_state, &outputs, ready);
} }
next_iter_roots.clear(); next_iter_roots.clear();
} }
void PropagatorState::FrameState::ActivateLoopInvs(int64 iter, void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state,
TaggedNodeSeq* ready) { TaggedNodeSeq* ready) {
// Propagate loop invariants to the new iteration. // Propagate loop invariants to the new iteration.
for (auto& node_entry : inv_values) { for (auto& node_entry : inv_values) {
@ -582,7 +580,7 @@ void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
const Entry& entry = node_entry.second; const Entry& entry = node_entry.second;
const bool is_dead = entry.state == Entry::State::NO_VALUE; const bool is_dead = entry.state == Entry::State::NO_VALUE;
EntryVector outputs{entry}; EntryVector outputs{entry};
ActivateNodes(item, is_dead, iter, &outputs, ready); ActivateNodes(item, is_dead, iter_state, &outputs, ready);
} }
} }
@ -596,33 +594,32 @@ void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
const bool is_dead = entry.state == Entry::State::NO_VALUE; const bool is_dead = entry.state == Entry::State::NO_VALUE;
for (int i = 0; i <= iteration_count; ++i) { for (int i = 0; i <= iteration_count; ++i) {
EntryVector outputs{entry}; EntryVector outputs{entry};
ActivateNodes(item, is_dead, i, &outputs, ready); ActivateNodes(item, is_dead, GetIteration(i), &outputs, ready);
} }
} }
bool PropagatorState::FrameState::IsIterationDone(int64 iter) { bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) {
IterationState* iter_state = GetIteration(iter);
if (iter_state->outstanding_ops == 0 && if (iter_state->outstanding_ops == 0 &&
iter_state->outstanding_frame_count == 0) { iter_state->outstanding_frame_count == 0) {
if (iter == 0) { if (iter_state->iter_num == 0) {
// The enclosing frame has no pending input. // The enclosing frame has no pending input.
return num_pending_inputs == 0; return num_pending_inputs == 0;
} else { } else {
// The preceding iteration is deleted (and therefore done). // The preceding iteration is deleted (and therefore done).
return (GetIteration(iter - 1) == nullptr); return (GetIteration(iter_state->iter_num - 1) == nullptr);
} }
} }
return false; return false;
} }
void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) { PropagatorState::IterationState*
PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
iteration_count++; iteration_count++;
const int64 next_iter = iteration_count;
// Initialize the next iteration. // Initialize the next iteration.
IterationState* iter_state = IterationState* next_iter =
new IterationState(pending_counts, total_input_tensors); new IterationState(iteration_count, pending_counts, total_input_tensors);
SetIteration(next_iter, iter_state); SetIteration(iteration_count, next_iter);
num_outstanding_iterations++; num_outstanding_iterations++;
dead_exits.clear(); dead_exits.clear();
@ -631,14 +628,15 @@ void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
// Activate the loop invariants in the new iteration. // Activate the loop invariants in the new iteration.
ActivateLoopInvs(next_iter, ready); ActivateLoopInvs(next_iter, ready);
return next_iter;
} }
bool PropagatorState::FrameState::CleanupIterations(int64 iter, bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
TaggedNodeSeq* ready) { TaggedNodeSeq* ready) {
int64 curr_iter = iter; int64 curr_iter = iter_state->iter_num;
while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) { while (curr_iter <= iteration_count && IsIterationDone(iter_state)) {
// Delete the iteration curr_iter. delete iter_state;
delete GetIteration(curr_iter);
SetIteration(curr_iter, nullptr); SetIteration(curr_iter, nullptr);
--num_outstanding_iterations; --num_outstanding_iterations;
++curr_iter; ++curr_iter;
@ -648,6 +646,10 @@ bool PropagatorState::FrameState::CleanupIterations(int64 iter,
if (!next_iter_roots.empty()) { if (!next_iter_roots.empty()) {
IncrementIteration(ready); IncrementIteration(ready);
} }
if (curr_iter <= iteration_count) {
iter_state = GetIteration(curr_iter);
}
} }
return IsFrameDone(); return IsFrameDone();
} }
@ -677,21 +679,21 @@ void PropagatorState::FrameState::SetIteration(int64 iter,
// Decrement the outstanding op count and clean up the iterations in the // Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done. // frame. Return true iff the execution of the frame is done.
bool PropagatorState::FrameState::DecrementOutstandingOps( bool PropagatorState::FrameState::DecrementOutstandingOps(
int64 iter, TaggedNodeSeq* ready) { IterationState* iter_state, TaggedNodeSeq* ready) {
mutex_lock l(mu); mutex_lock l(mu);
return DecrementOutstandingOpsLocked(iter, ready); return DecrementOutstandingOpsLocked(iter_state, ready);
} }
// Decrement the outstanding op count and clean up the iterations in the // Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done. // frame. Return true iff the execution of the frame is done.
bool PropagatorState::FrameState::DecrementOutstandingOpsLocked( bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { IterationState* iter_state, TaggedNodeSeq* ready)
IterationState* istate = GetIteration(iter); TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
istate->outstanding_ops--; iter_state->outstanding_ops--;
if (istate->outstanding_ops != 0) { if (iter_state->outstanding_ops != 0) {
return false; return false;
} else { } else {
return CleanupIterations(iter, ready); return CleanupIterations(iter_state, ready);
} }
} }

View File

@ -49,8 +49,10 @@ class PropagatorState {
~PropagatorState(); ~PropagatorState();
private: private:
// Forward declaration so that `TaggedNode` can include a `FrameState*`. // Forward declaration so that `TaggedNode` can include a `FrameState*` and an
// `IterationState*`.
struct FrameState; struct FrameState;
struct IterationState;
public: public:
// A `TaggedNode` corresponds to a single invocation of a node's kernel, // A `TaggedNode` corresponds to a single invocation of a node's kernel,
@ -59,12 +61,12 @@ class PropagatorState {
struct TaggedNode { struct TaggedNode {
const NodeItem* node_item; const NodeItem* node_item;
FrameState* input_frame; FrameState* input_frame;
int64 input_iter; IterationState* input_iter;
bool is_dead; bool is_dead;
TaggedNode() = default; TaggedNode() = default;
TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter, TaggedNode(const NodeItem* node_item, FrameState* in_frame,
bool dead) IterationState* in_iter, bool dead)
: node_item(node_item), : node_item(node_item),
input_frame(in_frame), input_frame(in_frame),
input_iter(in_iter), input_iter(in_iter),
@ -73,7 +75,7 @@ class PropagatorState {
const NodeItem& get_node_item() const { return *node_item; } const NodeItem& get_node_item() const { return *node_item; }
bool get_is_dead() const { return is_dead; } bool get_is_dead() const { return is_dead; }
int64 get_iter_num() const { return input_iter; } int64 get_iter_num() const;
}; };
// A drop-in replacement for std::deque<TaggedNode>. We typically don't // A drop-in replacement for std::deque<TaggedNode>. We typically don't
@ -116,16 +118,18 @@ class PropagatorState {
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
private: private:
// The state of an iteration in a particular frame.
struct IterationState { struct IterationState {
explicit IterationState(const PendingCounts* pending_counts, explicit IterationState(int64 iter_num, const PendingCounts* pending_counts,
int total_input_tensors) int total_input_tensors)
: input_tensors(new Entry[total_input_tensors]), : iter_num(iter_num),
input_tensors(new Entry[total_input_tensors]),
outstanding_ops(0), outstanding_ops(0),
outstanding_frame_count(0), outstanding_frame_count(0),
counts(*pending_counts) { // Initialize with copy of *pending_counts counts(*pending_counts) { // Initialize with copy of *pending_counts
} }
// The state of an iteration. const int64 iter_num; // The index of this iteration in the enclosing loop.
// One copy per iteration. For iteration k, i-th node's j-th input is in // One copy per iteration. For iteration k, i-th node's j-th input is in
// input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
@ -221,10 +225,10 @@ class PropagatorState {
// frame_name. // frame_name.
uint64 frame_id; uint64 frame_id;
// The iteration id of its parent frame when this frame is created. // The iteration state of its parent frame when this frame is created.
// -1 if there is no parent frame. The frame_name/parent_iter pair // nullptr if there is no parent frame. The frame_name/parent_iter pair
// uniquely identifies this FrameState. // uniquely identifies this FrameState.
int64 parent_iter = -1; IterationState* parent_iter = nullptr;
// The FrameState of its parent frame. // The FrameState of its parent frame.
FrameState* parent_frame = nullptr; FrameState* parent_frame = nullptr;
@ -291,28 +295,33 @@ class PropagatorState {
// Decrement the outstanding op count and clean up the iterations in the // Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done. // frame. Return true iff the execution of the frame is done.
bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready); bool DecrementOutstandingOps(IterationState* iter_state,
TaggedNodeSeq* ready);
// Decrement the outstanding op count and clean up the iterations in the // Decrement the outstanding op count and clean up the iterations in the
// frame. Return true iff the execution of the frame is done. // frame. Return true iff the execution of the frame is done.
bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready); bool DecrementOutstandingOpsLocked(IterationState* iter_state,
TaggedNodeSeq* ready);
// Returns true if the computation in the frame is completed. // Returns true if the computation in the frame is completed.
bool IsFrameDone(); bool IsFrameDone();
// Returns true if the iteration of the frame is completed. // Returns true if the iteration of the frame is completed.
bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); bool IsIterationDone(IterationState* iter_state)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Increments the iteration id. If this is a new iteration, initialize it. // Increments the iteration id. If this is a new iteration, initialize it.
void IncrementIteration(TaggedNodeSeq* ready) //
// Returns a pointer to the new iteration.
IterationState* IncrementIteration(TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Activate all the deferred NextIteration nodes in a new iteration. // Activate all the deferred NextIteration nodes in a new iteration.
void ActivateNexts(int64 iter, TaggedNodeSeq* ready) void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Activate all the current loop invariants in a new iteration. // Activate all the current loop invariants in a new iteration.
void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready) void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Add a new loop invariant and make it available to all active // Add a new loop invariant and make it available to all active
@ -322,12 +331,12 @@ class PropagatorState {
// Activate the successors of a node. Contents of *outputs are left in an // Activate the successors of a node. Contents of *outputs are left in an
// indeterminate state after returning from this method. // indeterminate state after returning from this method.
void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, void ActivateNodes(const NodeItem* item, const bool is_dead,
EntryVector* outputs, TaggedNodeSeq* ready) IterationState* iter_state, EntryVector* outputs,
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
// Cleanup iterations of this frame starting from iteration iter. // Cleanup iterations of this frame starting from the given iteration.
bool CleanupIterations(int64 iter, TaggedNodeSeq* ready) bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
void DumpIterationState(PropagatorState* parent) { void DumpIterationState(PropagatorState* parent) {
@ -350,12 +359,12 @@ class PropagatorState {
private: private:
// REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
void ActivateNodesFastPath(const NodeItem* item, const bool is_dead, void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
int64 iter, EntryVector* outputs, IterationState* iter_state, EntryVector* outputs,
TaggedNodeSeq* ready) TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead, void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
int64 iter, EntryVector* outputs, IterationState* iter_state, EntryVector* outputs,
TaggedNodeSeq* ready) TaggedNodeSeq* ready)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu); TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
}; };
@ -379,13 +388,13 @@ class PropagatorState {
// same address while the iteration is live. // same address while the iteration is live.
Entry* GetInputTensors(const TaggedNode& tagged_node) const Entry* GetInputTensors(const TaggedNode& tagged_node) const
TF_NO_THREAD_SAFETY_ANALYSIS { TF_NO_THREAD_SAFETY_ANALYSIS {
return tagged_node.input_frame->GetIteration(tagged_node.input_iter) return tagged_node.input_iter->input_tensors +
->input_tensors +
tagged_node.node_item->input_start; tagged_node.node_item->input_start;
} }
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
return {tagged_node.input_frame->frame_id, tagged_node.input_iter}; return {tagged_node.input_frame->frame_id,
tagged_node.input_iter->iter_num};
} }
// Provide debugging output of the state of the executor. // Provide debugging output of the state of the executor.
@ -397,9 +406,8 @@ class PropagatorState {
// optional debugging support. // optional debugging support.
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
mutex_lock l(tagged_node.input_frame->mu); mutex_lock l(tagged_node.input_frame->mu);
tagged_node.input_frame->GetIteration(tagged_node.input_iter) tagged_node.input_iter->mark_started(
->mark_started( immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
} }
} }
@ -408,16 +416,15 @@ class PropagatorState {
// optional debugging support. // optional debugging support.
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
mutex_lock l(tagged_node.input_frame->mu); mutex_lock l(tagged_node.input_frame->mu);
tagged_node.input_frame->GetIteration(tagged_node.input_iter) tagged_node.input_iter->mark_completed(
->mark_completed( immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
} }
} }
private: private:
// Find an existing or create a new child frame in the frame 'frame' at // Find an existing or create a new child frame in the frame 'frame' at
// iteration 'iter'. // iteration 'iter'.
void FindOrCreateChildFrame(FrameState* frame, int64 iter, void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state,
const NodeItem& node_item, FrameState** child); const NodeItem& node_item, FrameState** child);
// Delete a frame. Called when the frame is done. // Delete a frame. Called when the frame is done.
@ -425,7 +432,7 @@ class PropagatorState {
// Cleanup frames and iterations starting from frame/iter. Called when // Cleanup frames and iterations starting from frame/iter. Called when
// a child frame is done. // a child frame is done.
void CleanupFramesIterations(FrameState* frame, int64 iter, void CleanupFramesIterations(FrameState* frame, IterationState* iter_state,
TaggedNodeSeq* ready); TaggedNodeSeq* ready);
// Provide debugging output about an outstanding iteration in the executor. // Provide debugging output about an outstanding iteration in the executor.
@ -450,6 +457,10 @@ class PropagatorState {
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState); TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
}; };
inline int64 PropagatorState::TaggedNode::get_iter_num() const {
return input_iter->iter_num;
}
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_