[PropagatorState] Use IterationState*
instead of int64
iteration ID in TaggedNode
.
This change avoids making frequent calls (at least 2 per kernel dispatched) to `FrameState::GetIteration()`, which perform integer division. Since this work happens under the `FrameState` mutex, reducing the work done here should also ease contention on that mutex slightly. This change adds microbenchmarks for both the functional (`WhileOp`) and lowered (`Switch`/`Merge`/`Enter`/`Exit`) loop implementations. To support these microbenchmarks, it adds support to "kernel_benchmark_testlib" for running functions during the benchmark. PiperOrigin-RevId: 310175986 Change-Id: I9af0781e341e5d8cb19729de375078f9f4237c54
This commit is contained in:
parent
17d5c85577
commit
ae2a0e5c47
@ -2297,6 +2297,10 @@ tf_cc_test(
|
||||
":core",
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
@ -2307,6 +2311,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:control_flow_ops",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:functional_ops",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//tensorflow/core/kernels:random_ops",
|
||||
"//tensorflow/core/kernels:state",
|
||||
|
@ -17,15 +17,24 @@ limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||
#include "tensorflow/core/common_runtime/process_util.h"
|
||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -532,4 +541,115 @@ static void BM_FeedInputFetchOutput(int iters) {
|
||||
}
|
||||
BENCHMARK(BM_FeedInputFetchOutput);
|
||||
|
||||
// Defines a graph to perform the following computation:
|
||||
//
|
||||
// i = 0
|
||||
// while (i < loop_iters)
|
||||
// i += 1;
|
||||
//
|
||||
// ...using the functional `WhileOp` (if `lower` is false) or the
|
||||
// `Switch`/`Merge`-style of control flow (if `lower` is true).
|
||||
static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||
testing::StopTiming();
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
// Add test functions for cond and body.
|
||||
FunctionDefLibrary f_lib_proto;
|
||||
|
||||
// Define the loop body as a function: `x = x + 1`.
|
||||
const Tensor one_t = test::AsScalar<int32>(1);
|
||||
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
||||
// Name
|
||||
"XPlusOne",
|
||||
// Args
|
||||
{"x: int32"},
|
||||
// Return values
|
||||
{"y: int32"},
|
||||
// Attr def
|
||||
{},
|
||||
// Nodes
|
||||
{
|
||||
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}},
|
||||
{{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}},
|
||||
});
|
||||
|
||||
// Define the loop condition as a function: `x < loop_iters`.
|
||||
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
|
||||
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
||||
// Name
|
||||
"LessThanOrEqualToN",
|
||||
// Args
|
||||
{"x: int32"},
|
||||
// Return values
|
||||
{"z: bool"},
|
||||
// Attr def
|
||||
{},
|
||||
// Nodes
|
||||
{
|
||||
{{"N"}, "Const", {}, {{"value", loop_iters_t}, {"dtype", DT_INT32}}},
|
||||
{{"z"}, "LessEqual", {"x", "N"}, {{"T", DT_INT32}}},
|
||||
});
|
||||
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||
auto a = ops::Const(root.WithOpName("A"), 0, {});
|
||||
Node* while_node;
|
||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||
AttrValue int32_attr;
|
||||
int32_attr.set_type(DT_INT32);
|
||||
AttrValue cond_func;
|
||||
cond_func.mutable_func()->set_name("LessThanOrEqualToN");
|
||||
AttrValue body_func;
|
||||
body_func.mutable_func()->set_name("XPlusOne");
|
||||
TF_ASSERT_OK(
|
||||
NodeBuilder("while", "While", &root.graph()->flib_def())
|
||||
.Input(inputs)
|
||||
.Attr("T", {DT_INT32})
|
||||
.Attr("cond", cond_func)
|
||||
.Attr("body", body_func)
|
||||
.Attr("parallel_iterations", 100)
|
||||
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||
.Finalize(root.graph(), &while_node));
|
||||
auto c = ops::Identity(
|
||||
root.WithOpName("C").WithControlDependencies(Output(while_node)),
|
||||
Output(while_node));
|
||||
TF_ASSERT_OK(root.DoShapeInference(while_node));
|
||||
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||
|
||||
if (lower) {
|
||||
FunctionLibraryDefinition flib_def(graph->flib_def());
|
||||
GraphOptimizationPassOptions opt_options;
|
||||
SessionOptions session_options;
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_do_function_inlining(true);
|
||||
opt_options.session_options = &session_options;
|
||||
opt_options.graph = &graph;
|
||||
opt_options.flib_def = &flib_def;
|
||||
LowerFunctionalOpsPass pass;
|
||||
TF_ASSERT_OK(pass.Run(opt_options));
|
||||
}
|
||||
|
||||
FixupSourceAndSinkEdges(graph.get());
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", graph.release()).Run(iters);
|
||||
}
|
||||
|
||||
static void BM_LoweredWhileLoop(int iters, int loop_iters) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true);
|
||||
}
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(0);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(1);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(10);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(100);
|
||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(1000);
|
||||
|
||||
static void BM_FunctionalWhileLoop(int iters, int loop_iters) {
|
||||
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false);
|
||||
}
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(0);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(10);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(100);
|
||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000);
|
||||
} // namespace tensorflow
|
||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/executor_factory.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/local_device.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -62,8 +64,10 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
||||
// Allow NewDevice to allocate a new threadpool with different number of
|
||||
// threads for each new benchmark.
|
||||
LocalDevice::set_use_global_threadpool(false);
|
||||
device_ =
|
||||
DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0");
|
||||
|
||||
device_mgr_ = absl::make_unique<StaticDeviceMgr>(
|
||||
DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"));
|
||||
device_ = device_mgr_->ListDevices()[0];
|
||||
CHECK(device_) << "Could not create a " << device << " device";
|
||||
|
||||
pool_ =
|
||||
@ -81,14 +85,24 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
||||
|
||||
const int graph_def_version = g->versions().producer();
|
||||
|
||||
flib_def_ = absl::make_unique<FunctionLibraryDefinition>(g->flib_def());
|
||||
|
||||
pflr_ = std::unique_ptr<ProcessFunctionLibraryRuntime>(
|
||||
new ProcessFunctionLibraryRuntime(
|
||||
device_mgr_.get(), Env::Default(), nullptr, graph_def_version,
|
||||
flib_def_.get(), OptimizerOptions(), pool_, nullptr, nullptr, nullptr,
|
||||
Rendezvous::Factory()));
|
||||
|
||||
flr_ = pflr_->GetFLR(device_->name());
|
||||
|
||||
LocalExecutorParams params;
|
||||
params.device = device_.get();
|
||||
params.function_library = nullptr;
|
||||
params.device = device_;
|
||||
params.function_library = flr_;
|
||||
params.create_kernel = [this, graph_def_version](
|
||||
const std::shared_ptr<const NodeProperties>& props,
|
||||
OpKernel** kernel) {
|
||||
return CreateNonCachedKernel(device_.get(), nullptr, props,
|
||||
graph_def_version, kernel);
|
||||
return CreateNonCachedKernel(device_, flr_, props, graph_def_version,
|
||||
kernel);
|
||||
};
|
||||
params.delete_kernel = [](OpKernel* kernel) {
|
||||
DeleteNonCachedKernel(kernel);
|
||||
@ -109,11 +123,12 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
||||
Benchmark::~Benchmark() {
|
||||
if (device_) {
|
||||
rendez_->Unref();
|
||||
// We delete `exec_` before `device_` because the `exec_` destructor may
|
||||
// We delete `exec_` before `device_mgr_` because the `exec_` destructor may
|
||||
// run kernel destructors that may attempt to access state borrowed from
|
||||
// `device_`, such as the resource manager.
|
||||
// `device_mgr_`, such as the resource manager.
|
||||
exec_.reset();
|
||||
device_.reset();
|
||||
pflr_.reset();
|
||||
device_mgr_.reset();
|
||||
delete pool_;
|
||||
}
|
||||
}
|
||||
|
@ -29,7 +29,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class FunctionLibraryRuntime;
|
||||
class ProcessFunctionLibraryRuntime;
|
||||
struct SessionOptions;
|
||||
class StaticDeviceMgr;
|
||||
|
||||
namespace test {
|
||||
|
||||
@ -55,9 +58,13 @@ class Benchmark {
|
||||
const std::vector<string>& outputs, int iters);
|
||||
|
||||
private:
|
||||
thread::ThreadPool* pool_ = nullptr;
|
||||
std::unique_ptr<Device> device_ = nullptr;
|
||||
thread::ThreadPool* pool_ = nullptr; // Not owned.
|
||||
Device* device_ = nullptr; // Not owned.
|
||||
Rendezvous* rendez_ = nullptr;
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||
FunctionLibraryRuntime* flr_; // Not owned.
|
||||
std::unique_ptr<Executor> exec_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
|
||||
|
@ -37,7 +37,7 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
||||
|
||||
// Initialize iteration 0.
|
||||
root_frame_->SetIteration(
|
||||
0, new PropagatorState::IterationState(root_frame_->pending_counts,
|
||||
0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
|
||||
root_frame_->total_input_tensors));
|
||||
|
||||
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
||||
@ -51,12 +51,13 @@ PropagatorState::~PropagatorState() {
|
||||
|
||||
void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||
TaggedNodeSeq* ready) {
|
||||
mutex_lock l(root_frame_->mu);
|
||||
IterationState* root_iter = root_frame_->GetIteration(0);
|
||||
for (const NodeItem* item : roots) {
|
||||
DCHECK_EQ(item->num_inputs, 0);
|
||||
ready->emplace_back(item, root_frame_, 0, false);
|
||||
ready->emplace_back(item, root_frame_, root_iter, false);
|
||||
}
|
||||
mutex_lock l(root_frame_->mu);
|
||||
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
|
||||
root_iter->outstanding_ops = ready->size();
|
||||
}
|
||||
|
||||
void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
@ -75,7 +76,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
|
||||
const NodeItem* const item = tagged_node.node_item;
|
||||
FrameState* const input_frame = tagged_node.input_frame;
|
||||
const int64 input_iter = tagged_node.input_iter;
|
||||
IterationState* const input_iter = tagged_node.input_iter;
|
||||
const bool is_dead = tagged_node.is_dead;
|
||||
|
||||
// Propagates outputs along out edges, and puts newly ready nodes
|
||||
@ -83,7 +84,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
DCHECK(ready->empty());
|
||||
bool is_frame_done = false;
|
||||
FrameState* output_frame = input_frame;
|
||||
int64 output_iter = input_iter;
|
||||
IterationState* output_iter = input_iter;
|
||||
|
||||
if (!item->is_enter_exit_or_next_iter) {
|
||||
// Fast path for nodes types that don't need special handling
|
||||
@ -95,9 +96,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
|
||||
} else if (item->is_enter) {
|
||||
FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
|
||||
output_iter = 0;
|
||||
{
|
||||
mutex_lock l(output_frame->mu);
|
||||
output_iter = output_frame->GetIteration(0);
|
||||
if (item->is_constant_enter) {
|
||||
// Propagate to all active iterations if this is a loop invariant.
|
||||
output_frame->AddLoopInv(item, (*outputs)[0], ready);
|
||||
@ -111,7 +112,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
if (is_dead) {
|
||||
mutex_lock l(input_frame->mu);
|
||||
// Stop and remember this node if it is a dead exit.
|
||||
if (input_iter == input_frame->iteration_count) {
|
||||
if (input_iter->iter_num == input_frame->iteration_count) {
|
||||
input_frame->dead_exits.push_back(item);
|
||||
}
|
||||
is_frame_done =
|
||||
@ -132,7 +133,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
// Stop the deadness propagation.
|
||||
output_frame = nullptr;
|
||||
} else {
|
||||
if (input_iter == input_frame->iteration_count &&
|
||||
if (input_iter->iter_num == input_frame->iteration_count &&
|
||||
input_frame->num_outstanding_iterations ==
|
||||
input_frame->max_parallel_iterations) {
|
||||
// Reached the maximum for parallel iterations.
|
||||
@ -140,10 +141,11 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
output_frame = nullptr;
|
||||
} else {
|
||||
// If this is a new iteration, start it.
|
||||
if (input_iter == input_frame->iteration_count) {
|
||||
input_frame->IncrementIteration(ready);
|
||||
if (input_iter->iter_num == input_frame->iteration_count) {
|
||||
output_iter = input_frame->IncrementIteration(ready);
|
||||
} else {
|
||||
output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
|
||||
}
|
||||
output_iter = input_iter + 1;
|
||||
}
|
||||
}
|
||||
if (output_frame != nullptr) {
|
||||
@ -159,7 +161,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||
// completion of this node makes its frame completed.
|
||||
if (is_frame_done) {
|
||||
FrameState* parent_frame = input_frame->parent_frame;
|
||||
const int64 parent_iter = input_frame->parent_iter;
|
||||
IterationState* parent_iter = input_frame->parent_iter;
|
||||
DeleteFrame(input_frame, ready);
|
||||
if (parent_frame != nullptr) {
|
||||
// The completion of frame may cause completions in its parent frame.
|
||||
@ -217,7 +219,8 @@ void PropagatorState::DumpState() {
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
||||
IterationState* iter_state,
|
||||
const NodeItem& node_item,
|
||||
FrameState** child) {
|
||||
// Get the child frame name.
|
||||
@ -225,8 +228,8 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
|
||||
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
|
||||
<< node_item.kernel->name();
|
||||
const string child_name =
|
||||
strings::StrCat(frame->frame_name, ";", iter, ";", enter_name);
|
||||
const string child_name = strings::StrCat(
|
||||
frame->frame_name, ";", iter_state->iter_num, ";", enter_name);
|
||||
|
||||
{
|
||||
mutex_lock executor_lock(mu_);
|
||||
@ -251,14 +254,14 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
temp->frame_name = child_name;
|
||||
temp->frame_id = Hash64(child_name);
|
||||
temp->parent_frame = frame;
|
||||
temp->parent_iter = iter;
|
||||
temp->parent_iter = iter_state;
|
||||
temp->InitializeFrameInfo(enter_name);
|
||||
|
||||
// Initialize iteration 0.
|
||||
{
|
||||
mutex_lock l(temp->mu);
|
||||
temp->SetIteration(
|
||||
0, new IterationState(temp->pending_counts, temp->total_input_tensors));
|
||||
temp->SetIteration(0, new IterationState(0, temp->pending_counts,
|
||||
temp->total_input_tensors));
|
||||
}
|
||||
|
||||
{
|
||||
@ -268,7 +271,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
*child = it->second;
|
||||
} else {
|
||||
mutex_lock frame_lock(frame->mu);
|
||||
frame->GetIteration(iter)->outstanding_frame_count++;
|
||||
iter_state->outstanding_frame_count++;
|
||||
outstanding_frames_[child_name] = temp;
|
||||
*child = temp;
|
||||
temp = nullptr;
|
||||
@ -280,20 +283,19 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
// First, propagate dead_exits (if any) to the parent frame.
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
const int64 parent_iter = frame->parent_iter;
|
||||
IterationState* parent_iter_state = frame->parent_iter;
|
||||
if (parent_frame != nullptr) {
|
||||
mutex_lock parent_frame_lock(parent_frame->mu);
|
||||
// Propagate all the dead exits to the parent frame.
|
||||
mutex_lock this_frame_lock(frame->mu);
|
||||
|
||||
for (const NodeItem* item : frame->dead_exits) {
|
||||
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
|
||||
|
||||
auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
|
||||
bool dst_dead) {
|
||||
if (dst_ready) {
|
||||
if (dst_item.is_control_trigger) dst_dead = false;
|
||||
ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead);
|
||||
ready->emplace_back(&dst_item, parent_frame, parent_iter_state,
|
||||
dst_dead);
|
||||
parent_iter_state->outstanding_ops++;
|
||||
}
|
||||
};
|
||||
@ -356,17 +358,18 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||
delete frame;
|
||||
}
|
||||
|
||||
void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
void PropagatorState::CleanupFramesIterations(FrameState* frame,
|
||||
IterationState* iter_state,
|
||||
TaggedNodeSeq* ready) {
|
||||
bool is_frame_done = false;
|
||||
{
|
||||
mutex_lock frame_lock(frame->mu);
|
||||
frame->GetIteration(iter)->outstanding_frame_count--;
|
||||
is_frame_done = frame->CleanupIterations(iter, ready);
|
||||
iter_state->outstanding_frame_count--;
|
||||
is_frame_done = frame->CleanupIterations(iter_state, ready);
|
||||
}
|
||||
if (is_frame_done) {
|
||||
FrameState* parent_frame = frame->parent_frame;
|
||||
const int64 parent_iter = frame->parent_iter;
|
||||
IterationState* parent_iter = frame->parent_iter;
|
||||
DeleteFrame(frame, ready);
|
||||
if (parent_frame != nullptr) {
|
||||
// The completion of frame may cause completions in its parent frame.
|
||||
@ -376,16 +379,13 @@ void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
||||
const bool is_dead,
|
||||
int64 iter,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
void PropagatorState::FrameState::ActivateNodesFastPath(
|
||||
const NodeItem* item, const bool is_dead, IterationState* iter_state,
|
||||
EntryVector* outputs, TaggedNodeSeq* ready) {
|
||||
// If we know that none of the item's edge destinations require special
|
||||
// handling (i.e. none of the nodes is a merge or control trigger node), we
|
||||
// can take a fast path that avoids accessing the destination NodeItem.
|
||||
const GraphView& gview = immutable_state.graph_view();
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
|
||||
// Add dst to the ready queue if it's ready
|
||||
//
|
||||
@ -398,7 +398,7 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
||||
TaggedNode& t = ready->emplace_back(); \
|
||||
t.node_item = dst_item; \
|
||||
t.input_frame = this; \
|
||||
t.input_iter = iter; \
|
||||
t.input_iter = iter_state; \
|
||||
t.is_dead = adjust_result.any_dead; \
|
||||
iter_state->outstanding_ops++; \
|
||||
} \
|
||||
@ -436,23 +436,20 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
||||
#undef MAYBE_ADD_TO_READY
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
||||
const bool is_dead,
|
||||
int64 iter,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
void PropagatorState::FrameState::ActivateNodesSlowPath(
|
||||
const NodeItem* item, const bool is_dead, IterationState* iter_state,
|
||||
EntryVector* outputs, TaggedNodeSeq* ready) {
|
||||
// If any of the edge destinations is a merge or a control trigger node,
|
||||
// we need to read each destination NodeItem to determine what action
|
||||
// to take.
|
||||
const GraphView& gview = immutable_state.graph_view();
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
|
||||
auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
|
||||
bool dst_ready, bool dst_dead) {
|
||||
// Add dst to the ready queue if it's ready
|
||||
if (dst_ready) {
|
||||
if (dst_item->is_control_trigger) dst_dead = false;
|
||||
ready->emplace_back(dst_item, this, iter, dst_dead);
|
||||
ready->emplace_back(dst_item, this, iter_state, dst_dead);
|
||||
iter_state->outstanding_ops++;
|
||||
}
|
||||
};
|
||||
@ -551,17 +548,18 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
|
||||
const bool is_dead, int64 iter,
|
||||
const bool is_dead,
|
||||
IterationState* iter_state,
|
||||
EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) {
|
||||
if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
|
||||
ActivateNodesSlowPath(item, is_dead, iter, outputs, ready);
|
||||
ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
|
||||
} else {
|
||||
ActivateNodesFastPath(item, is_dead, iter, outputs, ready);
|
||||
ActivateNodesFastPath(item, is_dead, iter_state, outputs, ready);
|
||||
}
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateNexts(int64 iter,
|
||||
void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state,
|
||||
TaggedNodeSeq* ready) {
|
||||
// Propagate the deferred NextIteration nodes to the new iteration.
|
||||
for (auto& node_entry : next_iter_roots) {
|
||||
@ -569,12 +567,12 @@ void PropagatorState::FrameState::ActivateNexts(int64 iter,
|
||||
const Entry& entry = node_entry.second;
|
||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, iter, &outputs, ready);
|
||||
ActivateNodes(item, is_dead, iter_state, &outputs, ready);
|
||||
}
|
||||
next_iter_roots.clear();
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
|
||||
void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state,
|
||||
TaggedNodeSeq* ready) {
|
||||
// Propagate loop invariants to the new iteration.
|
||||
for (auto& node_entry : inv_values) {
|
||||
@ -582,7 +580,7 @@ void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
|
||||
const Entry& entry = node_entry.second;
|
||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, iter, &outputs, ready);
|
||||
ActivateNodes(item, is_dead, iter_state, &outputs, ready);
|
||||
}
|
||||
}
|
||||
|
||||
@ -596,33 +594,32 @@ void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
|
||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||
for (int i = 0; i <= iteration_count; ++i) {
|
||||
EntryVector outputs{entry};
|
||||
ActivateNodes(item, is_dead, i, &outputs, ready);
|
||||
ActivateNodes(item, is_dead, GetIteration(i), &outputs, ready);
|
||||
}
|
||||
}
|
||||
|
||||
bool PropagatorState::FrameState::IsIterationDone(int64 iter) {
|
||||
IterationState* iter_state = GetIteration(iter);
|
||||
bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) {
|
||||
if (iter_state->outstanding_ops == 0 &&
|
||||
iter_state->outstanding_frame_count == 0) {
|
||||
if (iter == 0) {
|
||||
if (iter_state->iter_num == 0) {
|
||||
// The enclosing frame has no pending input.
|
||||
return num_pending_inputs == 0;
|
||||
} else {
|
||||
// The preceding iteration is deleted (and therefore done).
|
||||
return (GetIteration(iter - 1) == nullptr);
|
||||
return (GetIteration(iter_state->iter_num - 1) == nullptr);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
||||
PropagatorState::IterationState*
|
||||
PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
||||
iteration_count++;
|
||||
const int64 next_iter = iteration_count;
|
||||
|
||||
// Initialize the next iteration.
|
||||
IterationState* iter_state =
|
||||
new IterationState(pending_counts, total_input_tensors);
|
||||
SetIteration(next_iter, iter_state);
|
||||
IterationState* next_iter =
|
||||
new IterationState(iteration_count, pending_counts, total_input_tensors);
|
||||
SetIteration(iteration_count, next_iter);
|
||||
num_outstanding_iterations++;
|
||||
dead_exits.clear();
|
||||
|
||||
@ -631,14 +628,15 @@ void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
||||
|
||||
// Activate the loop invariants in the new iteration.
|
||||
ActivateLoopInvs(next_iter, ready);
|
||||
|
||||
return next_iter;
|
||||
}
|
||||
|
||||
bool PropagatorState::FrameState::CleanupIterations(int64 iter,
|
||||
bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
|
||||
TaggedNodeSeq* ready) {
|
||||
int64 curr_iter = iter;
|
||||
while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
|
||||
// Delete the iteration curr_iter.
|
||||
delete GetIteration(curr_iter);
|
||||
int64 curr_iter = iter_state->iter_num;
|
||||
while (curr_iter <= iteration_count && IsIterationDone(iter_state)) {
|
||||
delete iter_state;
|
||||
SetIteration(curr_iter, nullptr);
|
||||
--num_outstanding_iterations;
|
||||
++curr_iter;
|
||||
@ -648,6 +646,10 @@ bool PropagatorState::FrameState::CleanupIterations(int64 iter,
|
||||
if (!next_iter_roots.empty()) {
|
||||
IncrementIteration(ready);
|
||||
}
|
||||
|
||||
if (curr_iter <= iteration_count) {
|
||||
iter_state = GetIteration(curr_iter);
|
||||
}
|
||||
}
|
||||
return IsFrameDone();
|
||||
}
|
||||
@ -677,21 +679,21 @@ void PropagatorState::FrameState::SetIteration(int64 iter,
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool PropagatorState::FrameState::DecrementOutstandingOps(
|
||||
int64 iter, TaggedNodeSeq* ready) {
|
||||
IterationState* iter_state, TaggedNodeSeq* ready) {
|
||||
mutex_lock l(mu);
|
||||
return DecrementOutstandingOpsLocked(iter, ready);
|
||||
return DecrementOutstandingOpsLocked(iter_state, ready);
|
||||
}
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
|
||||
int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
IterationState* istate = GetIteration(iter);
|
||||
istate->outstanding_ops--;
|
||||
if (istate->outstanding_ops != 0) {
|
||||
IterationState* iter_state, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||
iter_state->outstanding_ops--;
|
||||
if (iter_state->outstanding_ops != 0) {
|
||||
return false;
|
||||
} else {
|
||||
return CleanupIterations(iter, ready);
|
||||
return CleanupIterations(iter_state, ready);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,8 +49,10 @@ class PropagatorState {
|
||||
~PropagatorState();
|
||||
|
||||
private:
|
||||
// Forward declaration so that `TaggedNode` can include a `FrameState*`.
|
||||
// Forward declaration so that `TaggedNode` can include a `FrameState*` and an
|
||||
// `IterationState*`.
|
||||
struct FrameState;
|
||||
struct IterationState;
|
||||
|
||||
public:
|
||||
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
|
||||
@ -59,12 +61,12 @@ class PropagatorState {
|
||||
struct TaggedNode {
|
||||
const NodeItem* node_item;
|
||||
FrameState* input_frame;
|
||||
int64 input_iter;
|
||||
IterationState* input_iter;
|
||||
bool is_dead;
|
||||
|
||||
TaggedNode() = default;
|
||||
TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter,
|
||||
bool dead)
|
||||
TaggedNode(const NodeItem* node_item, FrameState* in_frame,
|
||||
IterationState* in_iter, bool dead)
|
||||
: node_item(node_item),
|
||||
input_frame(in_frame),
|
||||
input_iter(in_iter),
|
||||
@ -73,7 +75,7 @@ class PropagatorState {
|
||||
const NodeItem& get_node_item() const { return *node_item; }
|
||||
|
||||
bool get_is_dead() const { return is_dead; }
|
||||
int64 get_iter_num() const { return input_iter; }
|
||||
int64 get_iter_num() const;
|
||||
};
|
||||
|
||||
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||
@ -116,16 +118,18 @@ class PropagatorState {
|
||||
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||
|
||||
private:
|
||||
// The state of an iteration in a particular frame.
|
||||
struct IterationState {
|
||||
explicit IterationState(const PendingCounts* pending_counts,
|
||||
explicit IterationState(int64 iter_num, const PendingCounts* pending_counts,
|
||||
int total_input_tensors)
|
||||
: input_tensors(new Entry[total_input_tensors]),
|
||||
: iter_num(iter_num),
|
||||
input_tensors(new Entry[total_input_tensors]),
|
||||
outstanding_ops(0),
|
||||
outstanding_frame_count(0),
|
||||
counts(*pending_counts) { // Initialize with copy of *pending_counts
|
||||
}
|
||||
|
||||
// The state of an iteration.
|
||||
const int64 iter_num; // The index of this iteration in the enclosing loop.
|
||||
|
||||
// One copy per iteration. For iteration k, i-th node's j-th input is in
|
||||
// input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
|
||||
@ -221,10 +225,10 @@ class PropagatorState {
|
||||
// frame_name.
|
||||
uint64 frame_id;
|
||||
|
||||
// The iteration id of its parent frame when this frame is created.
|
||||
// -1 if there is no parent frame. The frame_name/parent_iter pair
|
||||
// The iteration state of its parent frame when this frame is created.
|
||||
// nullptr if there is no parent frame. The frame_name/parent_iter pair
|
||||
// uniquely identifies this FrameState.
|
||||
int64 parent_iter = -1;
|
||||
IterationState* parent_iter = nullptr;
|
||||
|
||||
// The FrameState of its parent frame.
|
||||
FrameState* parent_frame = nullptr;
|
||||
@ -291,28 +295,33 @@ class PropagatorState {
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready);
|
||||
bool DecrementOutstandingOps(IterationState* iter_state,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Decrement the outstanding op count and clean up the iterations in the
|
||||
// frame. Return true iff the execution of the frame is done.
|
||||
bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready);
|
||||
bool DecrementOutstandingOpsLocked(IterationState* iter_state,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Returns true if the computation in the frame is completed.
|
||||
bool IsFrameDone();
|
||||
|
||||
// Returns true if the iteration of the frame is completed.
|
||||
bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
bool IsIterationDone(IterationState* iter_state)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Increments the iteration id. If this is a new iteration, initialize it.
|
||||
void IncrementIteration(TaggedNodeSeq* ready)
|
||||
//
|
||||
// Returns a pointer to the new iteration.
|
||||
IterationState* IncrementIteration(TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Activate all the deferred NextIteration nodes in a new iteration.
|
||||
void ActivateNexts(int64 iter, TaggedNodeSeq* ready)
|
||||
void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Activate all the current loop invariants in a new iteration.
|
||||
void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready)
|
||||
void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Add a new loop invariant and make it available to all active
|
||||
@ -322,12 +331,12 @@ class PropagatorState {
|
||||
|
||||
// Activate the successors of a node. Contents of *outputs are left in an
|
||||
// indeterminate state after returning from this method.
|
||||
void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
|
||||
EntryVector* outputs, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
void ActivateNodes(const NodeItem* item, const bool is_dead,
|
||||
IterationState* iter_state, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
// Cleanup iterations of this frame starting from iteration iter.
|
||||
bool CleanupIterations(int64 iter, TaggedNodeSeq* ready)
|
||||
// Cleanup iterations of this frame starting from the given iteration.
|
||||
bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
void DumpIterationState(PropagatorState* parent) {
|
||||
@ -350,12 +359,12 @@ class PropagatorState {
|
||||
private:
|
||||
// REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
|
||||
void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
|
||||
int64 iter, EntryVector* outputs,
|
||||
IterationState* iter_state, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
|
||||
void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
|
||||
int64 iter, EntryVector* outputs,
|
||||
IterationState* iter_state, EntryVector* outputs,
|
||||
TaggedNodeSeq* ready)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||
};
|
||||
@ -379,13 +388,13 @@ class PropagatorState {
|
||||
// same address while the iteration is live.
|
||||
Entry* GetInputTensors(const TaggedNode& tagged_node) const
|
||||
TF_NO_THREAD_SAFETY_ANALYSIS {
|
||||
return tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->input_tensors +
|
||||
return tagged_node.input_iter->input_tensors +
|
||||
tagged_node.node_item->input_start;
|
||||
}
|
||||
|
||||
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
|
||||
return {tagged_node.input_frame->frame_id, tagged_node.input_iter};
|
||||
return {tagged_node.input_frame->frame_id,
|
||||
tagged_node.input_iter->iter_num};
|
||||
}
|
||||
|
||||
// Provide debugging output of the state of the executor.
|
||||
@ -397,9 +406,8 @@ class PropagatorState {
|
||||
// optional debugging support.
|
||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(tagged_node.input_frame->mu);
|
||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->mark_started(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
tagged_node.input_iter->mark_started(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -408,16 +416,15 @@ class PropagatorState {
|
||||
// optional debugging support.
|
||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||
mutex_lock l(tagged_node.input_frame->mu);
|
||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
||||
->mark_completed(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
tagged_node.input_iter->mark_completed(
|
||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Find an existing or create a new child frame in the frame 'frame' at
|
||||
// iteration 'iter'.
|
||||
void FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
||||
void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state,
|
||||
const NodeItem& node_item, FrameState** child);
|
||||
|
||||
// Delete a frame. Called when the frame is done.
|
||||
@ -425,7 +432,7 @@ class PropagatorState {
|
||||
|
||||
// Cleanup frames and iterations starting from frame/iter. Called when
|
||||
// a child frame is done.
|
||||
void CleanupFramesIterations(FrameState* frame, int64 iter,
|
||||
void CleanupFramesIterations(FrameState* frame, IterationState* iter_state,
|
||||
TaggedNodeSeq* ready);
|
||||
|
||||
// Provide debugging output about an outstanding iteration in the executor.
|
||||
@ -450,6 +457,10 @@ class PropagatorState {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
||||
};
|
||||
|
||||
inline int64 PropagatorState::TaggedNode::get_iter_num() const {
|
||||
return input_iter->iter_num;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
|
||||
|
Loading…
Reference in New Issue
Block a user