[Executor] Optimize PropagatorState::FindOrCreateChildFrame().

At present, the `PropagatorState` must look up the "frame_name" and "parallel_iterations" attributes on the `Enter` node's `NodeDef` each time it propagates its outputs. This information is mostly cached already in the `ImmutableExecutorState`.

This change makes the following optimizations:

1. Cache a `FrameInfo*` for each enter node and provide O(1) lookup using an `std::vector<FrameInfo*>`.
2. Add `FrameInfo::parallel_iterations`.
3. Perform `frame_name` and `parallel_iterations` resolution once at `ImmutableExecutorState` construction time.
4. Avoid building and storing `std::string FrameState::frame_name`, which is only used for verbose logging. Instead use the `uint64 FrameState::frame_id` as the key in all data structures, to optimize lookup. This is safe because we already depend on a lack of collisions between frame IDs (since the frame IDs are used in rendezvous keys when tensors are sent between devices).

This change also modifies the "executor_test.cc" loop microbenchmarks to cover different numbers of loop variables (which, in the lowered case, directly translates to the number of "Enter" ops).

PiperOrigin-RevId: 310467528
Change-Id: Id70d284e99fd7537156a3a5e10da827aea4791f9
This commit is contained in:
Derek Murray 2020-05-07 17:21:16 -07:00 committed by TensorFlower Gardener
parent 804dd8af35
commit 1a03912767
6 changed files with 158 additions and 83 deletions

View File

@ -1212,6 +1212,7 @@ cc_library(
":propagator_debug_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:hash",
"//tensorflow/core/profiler/lib:traceme",
],
)

View File

