Address review comments.

This commit is contained in:
Trent Lo 2019-05-17 23:34:00 -07:00
parent e1f36241fc
commit 3f7afc54bd
2 changed files with 59 additions and 52 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/status_macros.h"
@ -648,9 +649,8 @@ Predicate* PredicateFactory::MakeAndOrImpl(
std::vector<Predicate*> to_add; std::vector<Predicate*> to_add;
for (Predicate* op : simplified_ops) { for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kAndRecurrence) { if (op->kind() == Predicate::Kind::kAndRecurrence) {
AndRecurrencePredicate* and_rec = auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
static_cast<AndRecurrencePredicate*>(op); if (negated_ops.contains(and_rec->step())) {
if (negated_ops.count(and_rec->step())) {
// Remove and_rec and ~X and insert S. Note that checking the // Remove and_rec and ~X and insert S. Note that checking the
// existence of ~X through negated_ops is sufficient since it makes // existence of ~X through negated_ops is sufficient since it makes
// sure the predicate is in the input operands. It does not need to // sure the predicate is in the input operands. It does not need to
@ -663,7 +663,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
} }
auto it = simplified_ops.begin(); auto it = simplified_ops.begin();
while (it != simplified_ops.end()) { while (it != simplified_ops.end()) {
if (to_remove.count(*it)) { if (to_remove.contains(*it)) {
it = simplified_ops.erase(it); it = simplified_ops.erase(it);
} else { } else {
++it; ++it;
@ -736,8 +736,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {} : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate(bool force_pessimistic = false); Status Populate(bool force_pessimistic = false);
Status PopulateFrame(absl::Span<Node* const> tpo, bool use_optimistic_mode, Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
bool* is_success); bool* success);
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor( StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const override; Node* n, int oidx) const override;
void Print() const override; void Print() const override;
@ -955,11 +955,13 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
return Status::OK(); return Status::OK();
} }
// If the node is inside some frames, get the name of the outermost non-empty
// frame. Otherwise, get an empty frame name.
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos, Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
const Node* src_node, string* frame) { absl::string_view frame) {
int depth = 0; int depth = 0;
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
while (cfi_iter->parent_frame != src_node) { while (!cfi_iter->parent_frame->IsSource()) {
n = cfi_iter->parent_frame; n = cfi_iter->parent_frame;
cfi_iter = &cfi_infos[n->id()]; cfi_iter = &cfi_infos[n->id()];
@ -970,21 +972,23 @@ Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
} }
} }
*frame = cfi_iter->frame_name; frame = cfi_iter->frame_name;
return Status::OK(); return Status::OK();
} }
// Compute a special topological order for the Graph, where nodes having the // Compute a special topological order for the Graph, where nodes having the
// same root frame are placed adjacent to each other. // same root frame are placed adjacent to each other. The traversal is a
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
// many inputs of each node are ready; a node is ready to be scheduled if all
// of its inputs are ready.
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
Status GetFrameBasedTopologicalOrder(const Graph* g, Status GetFrameBasedTopologicalOrder(const Graph* g,
absl::Span<const ControlFlowInfo> cf_infos, absl::Span<const ControlFlowInfo> cf_infos,
std::vector<Node*>* order) { std::vector<Node*>* order) {
absl::flat_hash_map<string, size_t> num_enters; absl::flat_hash_map<absl::string_view, size_t> num_enters;
absl::flat_hash_map<string, size_t> num_exits; absl::flat_hash_map<absl::string_view, size_t> num_exits;
std::vector<size_t> num_ready_inputs(g->num_node_ids(), 0); std::vector<size_t> num_ready_inputs(g->num_node_ids(), 0);
Node* src_node = g->source_node(); Node* src_node = g->source_node();
std::deque<Node*> ready;
ready.push_back(src_node);
for (const auto* node : g->op_nodes()) { for (const auto* node : g->op_nodes()) {
const ControlFlowInfo& cf = cf_infos[node->id()]; const ControlFlowInfo& cf = cf_infos[node->id()];
bool is_root_level = cf.parent_frame == src_node; bool is_root_level = cf.parent_frame == src_node;
@ -995,6 +999,8 @@ Status GetFrameBasedTopologicalOrder(const Graph* g,
} else if (IsExit(node) && is_root_level) { } else if (IsExit(node) && is_root_level) {
++num_exits[cf.frame_name]; ++num_exits[cf.frame_name];
} }
// Edge NextIteration->Merge is counted before starting the traveral to
// break the backedges.
if (IsMerge(node)) { if (IsMerge(node)) {
for (const Edge* e : node->in_edges()) { for (const Edge* e : node->in_edges()) {
if (IsNextIteration(e->src())) { if (IsNextIteration(e->src())) {
@ -1004,6 +1010,8 @@ Status GetFrameBasedTopologicalOrder(const Graph* g,
} }
} }
std::deque<Node*> ready;
ready.push_back(src_node);
absl::flat_hash_map<string, std::vector<Node*>> staging_enter_vecs; absl::flat_hash_map<string, std::vector<Node*>> staging_enter_vecs;
// Exit nodes shall all be from the same frame, as we process a frame at a // Exit nodes shall all be from the same frame, as we process a frame at a
// time. So, one vector is enough. // time. So, one vector is enough.
@ -1044,7 +1052,7 @@ Status GetFrameBasedTopologicalOrder(const Graph* g,
// make sure all nodes in the currently processing frame are visited // make sure all nodes in the currently processing frame are visited
// before starting processing other frames. // before starting processing other frames.
string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name; string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name;
CHECK(staging_exit_vec.size() == num_exits[frame_name]); CHECK_EQ(staging_exit_vec.size(), num_exits[frame_name]);
ready.insert(ready.end(), staging_exit_vec.begin(), ready.insert(ready.end(), staging_exit_vec.begin(),
staging_exit_vec.end()); staging_exit_vec.end());
staging_exit_vec.clear(); staging_exit_vec.clear();
@ -1098,8 +1106,8 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// does not affect the result after pattern-matching into the // does not affect the result after pattern-matching into the
// AndRecurrence form. // AndRecurrence form.
string frame_name = control_flow_info_[n->id()].frame_name; string frame_name = control_flow_info_[n->id()].frame_name;
auto ret = frame_to_merge_node_.insert({frame_name, n}); auto insert_result = frame_to_merge_node_.insert({frame_name, n});
Node* representative = ret.first->second; Node* representative = insert_result.first->second;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
representative, /*output_idx=*/0, /*must_be_true=*/false, representative, /*output_idx=*/0, /*must_be_true=*/false,
&input_data_pred)); &input_data_pred));
@ -1240,52 +1248,51 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) {
absl::StrJoin(unreachable_nodes, ", ")); absl::StrJoin(unreachable_nodes, ", "));
} }
std::vector<Node*> tpo; std::vector<Node*> topo;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &tpo)); GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &topo));
size_t frame_start = 0; size_t frame_start = 0;
for (size_t i = 0; i < tpo.size(); ++i) { for (size_t i = 0; i < topo.size(); ++i) {
// Collect nodes until we see a node who has a different root frame. // Collect nodes until we see a node who has a different root frame.
if (i != tpo.size() - 1) { if (i != topo.size() - 1) {
string i_frame_name, next_frame_name; absl::string_view i_frame_name, next_frame_name;
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i], control_flow_info_, TF_RETURN_IF_ERROR(GetRootFrame(topo[i], control_flow_info_,
graph_.source_node(), &i_frame_name)); i_frame_name));
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i + 1], control_flow_info_, TF_RETURN_IF_ERROR(GetRootFrame(topo[i + 1], control_flow_info_,
graph_.source_node(), &next_frame_name)); next_frame_name));
if (i_frame_name == next_frame_name) { if (i_frame_name == next_frame_name) {
continue; continue;
} }
} }
string frame_name = control_flow_info_[tpo[i]->id()].frame_name; string frame_name = control_flow_info_[topo[i]->id()].frame_name;
absl::Span<Node*> sub_tpo(tpo.data() + frame_start, i - frame_start + 1); absl::Span<Node*> sub_topo(topo.data() + frame_start,
/*length=*/i - frame_start + 1);
frame_start = i + 1; frame_start = i + 1;
// First, try the optimistic mode. // First, try the optimistic mode.
bool is_success = false; bool success = false;
if (!force_pessimistic && !frame_name.empty()) { if (!force_pessimistic && !frame_name.empty()) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ true, &is_success)); PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
} }
if (!is_success) { if (!success) {
// The optimistic mode does not converge. Let's fall back to the // The optimistic mode does not converge. Let's fall back to the
// pessimistic mode. // pessimistic mode.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ false, nullptr)); PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
} }
if (VLOG_IS_ON(2)) {
VLOG(2) << "Done populating frame " << frame_name << " using the " VLOG(2) << "Done populating frame " << frame_name << " using the "
<< (is_success ? "optimistic" : "pessimistic") << " mode."; << (success ? "optimistic" : "pessimistic") << " mode.";
}
} }
return Status::OK(); return Status::OK();
} }
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo, Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
bool use_optimistic_mode, bool use_optimistic_mode,
bool* is_success) { bool* success) {
// This an abstract interpretation over the deadness propagation semantics of // This an abstract interpretation over the deadness propagation semantics of
// the graph executor. // the graph executor.
// //
@ -1303,7 +1310,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
// delta should not change in the second iteration. // delta should not change in the second iteration.
std::vector<bool> should_revisit; std::vector<bool> should_revisit;
should_revisit.resize(graph_.num_node_ids()); should_revisit.resize(graph_.num_node_ids());
for (Node* n : tpo) { for (Node* n : topo) {
VLOG(4) << "Visiting " << n->name(); VLOG(4) << "Visiting " << n->name();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode)); HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
@ -1318,7 +1325,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
} }
} }
for (Node* n : tpo) { for (Node* n : topo) {
// The nodes added to should_revisit in the previous loop need to be // The nodes added to should_revisit in the previous loop need to be
// revisited now. Reprocesing these initial nodes may add *their* consumers // revisited now. Reprocesing these initial nodes may add *their* consumers
// to should_revisit, and these newly added nodes will also be processed by // to should_revisit, and these newly added nodes will also be processed by
@ -1339,7 +1346,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
if (use_optimistic_mode) { if (use_optimistic_mode) {
bool is_converged = true; bool is_converged = true;
absl::flat_hash_map<string, Predicate*> frame_to_pred; absl::flat_hash_map<string, Predicate*> frame_to_pred;
for (Node* n : tpo) { for (Node* n : topo) {
if (!n->IsMerge()) { if (!n->IsMerge()) {
continue; continue;
} }
@ -1382,7 +1389,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
// Clear the assigned predicates if the optimistic mode does not converge. // Clear the assigned predicates if the optimistic mode does not converge.
if (!is_converged) { if (!is_converged) {
for (Node* n : tpo) { for (Node* n : topo) {
for (int oid = 0; oid < n->num_outputs(); ++oid) { for (int oid = 0; oid < n->num_outputs(); ++oid) {
predicate_map_.erase(TensorId(n->name(), oid)); predicate_map_.erase(TensorId(n->name(), oid));
} }
@ -1390,8 +1397,8 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
} }
} }
if (is_success != nullptr) { if (success != nullptr) {
*is_success = is_converged; *success = is_converged;
} }
} }

View File

@ -639,7 +639,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false)); /*force_pessimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -653,7 +653,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true)); /*force_pessimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -683,7 +683,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false)); /*force_pessimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"{#true,&,*iv0/cond:0}<frame>"); "{#true,&,*iv0/cond:0}<frame>");
@ -691,7 +691,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true)); /*force_pessimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"div0/iv:0"); "div0/iv:0");
@ -746,7 +746,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false)); /*force_pessimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>"); "{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -772,7 +772,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true)); /*force_pessimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>"); "{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -925,7 +925,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false)); /*force_pessimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -942,7 +942,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true)); /*force_pessimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");