[Executor] Optimize PropagatorState::FindOrCreateChildFrame().
At present, the `PropagatorState` must look up the "frame_name" and "parallel_iterations" attributes on the `Enter` node's `NodeDef` each time it propagates its outputs. This information is mostly cached already in the `ImmutableExecutorState`. This change makes the following optimizations: 1. Cache a `FrameInfo*` for each enter node and provide O(1) lookup using an `std::vector<FrameInfo*>`. 2. Add `FrameInfo::parallel_iterations`. 3. Perform `frame_name` and `parallel_iterations` resolution once at `ImmutableExecutorState` construction time. 4. Avoid building and storing `std::string FrameState::frame_name`, which is only used for verbose logging. Instead use the `uint64 FrameState::frame_id` as the key in all data structures, to optimize lookup. This is safe because we already depend on a lack of collisions between frame IDs (since the frame IDs are used in rendezvous keys when tensors are sent between devices). This change also modifies the "executor_test.cc" loop microbenchmarks to cover different numbers of loop variables (which, in the lowered case, directly translates to the number of "Enter" ops). PiperOrigin-RevId: 310467528 Change-Id: Id70d284e99fd7537156a3a5e10da827aea4791f9
This commit is contained in:
parent
804dd8af35
commit
1a03912767
@ -1212,6 +1212,7 @@ cc_library(
|
|||||||
":propagator_debug_utils",
|
":propagator_debug_utils",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/platform:hash",
|
||||||
"//tensorflow/core/profiler/lib:traceme",
|
"//tensorflow/core/profiler/lib:traceme",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -549,7 +549,8 @@ BENCHMARK(BM_FeedInputFetchOutput);
|
|||||||
//
|
//
|
||||||
// ...using the functional `WhileOp` (if `lower` is false) or the
|
// ...using the functional `WhileOp` (if `lower` is false) or the
|
||||||
// `Switch`/`Merge`-style of control flow (if `lower` is true).
|
// `Switch`/`Merge`-style of control flow (if `lower` is true).
|
||||||
static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
|
||||||
|
bool lower) {
|
||||||
testing::StopTiming();
|
testing::StopTiming();
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
|
||||||
@ -558,20 +559,44 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
|||||||
|
|
||||||
// Define the loop body as a function: `x = x + 1`.
|
// Define the loop body as a function: `x = x + 1`.
|
||||||
const Tensor one_t = test::AsScalar<int32>(1);
|
const Tensor one_t = test::AsScalar<int32>(1);
|
||||||
|
|
||||||
|
std::vector<string> args;
|
||||||
|
args.reserve(loop_vars);
|
||||||
|
args.push_back("x: int32");
|
||||||
|
for (int i = 1; i < loop_vars; ++i) {
|
||||||
|
args.push_back(strings::StrCat("x", i, ": int32"));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<string> body_rets;
|
||||||
|
body_rets.reserve(loop_vars);
|
||||||
|
body_rets.push_back("y: int32");
|
||||||
|
for (int i = 1; i < loop_vars; ++i) {
|
||||||
|
body_rets.push_back(strings::StrCat("y", i, ": int32"));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<FunctionDefHelper::Node> body_nodes;
|
||||||
|
body_nodes.reserve(1 + loop_vars);
|
||||||
|
body_nodes.push_back(
|
||||||
|
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}});
|
||||||
|
body_nodes.push_back({{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}});
|
||||||
|
for (int i = 1; i < loop_vars; ++i) {
|
||||||
|
body_nodes.push_back({{strings::StrCat("y", i)},
|
||||||
|
"Identity",
|
||||||
|
{strings::StrCat("x", i)},
|
||||||
|
{{"T", DT_INT32}}});
|
||||||
|
}
|
||||||
|
|
||||||
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
*f_lib_proto.add_function() = FunctionDefHelper::Define(
|
||||||
// Name
|
// Name
|
||||||
"XPlusOne",
|
"XPlusOne",
|
||||||
// Args
|
// Args
|
||||||
{"x: int32"},
|
args,
|
||||||
// Return values
|
// Return values
|
||||||
{"y: int32"},
|
body_rets,
|
||||||
// Attr def
|
// Attr def
|
||||||
{},
|
{},
|
||||||
// Nodes
|
// Nodes
|
||||||
{
|
body_nodes);
|
||||||
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}},
|
|
||||||
{{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}},
|
|
||||||
});
|
|
||||||
|
|
||||||
// Define the loop condition as a function: `x < loop_iters`.
|
// Define the loop condition as a function: `x < loop_iters`.
|
||||||
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
|
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
|
||||||
@ -579,7 +604,7 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
|||||||
// Name
|
// Name
|
||||||
"LessThanOrEqualToN",
|
"LessThanOrEqualToN",
|
||||||
// Args
|
// Args
|
||||||
{"x: int32"},
|
args,
|
||||||
// Return values
|
// Return values
|
||||||
{"z: bool"},
|
{"z: bool"},
|
||||||
// Attr def
|
// Attr def
|
||||||
@ -594,7 +619,12 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
|||||||
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
|
||||||
auto a = ops::Const(root.WithOpName("A"), 0, {});
|
auto a = ops::Const(root.WithOpName("A"), 0, {});
|
||||||
Node* while_node;
|
Node* while_node;
|
||||||
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
|
std::vector<NodeBuilder::NodeOut> inputs;
|
||||||
|
std::vector<DataType> input_types(loop_vars, DT_INT32);
|
||||||
|
inputs.reserve(loop_vars);
|
||||||
|
for (int i = 0; i < loop_vars; ++i) {
|
||||||
|
inputs.push_back(NodeBuilder::NodeOut(a.node()));
|
||||||
|
}
|
||||||
AttrValue int32_attr;
|
AttrValue int32_attr;
|
||||||
int32_attr.set_type(DT_INT32);
|
int32_attr.set_type(DT_INT32);
|
||||||
AttrValue cond_func;
|
AttrValue cond_func;
|
||||||
@ -604,7 +634,7 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
|||||||
TF_ASSERT_OK(
|
TF_ASSERT_OK(
|
||||||
NodeBuilder("while", "While", &root.graph()->flib_def())
|
NodeBuilder("while", "While", &root.graph()->flib_def())
|
||||||
.Input(inputs)
|
.Input(inputs)
|
||||||
.Attr("T", {DT_INT32})
|
.Attr("T", input_types)
|
||||||
.Attr("cond", cond_func)
|
.Attr("cond", cond_func)
|
||||||
.Attr("body", body_func)
|
.Attr("body", body_func)
|
||||||
.Attr("parallel_iterations", 100)
|
.Attr("parallel_iterations", 100)
|
||||||
@ -635,21 +665,33 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
|
|||||||
test::Benchmark("cpu", graph.release()).Run(iters);
|
test::Benchmark("cpu", graph.release()).Run(iters);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void BM_LoweredWhileLoop(int iters, int loop_iters) {
|
static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) {
|
||||||
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true);
|
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true);
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(0);
|
BENCHMARK(BM_LoweredWhileLoop)
|
||||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(1);
|
->ArgPair(0, 1)
|
||||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(10);
|
->ArgPair(1, 1)
|
||||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(100);
|
->ArgPair(10, 1)
|
||||||
BENCHMARK(BM_LoweredWhileLoop)->Arg(1000);
|
->ArgPair(100, 1)
|
||||||
|
->ArgPair(1000, 1)
|
||||||
|
->ArgPair(0, 100)
|
||||||
|
->ArgPair(1, 100)
|
||||||
|
->ArgPair(10, 100)
|
||||||
|
->ArgPair(100, 100)
|
||||||
|
->ArgPair(1000, 100);
|
||||||
|
|
||||||
static void BM_FunctionalWhileLoop(int iters, int loop_iters) {
|
static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) {
|
||||||
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false);
|
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false);
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(0);
|
BENCHMARK(BM_FunctionalWhileLoop)
|
||||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1);
|
->ArgPair(0, 1)
|
||||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(10);
|
->ArgPair(1, 1)
|
||||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(100);
|
->ArgPair(10, 1)
|
||||||
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000);
|
->ArgPair(100, 1)
|
||||||
|
->ArgPair(1000, 1)
|
||||||
|
->ArgPair(0, 100)
|
||||||
|
->ArgPair(1, 100)
|
||||||
|
->ArgPair(10, 100)
|
||||||
|
->ArgPair(100, 100)
|
||||||
|
->ArgPair(1000, 100);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|||||||
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/edgeset.h"
|
#include "tensorflow/core/graph/edgeset.h"
|
||||||
#include "tensorflow/core/graph/graph.h"
|
#include "tensorflow/core/graph/graph.h"
|
||||||
#include "tensorflow/core/graph/graph_node_util.h"
|
#include "tensorflow/core/graph/graph_node_util.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -39,9 +40,6 @@ ImmutableExecutorState::~ImmutableExecutorState() {
|
|||||||
params_.delete_kernel(item->kernel);
|
params_.delete_kernel(item->kernel);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto fiter : frame_info_) {
|
|
||||||
delete fiter.second;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -71,11 +69,16 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
|
|||||||
|
|
||||||
ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
|
ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
|
||||||
const string& fname) {
|
const string& fname) {
|
||||||
auto slot = &frame_info_[fname];
|
auto iter = frame_info_.find(fname);
|
||||||
if (*slot == nullptr) {
|
if (iter != frame_info_.end()) {
|
||||||
*slot = new FrameInfo;
|
return iter->second.get();
|
||||||
|
} else {
|
||||||
|
auto frame_info = absl::make_unique<FrameInfo>(fname);
|
||||||
|
absl::string_view fname_view = frame_info->name;
|
||||||
|
auto emplace_result =
|
||||||
|
frame_info_.emplace(fname_view, std::move(frame_info));
|
||||||
|
return emplace_result.first->second.get();
|
||||||
}
|
}
|
||||||
return *slot;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
||||||
@ -89,7 +92,7 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
|||||||
EnsureFrameInfo(it)->nodes =
|
EnsureFrameInfo(it)->nodes =
|
||||||
absl::make_unique<std::vector<const NodeItem*>>();
|
absl::make_unique<std::vector<const NodeItem*>>();
|
||||||
}
|
}
|
||||||
root_frame_info_ = frame_info_[""];
|
root_frame_info_ = frame_info_[""].get();
|
||||||
|
|
||||||
pending_ids_.resize(gview_.num_nodes());
|
pending_ids_.resize(gview_.num_nodes());
|
||||||
|
|
||||||
@ -157,6 +160,28 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
|
GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
|
||||||
item->is_constant_enter = is_constant_enter;
|
item->is_constant_enter = is_constant_enter;
|
||||||
|
|
||||||
|
string frame_name;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
|
||||||
|
FrameInfo* frame_info = frame_info_[frame_name].get();
|
||||||
|
|
||||||
|
int parallel_iterations;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GetNodeAttr(n->attrs(), "parallel_iterations", ¶llel_iterations));
|
||||||
|
|
||||||
|
if (frame_info->parallel_iterations == -1) {
|
||||||
|
frame_info->parallel_iterations = parallel_iterations;
|
||||||
|
} else if (frame_info->parallel_iterations != parallel_iterations) {
|
||||||
|
LOG(WARNING) << "Loop frame \"" << frame_name
|
||||||
|
<< "\" had two different values for parallel_iterations: "
|
||||||
|
<< frame_info->parallel_iterations << " vs. "
|
||||||
|
<< parallel_iterations << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enter_frame_info_.size() <= id) {
|
||||||
|
enter_frame_info_.resize(id + 1);
|
||||||
|
}
|
||||||
|
enter_frame_info_[id] = frame_info;
|
||||||
} else {
|
} else {
|
||||||
item->is_constant_enter = false;
|
item->is_constant_enter = false;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||||
#include "tensorflow/core/common_runtime/local_executor_params.h"
|
#include "tensorflow/core/common_runtime/local_executor_params.h"
|
||||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||||
@ -41,11 +42,16 @@ class Graph;
|
|||||||
class ImmutableExecutorState {
|
class ImmutableExecutorState {
|
||||||
public:
|
public:
|
||||||
struct FrameInfo {
|
struct FrameInfo {
|
||||||
FrameInfo()
|
explicit FrameInfo(string name)
|
||||||
: input_count(0),
|
: name(std::move(name)),
|
||||||
|
input_count(0),
|
||||||
total_inputs(0),
|
total_inputs(0),
|
||||||
pending_counts(nullptr),
|
pending_counts(nullptr),
|
||||||
nodes(nullptr) {}
|
nodes(nullptr),
|
||||||
|
parallel_iterations(-1) {}
|
||||||
|
|
||||||
|
// The name of the frame.
|
||||||
|
string name;
|
||||||
|
|
||||||
// The total number of inputs to a frame.
|
// The total number of inputs to a frame.
|
||||||
int input_count;
|
int input_count;
|
||||||
@ -63,6 +69,9 @@ class ImmutableExecutorState {
|
|||||||
|
|
||||||
// The nodes in a frame. Used only for debugging.
|
// The nodes in a frame. Used only for debugging.
|
||||||
std::unique_ptr<std::vector<const NodeItem*>> nodes;
|
std::unique_ptr<std::vector<const NodeItem*>> nodes;
|
||||||
|
|
||||||
|
// The number of iterations of this frame that can execute concurrently.
|
||||||
|
int32 parallel_iterations;
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit ImmutableExecutorState(const LocalExecutorParams& p)
|
explicit ImmutableExecutorState(const LocalExecutorParams& p)
|
||||||
@ -83,17 +92,13 @@ class ImmutableExecutorState {
|
|||||||
}
|
}
|
||||||
const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; }
|
const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; }
|
||||||
|
|
||||||
const FrameInfo* get_frame_info(const string& frame_name) const {
|
|
||||||
auto it_frame_info = frame_info_.find(frame_name);
|
|
||||||
if (it_frame_info == frame_info_.end()) {
|
|
||||||
return nullptr;
|
|
||||||
} else {
|
|
||||||
return it_frame_info->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
|
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
|
||||||
|
|
||||||
|
const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const {
|
||||||
|
DCHECK(node_item.is_enter);
|
||||||
|
return *enter_frame_info_[node_item.node_id];
|
||||||
|
}
|
||||||
|
|
||||||
bool requires_control_flow_support() const { return requires_control_flow_; }
|
bool requires_control_flow_support() const { return requires_control_flow_; }
|
||||||
|
|
||||||
// Copies the pending counts for nodes in this graph to the given array.
|
// Copies the pending counts for nodes in this graph to the given array.
|
||||||
@ -135,9 +140,14 @@ class ImmutableExecutorState {
|
|||||||
// Mapping from frame name to static information about the frame.
|
// Mapping from frame name to static information about the frame.
|
||||||
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
// TODO(yuanbyu): We could cache it along with the graph so to avoid
|
||||||
// the overhead of constructing it for each executor instance.
|
// the overhead of constructing it for each executor instance.
|
||||||
gtl::FlatMap<string, FrameInfo*> frame_info_;
|
absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>>
|
||||||
|
frame_info_;
|
||||||
const FrameInfo* root_frame_info_; // Not owned.
|
const FrameInfo* root_frame_info_; // Not owned.
|
||||||
|
|
||||||
|
// If the graph contains any "Enter" or "RefEnter" nodes, this vector maps
|
||||||
|
// dense node IDs to the corresponding FrameInfo.
|
||||||
|
std::vector<FrameInfo*> enter_frame_info_;
|
||||||
|
|
||||||
// If `requires_control_flow_` is false, this points to an array of initial
|
// If `requires_control_flow_` is false, this points to an array of initial
|
||||||
// pending counts for the nodes in the graph, indexed by node ID.
|
// pending counts for the nodes in the graph, indexed by node ID.
|
||||||
std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;
|
std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;
|
||||||
|
|||||||
@ -16,9 +16,11 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/propagator_state.h"
|
#include "tensorflow/core/common_runtime/propagator_state.h"
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||||
|
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
|
||||||
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
#include "tensorflow/core/platform/hash.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -33,14 +35,14 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
|
|||||||
// We assume root_frame_->frame_name.empty().
|
// We assume root_frame_->frame_name.empty().
|
||||||
root_frame_ = new FrameState(immutable_state_, 1);
|
root_frame_ = new FrameState(immutable_state_, 1);
|
||||||
root_frame_->frame_id = 0; // must be 0
|
root_frame_->frame_id = 0; // must be 0
|
||||||
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
|
root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info());
|
||||||
|
|
||||||
// Initialize iteration 0.
|
// Initialize iteration 0.
|
||||||
root_frame_->SetIteration(
|
root_frame_->SetIteration(
|
||||||
0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
|
0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
|
||||||
root_frame_->total_input_tensors));
|
root_frame_->total_input_tensors));
|
||||||
|
|
||||||
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
|
outstanding_frames_.emplace(root_frame_->frame_id, root_frame_);
|
||||||
}
|
}
|
||||||
|
|
||||||
PropagatorState::~PropagatorState() {
|
PropagatorState::~PropagatorState() {
|
||||||
@ -224,16 +226,16 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
|||||||
const NodeItem& node_item,
|
const NodeItem& node_item,
|
||||||
FrameState** child) {
|
FrameState** child) {
|
||||||
// Get the child frame name.
|
// Get the child frame name.
|
||||||
AttrSlice attrs(node_item.kernel->def());
|
const ImmutableExecutorState::FrameInfo& frame_info =
|
||||||
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
|
immutable_state_.get_enter_frame_info(node_item);
|
||||||
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
|
|
||||||
<< node_item.kernel->name();
|
const uint64 child_id = Hash64Combine(
|
||||||
const string child_name = strings::StrCat(
|
frame->frame_id,
|
||||||
frame->frame_name, ";", iter_state->iter_num, ";", enter_name);
|
Hash64Combine(iter_state->iter_num, Hash64(frame_info.name)));
|
||||||
|
|
||||||
{
|
{
|
||||||
mutex_lock executor_lock(mu_);
|
tf_shared_lock executor_lock(mu_);
|
||||||
auto it = outstanding_frames_.find(child_name);
|
auto it = outstanding_frames_.find(child_id);
|
||||||
if (it != outstanding_frames_.end()) {
|
if (it != outstanding_frames_.end()) {
|
||||||
*child = it->second;
|
*child = it->second;
|
||||||
return;
|
return;
|
||||||
@ -242,20 +244,18 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
|||||||
|
|
||||||
// Need to create a new frame instance.
|
// Need to create a new frame instance.
|
||||||
// Note that this new frame instance is created without any locks.
|
// Note that this new frame instance is created without any locks.
|
||||||
if (vlog_) VLOG(2) << "Create frame: " << child_name;
|
if (vlog_) {
|
||||||
|
const string child_name = strings::StrCat(
|
||||||
|
frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name);
|
||||||
|
VLOG(2) << "Create frame: " << child_name << " id: " << child_id;
|
||||||
|
}
|
||||||
|
|
||||||
int parallel_iters;
|
FrameState* temp =
|
||||||
bool found_parallel_iters =
|
new FrameState(immutable_state_, frame_info.parallel_iterations);
|
||||||
TryGetNodeAttr(attrs, "parallel_iterations", ¶llel_iters);
|
temp->frame_id = child_id;
|
||||||
DCHECK(found_parallel_iters)
|
|
||||||
<< "Could not find \"parallel_iterations\" attr in node "
|
|
||||||
<< node_item.kernel->name();
|
|
||||||
FrameState* temp = new FrameState(immutable_state_, parallel_iters);
|
|
||||||
temp->frame_name = child_name;
|
|
||||||
temp->frame_id = Hash64(child_name);
|
|
||||||
temp->parent_frame = frame;
|
temp->parent_frame = frame;
|
||||||
temp->parent_iter = iter_state;
|
temp->parent_iter = iter_state;
|
||||||
temp->InitializeFrameInfo(enter_name);
|
temp->InitializeFrameInfo(frame_info);
|
||||||
|
|
||||||
// Initialize iteration 0.
|
// Initialize iteration 0.
|
||||||
{
|
{
|
||||||
@ -266,13 +266,13 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
|
|||||||
|
|
||||||
{
|
{
|
||||||
mutex_lock executor_lock(mu_);
|
mutex_lock executor_lock(mu_);
|
||||||
auto it = outstanding_frames_.find(child_name);
|
auto it = outstanding_frames_.find(child_id);
|
||||||
if (it != outstanding_frames_.end()) {
|
if (it != outstanding_frames_.end()) {
|
||||||
*child = it->second;
|
*child = it->second;
|
||||||
} else {
|
} else {
|
||||||
mutex_lock frame_lock(frame->mu);
|
mutex_lock frame_lock(frame->mu);
|
||||||
iter_state->outstanding_frame_count++;
|
iter_state->outstanding_frame_count++;
|
||||||
outstanding_frames_[child_name] = temp;
|
outstanding_frames_[child_id] = temp;
|
||||||
*child = temp;
|
*child = temp;
|
||||||
temp = nullptr;
|
temp = nullptr;
|
||||||
}
|
}
|
||||||
@ -349,11 +349,10 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete the frame.
|
// Delete the frame.
|
||||||
const string& frame_name = frame->frame_name;
|
if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id;
|
||||||
if (vlog_) VLOG(2) << "Delete frame " << frame_name;
|
|
||||||
{
|
{
|
||||||
mutex_lock executor_lock(mu_);
|
mutex_lock executor_lock(mu_);
|
||||||
outstanding_frames_.erase(frame_name);
|
outstanding_frames_.erase(frame->frame_id);
|
||||||
}
|
}
|
||||||
delete frame;
|
delete frame;
|
||||||
}
|
}
|
||||||
@ -655,14 +654,11 @@ bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::InitializeFrameInfo(
|
void PropagatorState::FrameState::InitializeFrameInfo(
|
||||||
const string& enter_name) {
|
const ImmutableExecutorState::FrameInfo& finfo) {
|
||||||
const ImmutableExecutorState::FrameInfo* finfo =
|
pending_counts = finfo.pending_counts.get();
|
||||||
immutable_state.get_frame_info(enter_name);
|
total_input_tensors = finfo.total_inputs;
|
||||||
DCHECK_NE(finfo, nullptr);
|
num_pending_inputs = finfo.input_count;
|
||||||
pending_counts = finfo->pending_counts.get();
|
nodes = finfo.nodes.get();
|
||||||
total_input_tensors = finfo->total_inputs;
|
|
||||||
num_pending_inputs = finfo->input_count;
|
|
||||||
nodes = finfo->nodes.get();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PropagatorState::FrameState::SetIteration(int64 iter,
|
void PropagatorState::FrameState::SetIteration(int64 iter,
|
||||||
|
|||||||
@ -279,7 +279,7 @@ class PropagatorState {
|
|||||||
// during structured traversal: parent_frame->mu < mu.
|
// during structured traversal: parent_frame->mu < mu.
|
||||||
mutex mu;
|
mutex mu;
|
||||||
|
|
||||||
void InitializeFrameInfo(const string& enter_name);
|
void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo);
|
||||||
|
|
||||||
inline IterationState* GetIteration(int64 iter)
|
inline IterationState* GetIteration(int64 iter)
|
||||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
|
||||||
@ -447,12 +447,13 @@ class PropagatorState {
|
|||||||
// The root frame in which the execution of this step is started.
|
// The root frame in which the execution of this step is started.
|
||||||
FrameState* root_frame_;
|
FrameState* root_frame_;
|
||||||
|
|
||||||
// Mapping from frame name to outstanding frames. A new frame is created
|
// Mapping from frame ID to outstanding frames. A new frame is created
|
||||||
// at some iteration of an active frame. So the unique key for the new
|
// at some iteration of an active frame. So the unique key for the new
|
||||||
// child frame is composed of the name of the parent frame, the iteration
|
// child frame is a hash composed of the ID of the parent frame, the iteration
|
||||||
// number at which the parent frame is creating the new frame, and the
|
// number at which the parent frame is creating the new frame, and the
|
||||||
// name of the new frame from nodedef.
|
// name of the new frame from nodedef.
|
||||||
gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
|
absl::flat_hash_map<uint64, FrameState*> outstanding_frames_
|
||||||
|
TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user