@ -549,7 +549,8 @@ BENCHMARK(BM_FeedInputFetchOutput);
//
// ...using the functional `WhileOp` (if `lower` is false) or the
// `Switch`/`Merge`-style of control flow (if `lower` is true).
static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
static void BM_WhileLoopHelper(int iters, int loop_iters, int loop_vars,
bool lower) {
testing::StopTiming();
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
@ -558,20 +559,44 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
// Define the loop body as a function: `x = x + 1`.
const Tensor one_t = test::AsScalar<int32>(1);
std::vector<string> args;
args.reserve(loop_vars);
args.push_back("x: int32");
for (int i = 1; i < loop_vars; ++i) {
args.push_back(strings::StrCat("x", i, ": int32"));
}
std::vector<string> body_rets;
body_rets.reserve(loop_vars);
body_rets.push_back("y: int32");
for (int i = 1; i < loop_vars; ++i) {
body_rets.push_back(strings::StrCat("y", i, ": int32"));
}
std::vector<FunctionDefHelper::Node> body_nodes;
body_nodes.reserve(1 + loop_vars);
body_nodes.push_back(
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}});
body_nodes.push_back({{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}});
for (int i = 1; i < loop_vars; ++i) {
body_nodes.push_back({{strings::StrCat("y", i)},
"Identity",
{strings::StrCat("x", i)},
{{"T", DT_INT32}}});
}
*f_lib_proto.add_function() = FunctionDefHelper::Define(
// Name
"XPlusOne",
// Args
{"x: int32"},
args,
// Return values
{"y: int32"},
body_rets,
// Attr def
{},
// Nodes
{
{{"one"}, "Const", {}, {{"value", one_t}, {"dtype", DT_INT32}}},
{{"y"}, "Add", {"x", "one"}, {{"T", DT_INT32}}},
});
body_nodes);
// Define the loop condition as a function: `x < loop_iters`.
const Tensor loop_iters_t = test::AsScalar<int32>(loop_iters);
@ -579,7 +604,7 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
// Name
"LessThanOrEqualToN",
// Args
{"x: int32"},
args,
// Return values
{"z: bool"},
// Attr def
@ -594,7 +619,12 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
auto a = ops::Const(root.WithOpName("A"), 0, {});
Node* while_node;
std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
std::vector<NodeBuilder::NodeOut> inputs;
std::vector<DataType> input_types(loop_vars, DT_INT32);
inputs.reserve(loop_vars);
for (int i = 0; i < loop_vars; ++i) {
inputs.push_back(NodeBuilder::NodeOut(a.node()));
}
AttrValue int32_attr;
int32_attr.set_type(DT_INT32);
AttrValue cond_func;
@ -604,7 +634,7 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
TF_ASSERT_OK(
NodeBuilder("while", "While", &root.graph()->flib_def())
.Input(inputs)
.Attr("T", {DT_INT32})
.Attr("T", input_types)
.Attr("cond", cond_func)
.Attr("body", body_func)
.Attr("parallel_iterations", 100)
@ -635,21 +665,33 @@ static void BM_WhileLoopHelper(int iters, int loop_iters, bool lower) {
test::Benchmark("cpu", graph.release()).Run(iters);
}
static void BM_LoweredWhileLoop(int iters, int loop_iters) {
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ true);
static void BM_LoweredWhileLoop(int iters, int loop_iters, int loop_vars) {
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ true);
}
BENCHMARK(BM_LoweredWhileLoop)->Arg(0);
BENCHMARK(BM_LoweredWhileLoop)->Arg(1);
BENCHMARK(BM_LoweredWhileLoop)->Arg(10);
BENCHMARK(BM_LoweredWhileLoop)->Arg(100);
BENCHMARK(BM_LoweredWhileLoop)->Arg(1000);
BENCHMARK(BM_LoweredWhileLoop)
->ArgPair(0, 1)
->ArgPair(1, 1)
->ArgPair(10, 1)
->ArgPair(100, 1)
->ArgPair(1000, 1)
->ArgPair(0, 100)
->ArgPair(1, 100)
->ArgPair(10, 100)
->ArgPair(100, 100)
->ArgPair(1000, 100);
static void BM_FunctionalWhileLoop(int iters, int loop_iters) {
BM_WhileLoopHelper(iters, loop_iters, /* lower= */ false);
static void BM_FunctionalWhileLoop(int iters, int loop_iters, int loop_vars) {
BM_WhileLoopHelper(iters, loop_iters, loop_vars, /* lower= */ false);
}
BENCHMARK(BM_FunctionalWhileLoop)->Arg(0);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(10);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(100);
BENCHMARK(BM_FunctionalWhileLoop)->Arg(1000);
BENCHMARK(BM_FunctionalWhileLoop)
->ArgPair(0, 1)
->ArgPair(1, 1)
->ArgPair(10, 1)
->ArgPair(100, 1)
->ArgPair(1000, 1)
->ArgPair(0, 100)
->ArgPair(1, 100)
->ArgPair(10, 100)
->ArgPair(100, 100)
->ArgPair(1000, 100);
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/graph/edgeset.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_node_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@ -39,9 +40,6 @@ ImmutableExecutorState::~ImmutableExecutorState() {
params_.delete_kernel(item->kernel);
}
}
for (auto fiter : frame_info_) {
delete fiter.second;
}
}
namespace {
@ -71,11 +69,16 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending,
ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
const string& fname) {
auto slot = &frame_info_[fname];
if (*slot == nullptr) {
*slot = new FrameInfo;
auto iter = frame_info_.find(fname);
if (iter != frame_info_.end()) {
return iter->second.get();
} else {
auto frame_info = absl::make_unique<FrameInfo>(fname);
absl::string_view fname_view = frame_info->name;
auto emplace_result =
frame_info_.emplace(fname_view, std::move(frame_info));
return emplace_result.first->second.get();
}
return *slot;
}
Status ImmutableExecutorState::Initialize(const Graph& graph) {
@ -89,7 +92,7 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
EnsureFrameInfo(it)->nodes =
absl::make_unique<std::vector<const NodeItem*>>();
}
root_frame_info_ = frame_info_[""];
root_frame_info_ = frame_info_[""].get();
pending_ids_.resize(gview_.num_nodes());
@ -157,6 +160,28 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) {
TF_RETURN_IF_ERROR(
GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
item->is_constant_enter = is_constant_enter;
string frame_name;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
FrameInfo* frame_info = frame_info_[frame_name].get();
int parallel_iterations;
TF_RETURN_IF_ERROR(
GetNodeAttr(n->attrs(), "parallel_iterations", &parallel_iterations));
if (frame_info->parallel_iterations == -1) {
frame_info->parallel_iterations = parallel_iterations;
} else if (frame_info->parallel_iterations != parallel_iterations) {
LOG(WARNING) << "Loop frame \"" << frame_name
<< "\" had two different values for parallel_iterations: "
<< frame_info->parallel_iterations << " vs. "
<< parallel_iterations << ".";
}
if (enter_frame_info_.size() <= id) {
enter_frame_info_.resize(id + 1);
}
enter_frame_info_[id] = frame_info;
} else {
item->is_constant_enter = false;
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/common_runtime/local_executor_params.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
@ -41,11 +42,16 @@ class Graph;
class ImmutableExecutorState {
public:
struct FrameInfo {
FrameInfo()
: input_count(0),
explicit FrameInfo(string name)
: name(std::move(name)),
input_count(0),
total_inputs(0),
pending_counts(nullptr),
nodes(nullptr) {}
nodes(nullptr),
parallel_iterations(-1) {}
// The name of the frame.
string name;
// The total number of inputs to a frame.
int input_count;
@ -63,6 +69,9 @@ class ImmutableExecutorState {
// The nodes in a frame. Used only for debugging.
std::unique_ptr<std::vector<const NodeItem*>> nodes;
// The number of iterations of this frame that can execute concurrently.
int32 parallel_iterations;
};
explicit ImmutableExecutorState(const LocalExecutorParams& p)
@ -83,17 +92,13 @@ class ImmutableExecutorState {
}
const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; }
const FrameInfo* get_frame_info(const string& frame_name) const {
auto it_frame_info = frame_info_.find(frame_name);
if (it_frame_info == frame_info_.end()) {
return nullptr;
} else {
return it_frame_info->second;
}
}
const FrameInfo& get_root_frame_info() const { return *root_frame_info_; }
const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const {
DCHECK(node_item.is_enter);
return *enter_frame_info_[node_item.node_id];
}
bool requires_control_flow_support() const { return requires_control_flow_; }
// Copies the pending counts for nodes in this graph to the given array.
@ -135,9 +140,14 @@ class ImmutableExecutorState {
// Mapping from frame name to static information about the frame.
// TODO(yuanbyu): We could cache it along with the graph so to avoid
// the overhead of constructing it for each executor instance.
gtl::FlatMap<string, FrameInfo*> frame_info_;
absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>>
frame_info_;
const FrameInfo* root_frame_info_; // Not owned.
// If the graph contains any "Enter" or "RefEnter" nodes, this vector maps
// dense node IDs to the corresponding FrameInfo.
std::vector<FrameInfo*> enter_frame_info_;
// If `requires_control_flow_` is false, this points to an array of initial
// pending counts for the nodes in the graph, indexed by node ID.
std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_;

View File

@ -16,9 +16,11 @@ limitations under the License.
#include "tensorflow/core/common_runtime/propagator_state.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/common_runtime/immutable_executor_state.h"
#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/hash.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace tensorflow {
@ -33,14 +35,14 @@ PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
// We assume root_frame_->frame_name.empty().
root_frame_ = new FrameState(immutable_state_, 1);
root_frame_->frame_id = 0; // must be 0
root_frame_->InitializeFrameInfo(root_frame_->frame_name);
root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info());
// Initialize iteration 0.
root_frame_->SetIteration(
0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
root_frame_->total_input_tensors));
outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
outstanding_frames_.emplace(root_frame_->frame_id, root_frame_);
}
PropagatorState::~PropagatorState() {
@ -224,16 +226,16 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
const NodeItem& node_item,
FrameState** child) {
// Get the child frame name.
AttrSlice attrs(node_item.kernel->def());
const string& enter_name = GetNodeAttrString(attrs, "frame_name");
DCHECK(!enter_name.empty()) << "Could not find \"frame_name\" attr in node "
<< node_item.kernel->name();
const string child_name = strings::StrCat(
frame->frame_name, ";", iter_state->iter_num, ";", enter_name);
const ImmutableExecutorState::FrameInfo& frame_info =
immutable_state_.get_enter_frame_info(node_item);
const uint64 child_id = Hash64Combine(
frame->frame_id,
Hash64Combine(iter_state->iter_num, Hash64(frame_info.name)));
{
mutex_lock executor_lock(mu_);
auto it = outstanding_frames_.find(child_name);
tf_shared_lock executor_lock(mu_);
auto it = outstanding_frames_.find(child_id);
if (it != outstanding_frames_.end()) {
*child = it->second;
return;
@ -242,20 +244,18 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
// Need to create a new frame instance.
// Note that this new frame instance is created without any locks.
if (vlog_) VLOG(2) << "Create frame: " << child_name;
if (vlog_) {
const string child_name = strings::StrCat(
frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name);
VLOG(2) << "Create frame: " << child_name << " id: " << child_id;
}
int parallel_iters;
bool found_parallel_iters =
TryGetNodeAttr(attrs, "parallel_iterations", &parallel_iters);
DCHECK(found_parallel_iters)
<< "Could not find \"parallel_iterations\" attr in node "
<< node_item.kernel->name();
FrameState* temp = new FrameState(immutable_state_, parallel_iters);
temp->frame_name = child_name;
temp->frame_id = Hash64(child_name);
FrameState* temp =
new FrameState(immutable_state_, frame_info.parallel_iterations);
temp->frame_id = child_id;
temp->parent_frame = frame;
temp->parent_iter = iter_state;
temp->InitializeFrameInfo(enter_name);
temp->InitializeFrameInfo(frame_info);
// Initialize iteration 0.
{
@ -266,13 +266,13 @@ void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
{
mutex_lock executor_lock(mu_);
auto it = outstanding_frames_.find(child_name);
auto it = outstanding_frames_.find(child_id);
if (it != outstanding_frames_.end()) {
*child = it->second;
} else {
mutex_lock frame_lock(frame->mu);
iter_state->outstanding_frame_count++;
outstanding_frames_[child_name] = temp;
outstanding_frames_[child_id] = temp;
*child = temp;
temp = nullptr;
}
@ -349,11 +349,10 @@ void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
}
// Delete the frame.
const string& frame_name = frame->frame_name;
if (vlog_) VLOG(2) << "Delete frame " << frame_name;
if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id;
{
mutex_lock executor_lock(mu_);
outstanding_frames_.erase(frame_name);
outstanding_frames_.erase(frame->frame_id);
}
delete frame;
}
@ -655,14 +654,11 @@ bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
}
void PropagatorState::FrameState::InitializeFrameInfo(
const string& enter_name) {
const ImmutableExecutorState::FrameInfo* finfo =
immutable_state.get_frame_info(enter_name);
DCHECK_NE(finfo, nullptr);
pending_counts = finfo->pending_counts.get();
total_input_tensors = finfo->total_inputs;
num_pending_inputs = finfo->input_count;
nodes = finfo->nodes.get();
const ImmutableExecutorState::FrameInfo& finfo) {
pending_counts = finfo.pending_counts.get();
total_input_tensors = finfo.total_inputs;
num_pending_inputs = finfo.input_count;
nodes = finfo.nodes.get();
}
void PropagatorState::FrameState::SetIteration(int64 iter,

View File

@ -279,7 +279,7 @@ class PropagatorState {
// during structured traversal: parent_frame->mu < mu.
mutex mu;
void InitializeFrameInfo(const string& enter_name);
void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo);
inline IterationState* GetIteration(int64 iter)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
@ -447,12 +447,13 @@ class PropagatorState {
// The root frame in which the execution of this step is started.
FrameState* root_frame_;
// Mapping from frame name to outstanding frames. A new frame is created
// Mapping from frame ID to outstanding frames. A new frame is created
// at some iteration of an active frame. So the unique key for the new
// child frame is composed of the name of the parent frame, the iteration
// child frame is a hash composed of the ID of the parent frame, the iteration
// number at which the parent frame is creating the new frame, and the
// name of the new frame from nodedef.
gtl::FlatMap<string, FrameState*> outstanding_frames_ TF_GUARDED_BY(mu_);
absl::flat_hash_map<uint64, FrameState*> outstanding_frames_
TF_GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState);
};