[PropagatorState] Use IterationState*
instead of int64
iteration ID in TaggedNode
.
This change avoids making frequent calls (at least 2 per kernel dispatched) to `FrameState::GetIteration()`, which perform integer division. Since this work happens under the `FrameState` mutex, reducing the work done here should also ease contention on that mutex slightly. This change adds microbenchmarks for both the functional (`WhileOp`) and lowered (`Switch`/`Merge`/`Enter`/`Exit`) loop implementations. To support these microbenchmarks, it adds support to "kernel_benchmark_testlib" for running functions during the benchmark. PiperOrigin-RevId: 310175986 Change-Id: I9af0781e341e5d8cb19729de375078f9f4237c54
This commit is contained in:
parent
17d5c85577
commit
ae2a0e5c47
@ -2297,6 +2297,10 @@ tf_cc_test(
|
|||||||
":core",
|
":core",
|
||||||
":core_cpu",
|
":core_cpu",
|
||||||
":core_cpu_internal",
|
":core_cpu_internal",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/cc:cc_ops_internal",
|
||||||
|
"//tensorflow/cc:function_ops",
|
||||||
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -2307,6 +2311,8 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core/kernels:array",
|
"//tensorflow/core/kernels:array",
|
||||||
"//tensorflow/core/kernels:control_flow_ops",
|
"//tensorflow/core/kernels:control_flow_ops",
|
||||||
|
"//tensorflow/core/kernels:function_ops",
|
||||||
|
"//tensorflow/core/kernels:functional_ops",
|
||||||
"//tensorflow/core/kernels:math",
|
"//tensorflow/core/kernels:math",
|
||||||
"//tensorflow/core/kernels:random_ops",
|
"//tensorflow/core/kernels:random_ops",
|
||||||
"//tensorflow/core/kernels:state",
|
"//tensorflow/core/kernels:state",
|
||||||
|
@ -17,15 +17,24 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/framework/ops.h"
|
||||||
|
#include "tensorflow/cc/ops/array_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/const_op.h"
|
||||||
|
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||||
|
#include "tensorflow/cc/ops/function_ops.h"
|
||||||
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
#include "tensorflow/core/common_runtime/graph_constructor.h"
|
||||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
|
#include "tensorflow/core/common_runtime/lower_functional_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/rendezvous.h"
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
#include "tensorflow/core/framework/step_stats.pb.h"
|
#include "tensorflow/core/framework/step_stats.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/algorithm.h"
|
#include "tensorflow/core/graph/algorithm.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -532,4 +541,115 @@ static void BM_FeedInputFetchOutput(int iters) {
|
|||||||
}
|
}
|
||||||
BENCHMARK(BM_FeedInputFetchOutput);
|
BENCHMARK(BM_FeedInputFetchOutput);
|
||||||
|
|
||||||
|
// Defines a graph to perform the following computation:
|
||||||
|
//
|
||||||
|
// i = 0
|
||||||
|
// while (i < loop_iters)
|
||||||
|
// i += 1;
|
||||||
|
//
|
||||||
|
// ...using the functional `WhileOp` (if `lower` is false) or the
|
||||||
|
// `Switch`/`Merge`-style of control flow (if `lower` is true).
|
||||||
|
static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
||||||
|
testing::StopTiming();
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
|
||||||
|
// Add test functions for cond and body.
|
||||||
|
FunctionDefLibrary f_lib_proto;
|
||||||
|
|
||||||
|
// Define the loop body as a function: `x = x + 1`.
|
||||||
|
const Tensor one_t = test::AsScalar<int32>(1);
|
||||||
|
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
||||||
|
// Name
|
||||||
|
"XPlusOne",
|
||||||
|
// Args
|
||||||
|
{"x: int32"},
|
||||||
|
// Return values
|
||||||
|
{"y: int32"},
|
||||||
|
// Attr def
|
||||||
|
{},
|
||||||
|
// Nodes
|
||||||
|
{
|
||||||
|
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}},
|
||||||
|
{{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Define the loop condition as a function: `x < loop_iters`.
|
||||||
|
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
|
||||||
|
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
||||||
|
// Name
|
||||||
|
"LessThanOrEqualToN",
|
||||||
|
// Args
|
||||||
|
{"x: int32"},
|
||||||
|
// Return values
|
||||||
|
{"z: bool"},
|
||||||
|
// Attr def
|
||||||
|
{},
|
||||||
|
// Nodes
|
||||||
|
{
|
||||||
|
{{"N"}, "Const", {}, {{"value", loop_iters_t}, {"dtype", DT_INT32}}},
|
||||||
|
{{"z"}, "LessEqual", {"x", "N"}, {{"T", DT_INT32}}},
|
||||||
|
});
|
||||||
|
|
||||||
|
Scope root = Scope::NewRootScope().ExitOnError();
|
||||||
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||||
|
auto a = ops::Const(root.WithOpName("A"), 0, {});
|
||||||
|
Node* while_node;
|
||||||
|
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
||||||
|
AttrValue int32_attr;
|
||||||
|
int32_attr.set_type(DT_INT32);
|
||||||
|
AttrValue cond_func;
|
||||||
|
cond_func.mutable_func()->set_name("LessThanOrEqualToN");
|
||||||
|
AttrValue body_func;
|
||||||
|
body_func.mutable_func()->set_name("XPlusOne");
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
NodeBuilder("while", "While", &root.graph()->flib_def())
|
||||||
|
.Input(inputs)
|
||||||
|
.Attr("T", {DT_INT32})
|
||||||
|
.Attr("cond", cond_func)
|
||||||
|
.Attr("body", body_func)
|
||||||
|
.Attr("parallel_iterations", 100)
|
||||||
|
.Attr(LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr, true)
|
||||||
|
.Finalize(root.graph(), &while_node));
|
||||||
|
auto c = ops::Identity(
|
||||||
|
root.WithOpName("C").WithControlDependencies(Output(while_node)),
|
||||||
|
Output(while_node));
|
||||||
|
TF_ASSERT_OK(root.DoShapeInference(while_node));
|
||||||
|
TF_ASSERT_OK(root.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
if (lower) {
|
||||||
|
FunctionLibraryDefinition flib_def(graph->flib_def());
|
||||||
|
GraphOptimizationPassOptions opt_options;
|
||||||
|
SessionOptions session_options;
|
||||||
|
session_options.config.mutable_graph_options()
|
||||||
|
->mutable_optimizer_options()
|
||||||
|
->set_do_function_inlining(true);
|
||||||
|
opt_options.session_options = &session_options;
|
||||||
|
opt_options.graph = &graph;
|
||||||
|
opt_options.flib_def = &flib_def;
|
||||||
|
LowerFunctionalOpsPass pass;
|
||||||
|
TF_ASSERT_OK(pass.Run(opt_options));
|
||||||
|
}
|
||||||
|
|
||||||
|
FixupSourceAndSinkEdges(graph.get());
|
||||||
|
testing::StartTiming();
|
||||||
|
test::Benchmark("cpu", graph.release()).Run(iters);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_LoweredWhileLoop(int iters, int loop_iters) {
|
||||||
|
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true);
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_LoweredWhileLoop)->Arg(0);
|
||||||
|
BENCHMARK(BM_LoweredWhileLoop)->Arg(1);
|
||||||
|
BENCHMARK(BM_LoweredWhileLoop)->Arg(10);
|
||||||
|
BENCHMARK(BM_LoweredWhileLoop)->Arg(100);
|
||||||
|
BENCHMARK(BM_LoweredWhileLoop)->Arg(1000);
|
||||||
|
|
||||||
|
static void BM_FunctionalWhileLoop(int iters, int loop_iters) {
|
||||||
|
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false);
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_FunctionalWhileLoop)->Arg(0);
|
||||||
|
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1);
|
||||||
|
BENCHMARK(BM_FunctionalWhileLoop)->Arg(10);
|
||||||
|
BENCHMARK(BM_FunctionalWhileLoop)->Arg(100);
|
||||||
|
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -19,7 +19,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
#include "tensorflow/core/common_runtime/executor_factory.h"
|
#include "tensorflow/core/common_runtime/executor_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/common_runtime/local_device.h"
|
#include "tensorflow/core/common_runtime/local_device.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -62,8 +64,10 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||||||
// Allow NewDevice to allocate a new threadpool with different number of
|
// Allow NewDevice to allocate a new threadpool with different number of
|
||||||
// threads for each new benchmark.
|
// threads for each new benchmark.
|
||||||
LocalDevice::set_use_global_threadpool(false);
|
LocalDevice::set_use_global_threadpool(false);
|
||||||
device_ =
|
|
||||||
DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0");
|
device_mgr_ = absl::make_unique<StaticDeviceMgr>(
|
||||||
|
DeviceFactory::NewDevice(t, *options, "/job:localhost/replica:0/task:0"));
|
||||||
|
device_ = device_mgr_->ListDevices()[0];
|
||||||
CHECK(device_) << "Could not create a " << device << " device";
|
CHECK(device_) << "Could not create a " << device << " device";
|
||||||
|
|
||||||
pool_ =
|
pool_ =
|
||||||
@ -81,14 +85,24 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||||||
|
|
||||||
const int graph_def_version = g->versions().producer();
|
const int graph_def_version = g->versions().producer();
|
||||||
|
|
||||||
|
flib_def_ = absl::make_unique<FunctionLibraryDefinition>(g->flib_def());
|
||||||
|
|
||||||
|
pflr_ = std::unique_ptr<ProcessFunctionLibraryRuntime>(
|
||||||
|
new ProcessFunctionLibraryRuntime(
|
||||||
|
device_mgr_.get(), Env::Default(), nullptr, graph_def_version,
|
||||||
|
flib_def_.get(), OptimizerOptions(), pool_, nullptr, nullptr, nullptr,
|
||||||
|
Rendezvous::Factory()));
|
||||||
|
|
||||||
|
flr_ = pflr_->GetFLR(device_->name());
|
||||||
|
|
||||||
LocalExecutorParams params;
|
LocalExecutorParams params;
|
||||||
params.device = device_.get();
|
params.device = device_;
|
||||||
params.function_library = nullptr;
|
params.function_library = flr_;
|
||||||
params.create_kernel = [this, graph_def_version](
|
params.create_kernel = [this, graph_def_version](
|
||||||
const std::shared_ptr<const NodeProperties>& props,
|
const std::shared_ptr<const NodeProperties>& props,
|
||||||
OpKernel** kernel) {
|
OpKernel** kernel) {
|
||||||
return CreateNonCachedKernel(device_.get(), nullptr, props,
|
return CreateNonCachedKernel(device_, flr_, props, graph_def_version,
|
||||||
graph_def_version, kernel);
|
kernel);
|
||||||
};
|
};
|
||||||
params.delete_kernel = [](OpKernel* kernel) {
|
params.delete_kernel = [](OpKernel* kernel) {
|
||||||
DeleteNonCachedKernel(kernel);
|
DeleteNonCachedKernel(kernel);
|
||||||
@ -109,11 +123,12 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||||||
Benchmark::~Benchmark() {
|
Benchmark::~Benchmark() {
|
||||||
if (device_) {
|
if (device_) {
|
||||||
rendez_->Unref();
|
rendez_->Unref();
|
||||||
// We delete `exec_` before `device_` because the `exec_` destructor may
|
// We delete `exec_` before `device_mgr_` because the `exec_` destructor may
|
||||||
// run kernel destructors that may attempt to access state borrowed from
|
// run kernel destructors that may attempt to access state borrowed from
|
||||||
// `device_`, such as the resource manager.
|
// `device_mgr_`, such as the resource manager.
|
||||||
exec_.reset();
|
exec_.reset();
|
||||||
device_.reset();
|
pflr_.reset();
|
||||||
|
device_mgr_.reset();
|
||||||
delete pool_;
|
delete pool_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -29,7 +29,10 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class Device;
|
class Device;
|
||||||
|
class FunctionLibraryRuntime;
|
||||||
|
class ProcessFunctionLibraryRuntime;
|
||||||
struct SessionOptions;
|
struct SessionOptions;
|
||||||
|
class StaticDeviceMgr;
|
||||||
|
|
||||||
namespace test {
|
namespace test {
|
||||||
|
|
||||||
@ -55,9 +58,13 @@ class Benchmark {
|
|||||||
const std::vector<string>& outputs, int iters);
|
const std::vector<string>& outputs, int iters);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
thread::ThreadPool* pool_ = nullptr;
|
thread::ThreadPool* pool_ = nullptr; // Not owned.
|
||||||
std::unique_ptr<Device> device_ = nullptr;
|
Device* device_ = nullptr; // Not owned.
|
||||||
Rendezvous* rendez_ = nullptr;
|
Rendezvous* rendez_ = nullptr;
|
||||||
|
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||||
|
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||||
|
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
|
||||||
|
FunctionLibraryRuntime* flr_; // Not owned.
|
||||||
std::unique_ptr<Executor> exec_;
|
std::unique_ptr<Executor> exec_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
|
TF_DISALLOW_COPY_AND_ASSIGN(Benchmark);
|
||||||
|
@ -37,7 +37,7 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
|||||||
|
|
||||||
// Initialize iteration 0.
|
// Initialize iteration 0.
|
||||||
root_frame_->SetIteration(
|
root_frame_->SetIteration(
|
||||||
0, new PropagatorState::IterationState(root_frame_->pending_counts,
|
0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
|
||||||
root_frame_->total_input_tensors));
|
root_frame_->total_input_tensors));
|
||||||
|
|
||||||
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
||||||
@ -51,12 +51,13 @@ PropagatorState::~PropagatorState() {
|
|||||||
|
|
||||||
void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
|
||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
|
mutex_lock l(root_frame_->mu);
|
||||||
|
IterationState* root_iter = root_frame_->GetIteration(0);
|
||||||
for (const NodeItem* item : roots) {
|
for (const NodeItem* item : roots) {
|
||||||
DCHECK_EQ(item->num_inputs, 0);
|
DCHECK_EQ(item->num_inputs, 0);
|
||||||
ready->emplace_back(item, root_frame_, 0, false);
|
ready->emplace_back(item, root_frame_, root_iter, false);
|
||||||
}
|
}
|
||||||
mutex_lock l(root_frame_->mu);
|
root_iter->outstanding_ops = ready->size();
|
||||||
root_frame_->GetIteration(0)->outstanding_ops = ready->size();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
||||||
@ -75,7 +76,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
|
|
||||||
const NodeItem* const item = tagged_node.node_item;
|
const NodeItem* const item = tagged_node.node_item;
|
||||||
FrameState* const input_frame = tagged_node.input_frame;
|
FrameState* const input_frame = tagged_node.input_frame;
|
||||||
const int64 input_iter = tagged_node.input_iter;
|
IterationState* const input_iter = tagged_node.input_iter;
|
||||||
const bool is_dead = tagged_node.is_dead;
|
const bool is_dead = tagged_node.is_dead;
|
||||||
|
|
||||||
// Propagates outputs along out edges, and puts newly ready nodes
|
// Propagates outputs along out edges, and puts newly ready nodes
|
||||||
@ -83,7 +84,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
DCHECK(ready->empty());
|
DCHECK(ready->empty());
|
||||||
bool is_frame_done = false;
|
bool is_frame_done = false;
|
||||||
FrameState* output_frame = input_frame;
|
FrameState* output_frame = input_frame;
|
||||||
int64 output_iter = input_iter;
|
IterationState* output_iter = input_iter;
|
||||||
|
|
||||||
if (!item->is_enter_exit_or_next_iter) {
|
if (!item->is_enter_exit_or_next_iter) {
|
||||||
// Fast path for nodes types that don't need special handling
|
// Fast path for nodes types that don't need special handling
|
||||||
@ -95,9 +96,9 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
|
input_frame->DecrementOutstandingOpsLocked(input_iter, ready);
|
||||||
} else if (item->is_enter) {
|
} else if (item->is_enter) {
|
||||||
FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
|
FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
|
||||||
output_iter = 0;
|
|
||||||
{
|
{
|
||||||
mutex_lock l(output_frame->mu);
|
mutex_lock l(output_frame->mu);
|
||||||
|
output_iter = output_frame->GetIteration(0);
|
||||||
if (item->is_constant_enter) {
|
if (item->is_constant_enter) {
|
||||||
// Propagate to all active iterations if this is a loop invariant.
|
// Propagate to all active iterations if this is a loop invariant.
|
||||||
output_frame->AddLoopInv(item, (*outputs)[0], ready);
|
output_frame->AddLoopInv(item, (*outputs)[0], ready);
|
||||||
@ -111,7 +112,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
if (is_dead) {
|
if (is_dead) {
|
||||||
mutex_lock l(input_frame->mu);
|
mutex_lock l(input_frame->mu);
|
||||||
// Stop and remember this node if it is a dead exit.
|
// Stop and remember this node if it is a dead exit.
|
||||||
if (input_iter == input_frame->iteration_count) {
|
if (input_iter->iter_num == input_frame->iteration_count) {
|
||||||
input_frame->dead_exits.push_back(item);
|
input_frame->dead_exits.push_back(item);
|
||||||
}
|
}
|
||||||
is_frame_done =
|
is_frame_done =
|
||||||
@ -132,7 +133,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
// Stop the deadness propagation.
|
// Stop the deadness propagation.
|
||||||
output_frame = nullptr;
|
output_frame = nullptr;
|
||||||
} else {
|
} else {
|
||||||
if (input_iter == input_frame->iteration_count &&
|
if (input_iter->iter_num == input_frame->iteration_count &&
|
||||||
input_frame->num_outstanding_iterations ==
|
input_frame->num_outstanding_iterations ==
|
||||||
input_frame->max_parallel_iterations) {
|
input_frame->max_parallel_iterations) {
|
||||||
// Reached the maximum for parallel iterations.
|
// Reached the maximum for parallel iterations.
|
||||||
@ -140,10 +141,11 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
output_frame = nullptr;
|
output_frame = nullptr;
|
||||||
} else {
|
} else {
|
||||||
// If this is a new iteration, start it.
|
// If this is a new iteration, start it.
|
||||||
if (input_iter == input_frame->iteration_count) {
|
if (input_iter->iter_num == input_frame->iteration_count) {
|
||||||
input_frame->IncrementIteration(ready);
|
output_iter = input_frame->IncrementIteration(ready);
|
||||||
|
} else {
|
||||||
|
output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
|
||||||
}
|
}
|
||||||
output_iter = input_iter + 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (output_frame != nullptr) {
|
if (output_frame != nullptr) {
|
||||||
@ -159,7 +161,7 @@ void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
|
|||||||
// completion of this node makes its frame completed.
|
// completion of this node makes its frame completed.
|
||||||
if (is_frame_done) {
|
if (is_frame_done) {
|
||||||
FrameState* parent_frame = input_frame->parent_frame;
|
FrameState* parent_frame = input_frame->parent_frame;
|
||||||
const int64 parent_iter = input_frame->parent_iter;
|
IterationState* parent_iter = input_frame->parent_iter;
|
||||||
DeleteFrame(input_frame, ready);
|
DeleteFrame(input_frame, ready);
|
||||||
if (parent_frame != nullptr) {
|
if (parent_frame != nullptr) {
|
||||||
// The completion of frame may cause completions in its parent frame.
|
// The completion of frame may cause completions in its parent frame.
|
||||||
@ -217,7 +219,8 @@ void PropagatorState::DumpState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
||||||
|
IterationState* iter_state,
|
||||||
const NodeItem& node_item,
|
const NodeItem& node_item,
|
||||||
FrameState** child) {
|
FrameState** child) {
|
||||||
// Get the child frame name.
|
// Get the child frame name.
|
||||||
@ -225,8 +228,8 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
|||||||
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
|
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
|
||||||
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
|
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
|
||||||
<< node_item.kernel->name();
|
<< node_item.kernel->name();
|
||||||
const string child_name =
|
const string child_name = strings::StrCat(
|
||||||
strings::StrCat(frame->frame_name, ";", iter, ";", enter_name);
|
frame->frame_name, ";", iter_state->iter_num, ";", enter_name);
|
||||||
|
|
||||||
{
|
{
|
||||||
mutex_lock executor_lock(mu_);
|
mutex_lock executor_lock(mu_);
|
||||||
@ -251,14 +254,14 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
|||||||
temp->frame_name = child_name;
|
temp->frame_name = child_name;
|
||||||
temp->frame_id = Hash64(child_name);
|
temp->frame_id = Hash64(child_name);
|
||||||
temp->parent_frame = frame;
|
temp->parent_frame = frame;
|
||||||
temp->parent_iter = iter;
|
temp->parent_iter = iter_state;
|
||||||
temp->InitializeFrameInfo(enter_name);
|
temp->InitializeFrameInfo(enter_name);
|
||||||
|
|
||||||
// Initialize iteration 0.
|
// Initialize iteration 0.
|
||||||
{
|
{
|
||||||
mutex_lock l(temp->mu);
|
mutex_lock l(temp->mu);
|
||||||
temp->SetIteration(
|
temp->SetIteration(0, new IterationState(0, temp->pending_counts,
|
||||||
0, new IterationState(temp->pending_counts, temp->total_input_tensors));
|
temp->total_input_tensors));
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -268,7 +271,7 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
|||||||
*child = it->second;
|
*child = it->second;
|
||||||
} else {
|
} else {
|
||||||
mutex_lock frame_lock(frame->mu);
|
mutex_lock frame_lock(frame->mu);
|
||||||
frame->GetIteration(iter)->outstanding_frame_count++;
|
iter_state->outstanding_frame_count++;
|
||||||
outstanding_frames_[child_name] = temp;
|
outstanding_frames_[child_name] = temp;
|
||||||
*child = temp;
|
*child = temp;
|
||||||
temp = nullptr;
|
temp = nullptr;
|
||||||
@ -280,20 +283,19 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
|||||||
void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
||||||
// First, propagate dead_exits (if any) to the parent frame.
|
// First, propagate dead_exits (if any) to the parent frame.
|
||||||
FrameState* parent_frame = frame->parent_frame;
|
FrameState* parent_frame = frame->parent_frame;
|
||||||
const int64 parent_iter = frame->parent_iter;
|
IterationState* parent_iter_state = frame->parent_iter;
|
||||||
if (parent_frame != nullptr) {
|
if (parent_frame != nullptr) {
|
||||||
mutex_lock parent_frame_lock(parent_frame->mu);
|
mutex_lock parent_frame_lock(parent_frame->mu);
|
||||||
// Propagate all the dead exits to the parent frame.
|
// Propagate all the dead exits to the parent frame.
|
||||||
mutex_lock this_frame_lock(frame->mu);
|
mutex_lock this_frame_lock(frame->mu);
|
||||||
|
|
||||||
for (const NodeItem* item : frame->dead_exits) {
|
for (const NodeItem* item : frame->dead_exits) {
|
||||||
auto parent_iter_state = parent_frame->GetIteration(parent_iter);
|
|
||||||
|
|
||||||
auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
|
auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
|
||||||
bool dst_dead) {
|
bool dst_dead) {
|
||||||
if (dst_ready) {
|
if (dst_ready) {
|
||||||
if (dst_item.is_control_trigger) dst_dead = false;
|
if (dst_item.is_control_trigger) dst_dead = false;
|
||||||
ready->emplace_back(&dst_item, parent_frame, parent_iter, dst_dead);
|
ready->emplace_back(&dst_item, parent_frame, parent_iter_state,
|
||||||
|
dst_dead);
|
||||||
parent_iter_state->outstanding_ops++;
|
parent_iter_state->outstanding_ops++;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -356,17 +358,18 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
|||||||
delete frame;
|
delete frame;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
|
void PropagatorState::CleanupFramesIterations(FrameState* frame,
|
||||||
|
IterationState* iter_state,
|
||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
bool is_frame_done = false;
|
bool is_frame_done = false;
|
||||||
{
|
{
|
||||||
mutex_lock frame_lock(frame->mu);
|
mutex_lock frame_lock(frame->mu);
|
||||||
frame->GetIteration(iter)->outstanding_frame_count--;
|
iter_state->outstanding_frame_count--;
|
||||||
is_frame_done = frame->CleanupIterations(iter, ready);
|
is_frame_done = frame->CleanupIterations(iter_state, ready);
|
||||||
}
|
}
|
||||||
if (is_frame_done) {
|
if (is_frame_done) {
|
||||||
FrameState* parent_frame = frame->parent_frame;
|
FrameState* parent_frame = frame->parent_frame;
|
||||||
const int64 parent_iter = frame->parent_iter;
|
IterationState* parent_iter = frame->parent_iter;
|
||||||
DeleteFrame(frame, ready);
|
DeleteFrame(frame, ready);
|
||||||
if (parent_frame != nullptr) {
|
if (parent_frame != nullptr) {
|
||||||
// The completion of frame may cause completions in its parent frame.
|
// The completion of frame may cause completions in its parent frame.
|
||||||
@ -376,16 +379,13 @@ void PropagatorState::CleanupFramesIterations(FrameState* frame, int64 iter,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
void PropagatorState::FrameState::ActivateNodesFastPath(
|
||||||
const bool is_dead,
|
const NodeItem* item, const bool is_dead, IterationState* iter_state,
|
||||||
int64 iter,
|
EntryVector* outputs, TaggedNodeSeq* ready) {
|
||||||
EntryVector* outputs,
|
|
||||||
TaggedNodeSeq* ready) {
|
|
||||||
// If we know that none of the item's edge destinations require special
|
// If we know that none of the item's edge destinations require special
|
||||||
// handling (i.e. none of the nodes is a merge or control trigger node), we
|
// handling (i.e. none of the nodes is a merge or control trigger node), we
|
||||||
// can take a fast path that avoids accessing the destination NodeItem.
|
// can take a fast path that avoids accessing the destination NodeItem.
|
||||||
const GraphView& gview = immutable_state.graph_view();
|
const GraphView& gview = immutable_state.graph_view();
|
||||||
IterationState* iter_state = GetIteration(iter);
|
|
||||||
|
|
||||||
// Add dst to the ready queue if it's ready
|
// Add dst to the ready queue if it's ready
|
||||||
//
|
//
|
||||||
@ -398,7 +398,7 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
|||||||
TaggedNode& t = ready->emplace_back(); \
|
TaggedNode& t = ready->emplace_back(); \
|
||||||
t.node_item = dst_item; \
|
t.node_item = dst_item; \
|
||||||
t.input_frame = this; \
|
t.input_frame = this; \
|
||||||
t.input_iter = iter; \
|
t.input_iter = iter_state; \
|
||||||
t.is_dead = adjust_result.any_dead; \
|
t.is_dead = adjust_result.any_dead; \
|
||||||
iter_state->outstanding_ops++; \
|
iter_state->outstanding_ops++; \
|
||||||
} \
|
} \
|
||||||
@ -436,23 +436,20 @@ void PropagatorState::FrameState::ActivateNodesFastPath(const NodeItem* item,
|
|||||||
#undef MAYBE_ADD_TO_READY
|
#undef MAYBE_ADD_TO_READY
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
void PropagatorState::FrameState::ActivateNodesSlowPath(
|
||||||
const bool is_dead,
|
const NodeItem* item, const bool is_dead, IterationState* iter_state,
|
||||||
int64 iter,
|
EntryVector* outputs, TaggedNodeSeq* ready) {
|
||||||
EntryVector* outputs,
|
|
||||||
TaggedNodeSeq* ready) {
|
|
||||||
// If any of the edge destinations is a merge or a control trigger node,
|
// If any of the edge destinations is a merge or a control trigger node,
|
||||||
// we need to read each destination NodeItem to determine what action
|
// we need to read each destination NodeItem to determine what action
|
||||||
// to take.
|
// to take.
|
||||||
const GraphView& gview = immutable_state.graph_view();
|
const GraphView& gview = immutable_state.graph_view();
|
||||||
IterationState* iter_state = GetIteration(iter);
|
|
||||||
|
|
||||||
auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
|
auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
|
||||||
bool dst_ready, bool dst_dead) {
|
bool dst_ready, bool dst_dead) {
|
||||||
// Add dst to the ready queue if it's ready
|
// Add dst to the ready queue if it's ready
|
||||||
if (dst_ready) {
|
if (dst_ready) {
|
||||||
if (dst_item->is_control_trigger) dst_dead = false;
|
if (dst_item->is_control_trigger) dst_dead = false;
|
||||||
ready->emplace_back(dst_item, this, iter, dst_dead);
|
ready->emplace_back(dst_item, this, iter_state, dst_dead);
|
||||||
iter_state->outstanding_ops++;
|
iter_state->outstanding_ops++;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -551,17 +548,18 @@ void PropagatorState::FrameState::ActivateNodesSlowPath(const NodeItem* item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
|
void PropagatorState::FrameState::ActivateNodes(const NodeItem* item,
|
||||||
const bool is_dead, int64 iter,
|
const bool is_dead,
|
||||||
|
IterationState* iter_state,
|
||||||
EntryVector* outputs,
|
EntryVector* outputs,
|
||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
|
if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
|
||||||
ActivateNodesSlowPath(item, is_dead, iter, outputs, ready);
|
ActivateNodesSlowPath(item, is_dead, iter_state, outputs, ready);
|
||||||
} else {
|
} else {
|
||||||
ActivateNodesFastPath(item, is_dead, iter, outputs, ready);
|
ActivateNodesFastPath(item, is_dead, iter_state, outputs, ready);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::ActivateNexts(int64 iter,
|
void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state,
|
||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
// Propagate the deferred NextIteration nodes to the new iteration.
|
// Propagate the deferred NextIteration nodes to the new iteration.
|
||||||
for (auto& node_entry : next_iter_roots) {
|
for (auto& node_entry : next_iter_roots) {
|
||||||
@ -569,12 +567,12 @@ void PropagatorState::FrameState::ActivateNexts(int64 iter,
|
|||||||
const Entry& entry = node_entry.second;
|
const Entry& entry = node_entry.second;
|
||||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||||
EntryVector outputs{entry};
|
EntryVector outputs{entry};
|
||||||
ActivateNodes(item, is_dead, iter, &outputs, ready);
|
ActivateNodes(item, is_dead, iter_state, &outputs, ready);
|
||||||
}
|
}
|
||||||
next_iter_roots.clear();
|
next_iter_roots.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
|
void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state,
|
||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
// Propagate loop invariants to the new iteration.
|
// Propagate loop invariants to the new iteration.
|
||||||
for (auto& node_entry : inv_values) {
|
for (auto& node_entry : inv_values) {
|
||||||
@ -582,7 +580,7 @@ void PropagatorState::FrameState::ActivateLoopInvs(int64 iter,
|
|||||||
const Entry& entry = node_entry.second;
|
const Entry& entry = node_entry.second;
|
||||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||||
EntryVector outputs{entry};
|
EntryVector outputs{entry};
|
||||||
ActivateNodes(item, is_dead, iter, &outputs, ready);
|
ActivateNodes(item, is_dead, iter_state, &outputs, ready);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -596,33 +594,32 @@ void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
|
|||||||
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
const bool is_dead = entry.state == Entry::State::NO_VALUE;
|
||||||
for (int i = 0; i <= iteration_count; ++i) {
|
for (int i = 0; i <= iteration_count; ++i) {
|
||||||
EntryVector outputs{entry};
|
EntryVector outputs{entry};
|
||||||
ActivateNodes(item, is_dead, i, &outputs, ready);
|
ActivateNodes(item, is_dead, GetIteration(i), &outputs, ready);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PropagatorState::FrameState::IsIterationDone(int64 iter) {
|
bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) {
|
||||||
IterationState* iter_state = GetIteration(iter);
|
|
||||||
if (iter_state->outstanding_ops == 0 &&
|
if (iter_state->outstanding_ops == 0 &&
|
||||||
iter_state->outstanding_frame_count == 0) {
|
iter_state->outstanding_frame_count == 0) {
|
||||||
if (iter == 0) {
|
if (iter_state->iter_num == 0) {
|
||||||
// The enclosing frame has no pending input.
|
// The enclosing frame has no pending input.
|
||||||
return num_pending_inputs == 0;
|
return num_pending_inputs == 0;
|
||||||
} else {
|
} else {
|
||||||
// The preceding iteration is deleted (and therefore done).
|
// The preceding iteration is deleted (and therefore done).
|
||||||
return (GetIteration(iter - 1) == nullptr);
|
return (GetIteration(iter_state->iter_num - 1) == nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
PropagatorState::IterationState*
|
||||||
|
PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
||||||
iteration_count++;
|
iteration_count++;
|
||||||
const int64 next_iter = iteration_count;
|
|
||||||
|
|
||||||
// Initialize the next iteration.
|
// Initialize the next iteration.
|
||||||
IterationState* iter_state =
|
IterationState* next_iter =
|
||||||
new IterationState(pending_counts, total_input_tensors);
|
new IterationState(iteration_count, pending_counts, total_input_tensors);
|
||||||
SetIteration(next_iter, iter_state);
|
SetIteration(iteration_count, next_iter);
|
||||||
num_outstanding_iterations++;
|
num_outstanding_iterations++;
|
||||||
dead_exits.clear();
|
dead_exits.clear();
|
||||||
|
|
||||||
@ -631,14 +628,15 @@ void PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
|
|||||||
|
|
||||||
// Activate the loop invariants in the new iteration.
|
// Activate the loop invariants in the new iteration.
|
||||||
ActivateLoopInvs(next_iter, ready);
|
ActivateLoopInvs(next_iter, ready);
|
||||||
|
|
||||||
|
return next_iter;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PropagatorState::FrameState::CleanupIterations(int64 iter,
|
bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
|
||||||
TaggedNodeSeq* ready) {
|
TaggedNodeSeq* ready) {
|
||||||
int64 curr_iter = iter;
|
int64 curr_iter = iter_state->iter_num;
|
||||||
while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
|
while (curr_iter <= iteration_count && IsIterationDone(iter_state)) {
|
||||||
// Delete the iteration curr_iter.
|
delete iter_state;
|
||||||
delete GetIteration(curr_iter);
|
|
||||||
SetIteration(curr_iter, nullptr);
|
SetIteration(curr_iter, nullptr);
|
||||||
--num_outstanding_iterations;
|
--num_outstanding_iterations;
|
||||||
++curr_iter;
|
++curr_iter;
|
||||||
@ -648,6 +646,10 @@ bool PropagatorState::FrameState::CleanupIterations(int64 iter,
|
|||||||
if (!next_iter_roots.empty()) {
|
if (!next_iter_roots.empty()) {
|
||||||
IncrementIteration(ready);
|
IncrementIteration(ready);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (curr_iter <= iteration_count) {
|
||||||
|
iter_state = GetIteration(curr_iter);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return IsFrameDone();
|
return IsFrameDone();
|
||||||
}
|
}
|
||||||
@ -677,21 +679,21 @@ void PropagatorState::FrameState::SetIteration(int64 iter,
|
|||||||
// Decrement the outstanding op count and clean up the iterations in the
|
// Decrement the outstanding op count and clean up the iterations in the
|
||||||
// frame. Return true iff the execution of the frame is done.
|
// frame. Return true iff the execution of the frame is done.
|
||||||
bool PropagatorState::FrameState::DecrementOutstandingOps(
|
bool PropagatorState::FrameState::DecrementOutstandingOps(
|
||||||
int64 iter, TaggedNodeSeq* ready) {
|
IterationState* iter_state, TaggedNodeSeq* ready) {
|
||||||
mutex_lock l(mu);
|
mutex_lock l(mu);
|
||||||
return DecrementOutstandingOpsLocked(iter, ready);
|
return DecrementOutstandingOpsLocked(iter_state, ready);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrement the outstanding op count and clean up the iterations in the
|
// Decrement the outstanding op count and clean up the iterations in the
|
||||||
// frame. Return true iff the execution of the frame is done.
|
// frame. Return true iff the execution of the frame is done.
|
||||||
bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
|
bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
|
||||||
int64 iter, TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
IterationState* iter_state, TaggedNodeSeq* ready)
|
||||||
IterationState* istate = GetIteration(iter);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||||
istate->outstanding_ops--;
|
iter_state->outstanding_ops--;
|
||||||
if (istate->outstanding_ops != 0) {
|
if (iter_state->outstanding_ops != 0) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
return CleanupIterations(iter, ready);
|
return CleanupIterations(iter_state, ready);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,8 +49,10 @@ class PropagatorState {
|
|||||||
~PropagatorState();
|
~PropagatorState();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Forward declaration so that `TaggedNode` can include a `FrameState*`.
|
// Forward declaration so that `TaggedNode` can include a `FrameState*` and an
|
||||||
|
// `IterationState*`.
|
||||||
struct FrameState;
|
struct FrameState;
|
||||||
|
struct IterationState;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
|
// A `TaggedNode` corresponds to a single invocation of a node's kernel,
|
||||||
@ -59,12 +61,12 @@ class PropagatorState {
|
|||||||
struct TaggedNode {
|
struct TaggedNode {
|
||||||
const NodeItem* node_item;
|
const NodeItem* node_item;
|
||||||
FrameState* input_frame;
|
FrameState* input_frame;
|
||||||
int64 input_iter;
|
IterationState* input_iter;
|
||||||
bool is_dead;
|
bool is_dead;
|
||||||
|
|
||||||
TaggedNode() = default;
|
TaggedNode() = default;
|
||||||
TaggedNode(const NodeItem* node_item, FrameState* in_frame, int64 in_iter,
|
TaggedNode(const NodeItem* node_item, FrameState* in_frame,
|
||||||
bool dead)
|
IterationState* in_iter, bool dead)
|
||||||
: node_item(node_item),
|
: node_item(node_item),
|
||||||
input_frame(in_frame),
|
input_frame(in_frame),
|
||||||
input_iter(in_iter),
|
input_iter(in_iter),
|
||||||
@ -73,7 +75,7 @@ class PropagatorState {
|
|||||||
const NodeItem& get_node_item() const { return *node_item; }
|
const NodeItem& get_node_item() const { return *node_item; }
|
||||||
|
|
||||||
bool get_is_dead() const { return is_dead; }
|
bool get_is_dead() const { return is_dead; }
|
||||||
int64 get_iter_num() const { return input_iter; }
|
int64 get_iter_num() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
// A drop-in replacement for std::deque<TaggedNode>. We typically don't
|
||||||
@ -116,16 +118,18 @@ class PropagatorState {
|
|||||||
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// The state of an iteration in a particular frame.
|
||||||
struct IterationState {
|
struct IterationState {
|
||||||
explicit IterationState(const PendingCounts* pending_counts,
|
explicit IterationState(int64 iter_num, const PendingCounts* pending_counts,
|
||||||
int total_input_tensors)
|
int total_input_tensors)
|
||||||
: input_tensors(new Entry[total_input_tensors]),
|
: iter_num(iter_num),
|
||||||
|
input_tensors(new Entry[total_input_tensors]),
|
||||||
outstanding_ops(0),
|
outstanding_ops(0),
|
||||||
outstanding_frame_count(0),
|
outstanding_frame_count(0),
|
||||||
counts(*pending_counts) { // Initialize with copy of *pending_counts
|
counts(*pending_counts) { // Initialize with copy of *pending_counts
|
||||||
}
|
}
|
||||||
|
|
||||||
// The state of an iteration.
|
const int64 iter_num; // The index of this iteration in the enclosing loop.
|
||||||
|
|
||||||
// One copy per iteration. For iteration k, i-th node's j-th input is in
|
// One copy per iteration. For iteration k, i-th node's j-th input is in
|
||||||
// input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
|
// input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is
|
||||||
@ -221,10 +225,10 @@ class PropagatorState {
|
|||||||
// frame_name.
|
// frame_name.
|
||||||
uint64 frame_id;
|
uint64 frame_id;
|
||||||
|
|
||||||
// The iteration id of its parent frame when this frame is created.
|
// The iteration state of its parent frame when this frame is created.
|
||||||
// -1 if there is no parent frame. The frame_name/parent_iter pair
|
// nullptr if there is no parent frame. The frame_name/parent_iter pair
|
||||||
// uniquely identifies this FrameState.
|
// uniquely identifies this FrameState.
|
||||||
int64 parent_iter = -1;
|
IterationState* parent_iter = nullptr;
|
||||||
|
|
||||||
// The FrameState of its parent frame.
|
// The FrameState of its parent frame.
|
||||||
FrameState* parent_frame = nullptr;
|
FrameState* parent_frame = nullptr;
|
||||||
@ -291,28 +295,33 @@ class PropagatorState {
|
|||||||
|
|
||||||
// Decrement the outstanding op count and clean up the iterations in the
|
// Decrement the outstanding op count and clean up the iterations in the
|
||||||
// frame. Return true iff the execution of the frame is done.
|
// frame. Return true iff the execution of the frame is done.
|
||||||
bool DecrementOutstandingOps(int64 iter, TaggedNodeSeq* ready);
|
bool DecrementOutstandingOps(IterationState* iter_state,
|
||||||
|
TaggedNodeSeq* ready);
|
||||||
|
|
||||||
// Decrement the outstanding op count and clean up the iterations in the
|
// Decrement the outstanding op count and clean up the iterations in the
|
||||||
// frame. Return true iff the execution of the frame is done.
|
// frame. Return true iff the execution of the frame is done.
|
||||||
bool DecrementOutstandingOpsLocked(int64 iter, TaggedNodeSeq* ready);
|
bool DecrementOutstandingOpsLocked(IterationState* iter_state,
|
||||||
|
TaggedNodeSeq* ready);
|
||||||
|
|
||||||
// Returns true if the computation in the frame is completed.
|
// Returns true if the computation in the frame is completed.
|
||||||
bool IsFrameDone();
|
bool IsFrameDone();
|
||||||
|
|
||||||
// Returns true if the iteration of the frame is completed.
|
// Returns true if the iteration of the frame is completed.
|
||||||
bool IsIterationDone(int64 iter) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
bool IsIterationDone(IterationState* iter_state)
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
// Increments the iteration id. If this is a new iteration, initialize it.
|
// Increments the iteration id. If this is a new iteration, initialize it.
|
||||||
void IncrementIteration(TaggedNodeSeq* ready)
|
//
|
||||||
|
// Returns a pointer to the new iteration.
|
||||||
|
IterationState* IncrementIteration(TaggedNodeSeq* ready)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
// Activate all the deferred NextIteration nodes in a new iteration.
|
// Activate all the deferred NextIteration nodes in a new iteration.
|
||||||
void ActivateNexts(int64 iter, TaggedNodeSeq* ready)
|
void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
// Activate all the current loop invariants in a new iteration.
|
// Activate all the current loop invariants in a new iteration.
|
||||||
void ActivateLoopInvs(int64 iter, TaggedNodeSeq* ready)
|
void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
// Add a new loop invariant and make it available to all active
|
// Add a new loop invariant and make it available to all active
|
||||||
@ -322,12 +331,12 @@ class PropagatorState {
|
|||||||
|
|
||||||
// Activate the successors of a node. Contents of *outputs are left in an
|
// Activate the successors of a node. Contents of *outputs are left in an
|
||||||
// indeterminate state after returning from this method.
|
// indeterminate state after returning from this method.
|
||||||
void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
|
void ActivateNodes(const NodeItem* item, const bool is_dead,
|
||||||
EntryVector* outputs, TaggedNodeSeq* ready)
|
IterationState* iter_state, EntryVector* outputs,
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
// Cleanup iterations of this frame starting from iteration iter.
|
// Cleanup iterations of this frame starting from the given iteration.
|
||||||
bool CleanupIterations(int64 iter, TaggedNodeSeq* ready)
|
bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
void DumpIterationState(PropagatorState* parent) {
|
void DumpIterationState(PropagatorState* parent) {
|
||||||
@ -350,12 +359,12 @@ class PropagatorState {
|
|||||||
private:
|
private:
|
||||||
// REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
|
// REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`.
|
||||||
void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
|
void ActivateNodesFastPath(const NodeItem* item, const bool is_dead,
|
||||||
int64 iter, EntryVector* outputs,
|
IterationState* iter_state, EntryVector* outputs,
|
||||||
TaggedNodeSeq* ready)
|
TaggedNodeSeq* ready)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
|
|
||||||
void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
|
void ActivateNodesSlowPath(const NodeItem* item, const bool is_dead,
|
||||||
int64 iter, EntryVector* outputs,
|
IterationState* iter_state, EntryVector* outputs,
|
||||||
TaggedNodeSeq* ready)
|
TaggedNodeSeq* ready)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu);
|
||||||
};
|
};
|
||||||
@ -379,13 +388,13 @@ class PropagatorState {
|
|||||||
// same address while the iteration is live.
|
// same address while the iteration is live.
|
||||||
Entry* GetInputTensors(const TaggedNode& tagged_node) const
|
Entry* GetInputTensors(const TaggedNode& tagged_node) const
|
||||||
TF_NO_THREAD_SAFETY_ANALYSIS {
|
TF_NO_THREAD_SAFETY_ANALYSIS {
|
||||||
return tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
return tagged_node.input_iter->input_tensors +
|
||||||
->input_tensors +
|
|
||||||
tagged_node.node_item->input_start;
|
tagged_node.node_item->input_start;
|
||||||
}
|
}
|
||||||
|
|
||||||
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
|
FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
|
||||||
return {tagged_node.input_frame->frame_id, tagged_node.input_iter};
|
return {tagged_node.input_frame->frame_id,
|
||||||
|
tagged_node.input_iter->iter_num};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Provide debugging output of the state of the executor.
|
// Provide debugging output of the state of the executor.
|
||||||
@ -397,9 +406,8 @@ class PropagatorState {
|
|||||||
// optional debugging support.
|
// optional debugging support.
|
||||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||||
mutex_lock l(tagged_node.input_frame->mu);
|
mutex_lock l(tagged_node.input_frame->mu);
|
||||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
tagged_node.input_iter->mark_started(
|
||||||
->mark_started(
|
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -408,16 +416,15 @@ class PropagatorState {
|
|||||||
// optional debugging support.
|
// optional debugging support.
|
||||||
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
|
||||||
mutex_lock l(tagged_node.input_frame->mu);
|
mutex_lock l(tagged_node.input_frame->mu);
|
||||||
tagged_node.input_frame->GetIteration(tagged_node.input_iter)
|
tagged_node.input_iter->mark_completed(
|
||||||
->mark_completed(
|
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
||||||
immutable_state_.pending_ids()[tagged_node.node_item->node_id]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Find an existing or create a new child frame in the frame 'frame' at
|
// Find an existing or create a new child frame in the frame 'frame' at
|
||||||
// iteration 'iter'.
|
// iteration 'iter'.
|
||||||
void FindOrCreateChildFrame(FrameState* frame, int64 iter,
|
void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state,
|
||||||
const NodeItem& node_item, FrameState** child);
|
const NodeItem& node_item, FrameState** child);
|
||||||
|
|
||||||
// Delete a frame. Called when the frame is done.
|
// Delete a frame. Called when the frame is done.
|
||||||
@ -425,7 +432,7 @@ class PropagatorState {
|
|||||||
|
|
||||||
// Cleanup frames and iterations starting from frame/iter. Called when
|
// Cleanup frames and iterations starting from frame/iter. Called when
|
||||||
// a child frame is done.
|
// a child frame is done.
|
||||||
void CleanupFramesIterations(FrameState* frame, int64 iter,
|
void CleanupFramesIterations(FrameState* frame, IterationState* iter_state,
|
||||||
TaggedNodeSeq* ready);
|
TaggedNodeSeq* ready);
|
||||||
|
|
||||||
// Provide debugging output about an outstanding iteration in the executor.
|
// Provide debugging output about an outstanding iteration in the executor.
|
||||||
@ -450,6 +457,10 @@ class PropagatorState {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline int64 PropagatorState::TaggedNode::get_iter_num() const {
|
||||||
|
return input_iter->iter_num;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
|
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_
|
||||||
|
Loading…
Reference in New Issue
Block a user