Address review comments.
This commit is contained in:
parent
e1f36241fc
commit
3f7afc54bd
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
|
||||
#include "tensorflow/compiler/jit/xla_cluster_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -648,9 +649,8 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
std::vector<Predicate*> to_add;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kAndRecurrence) {
|
||||
AndRecurrencePredicate* and_rec =
|
||||
static_cast<AndRecurrencePredicate*>(op);
|
||||
if (negated_ops.count(and_rec->step())) {
|
||||
auto* and_rec = static_cast<AndRecurrencePredicate*>(op);
|
||||
if (negated_ops.contains(and_rec->step())) {
|
||||
// Remove and_rec and ~X and insert S. Note that checking the
|
||||
// existence of ~X through negated_ops is sufficient since it makes
|
||||
// sure the predicate is in the input operands. It does not need to
|
||||
@ -663,7 +663,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
}
|
||||
auto it = simplified_ops.begin();
|
||||
while (it != simplified_ops.end()) {
|
||||
if (to_remove.count(*it)) {
|
||||
if (to_remove.contains(*it)) {
|
||||
it = simplified_ops.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
@ -736,8 +736,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||
|
||||
Status Populate(bool force_pessimistic = false);
|
||||
Status PopulateFrame(absl::Span<Node* const> tpo, bool use_optimistic_mode,
|
||||
bool* is_success);
|
||||
Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
|
||||
bool* success);
|
||||
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
|
||||
Node* n, int oidx) const override;
|
||||
void Print() const override;
|
||||
@ -955,11 +955,13 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If the node is inside some frames, get the name of the outermost non-empty
|
||||
// frame. Otherwise, get an empty frame name.
|
||||
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
const Node* src_node, string* frame) {
|
||||
absl::string_view frame) {
|
||||
int depth = 0;
|
||||
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
|
||||
while (cfi_iter->parent_frame != src_node) {
|
||||
while (!cfi_iter->parent_frame->IsSource()) {
|
||||
n = cfi_iter->parent_frame;
|
||||
cfi_iter = &cfi_infos[n->id()];
|
||||
|
||||
@ -970,21 +972,23 @@ Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
}
|
||||
}
|
||||
|
||||
*frame = cfi_iter->frame_name;
|
||||
frame = cfi_iter->frame_name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compute a special topological order for the Graph, where nodes having the
|
||||
// same root frame are placed adjacent to each other.
|
||||
// same root frame are placed adjacent to each other. The traversal is a
|
||||
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
|
||||
// many inputs of each node are ready; a node is ready to be scheduled if all
|
||||
// of its inputs are ready.
|
||||
// For details, see https://en.wikipedia.org/wiki/Topological_sorting
|
||||
Status GetFrameBasedTopologicalOrder(const Graph* g,
|
||||
absl::Span<const ControlFlowInfo> cf_infos,
|
||||
std::vector<Node*>* order) {
|
||||
absl::flat_hash_map<string, size_t> num_enters;
|
||||
absl::flat_hash_map<string, size_t> num_exits;
|
||||
absl::flat_hash_map<absl::string_view, size_t> num_enters;
|
||||
absl::flat_hash_map<absl::string_view, size_t> num_exits;
|
||||
std::vector<size_t> num_ready_inputs(g->num_node_ids(), 0);
|
||||
Node* src_node = g->source_node();
|
||||
std::deque<Node*> ready;
|
||||
ready.push_back(src_node);
|
||||
for (const auto* node : g->op_nodes()) {
|
||||
const ControlFlowInfo& cf = cf_infos[node->id()];
|
||||
bool is_root_level = cf.parent_frame == src_node;
|
||||
@ -995,6 +999,8 @@ Status GetFrameBasedTopologicalOrder(const Graph* g,
|
||||
} else if (IsExit(node) && is_root_level) {
|
||||
++num_exits[cf.frame_name];
|
||||
}
|
||||
// Edge NextIteration->Merge is counted before starting the traveral to
|
||||
// break the backedges.
|
||||
if (IsMerge(node)) {
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (IsNextIteration(e->src())) {
|
||||
@ -1004,6 +1010,8 @@ Status GetFrameBasedTopologicalOrder(const Graph* g,
|
||||
}
|
||||
}
|
||||
|
||||
std::deque<Node*> ready;
|
||||
ready.push_back(src_node);
|
||||
absl::flat_hash_map<string, std::vector<Node*>> staging_enter_vecs;
|
||||
// Exit nodes shall all be from the same frame, as we process a frame at a
|
||||
// time. So, one vector is enough.
|
||||
@ -1044,7 +1052,7 @@ Status GetFrameBasedTopologicalOrder(const Graph* g,
|
||||
// make sure all nodes in the currently processing frame are visited
|
||||
// before starting processing other frames.
|
||||
string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name;
|
||||
CHECK(staging_exit_vec.size() == num_exits[frame_name]);
|
||||
CHECK_EQ(staging_exit_vec.size(), num_exits[frame_name]);
|
||||
ready.insert(ready.end(), staging_exit_vec.begin(),
|
||||
staging_exit_vec.end());
|
||||
staging_exit_vec.clear();
|
||||
@ -1098,8 +1106,8 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
// does not affect the result after pattern-matching into the
|
||||
// AndRecurrence form.
|
||||
string frame_name = control_flow_info_[n->id()].frame_name;
|
||||
auto ret = frame_to_merge_node_.insert({frame_name, n});
|
||||
Node* representative = ret.first->second;
|
||||
auto insert_result = frame_to_merge_node_.insert({frame_name, n});
|
||||
Node* representative = insert_result.first->second;
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
representative, /*output_idx=*/0, /*must_be_true=*/false,
|
||||
&input_data_pred));
|
||||
@ -1240,52 +1248,51 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) {
|
||||
absl::StrJoin(unreachable_nodes, ", "));
|
||||
}
|
||||
|
||||
std::vector<Node*> tpo;
|
||||
std::vector<Node*> topo;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &tpo));
|
||||
GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &topo));
|
||||
|
||||
size_t frame_start = 0;
|
||||
for (size_t i = 0; i < tpo.size(); ++i) {
|
||||
for (size_t i = 0; i < topo.size(); ++i) {
|
||||
// Collect nodes until we see a node who has a different root frame.
|
||||
if (i != tpo.size() - 1) {
|
||||
string i_frame_name, next_frame_name;
|
||||
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i], control_flow_info_,
|
||||
graph_.source_node(), &i_frame_name));
|
||||
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i + 1], control_flow_info_,
|
||||
graph_.source_node(), &next_frame_name));
|
||||
if (i != topo.size() - 1) {
|
||||
absl::string_view i_frame_name, next_frame_name;
|
||||
TF_RETURN_IF_ERROR(GetRootFrame(topo[i], control_flow_info_,
|
||||
i_frame_name));
|
||||
TF_RETURN_IF_ERROR(GetRootFrame(topo[i + 1], control_flow_info_,
|
||||
next_frame_name));
|
||||
if (i_frame_name == next_frame_name) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
string frame_name = control_flow_info_[tpo[i]->id()].frame_name;
|
||||
absl::Span<Node*> sub_tpo(tpo.data() + frame_start, i - frame_start + 1);
|
||||
string frame_name = control_flow_info_[topo[i]->id()].frame_name;
|
||||
absl::Span<Node*> sub_topo(topo.data() + frame_start,
|
||||
/*length=*/i - frame_start + 1);
|
||||
frame_start = i + 1;
|
||||
|
||||
// First, try the optimistic mode.
|
||||
bool is_success = false;
|
||||
bool success = false;
|
||||
if (!force_pessimistic && !frame_name.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ true, &is_success));
|
||||
PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
|
||||
}
|
||||
if (!is_success) {
|
||||
if (!success) {
|
||||
// The optimistic mode does not converge. Let's fall back to the
|
||||
// pessimistic mode.
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ false, nullptr));
|
||||
PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr));
|
||||
}
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Done populating frame " << frame_name << " using the "
|
||||
<< (is_success ? "optimistic" : "pessimistic") << " mode.";
|
||||
}
|
||||
<< (success ? "optimistic" : "pessimistic") << " mode.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
|
||||
bool use_optimistic_mode,
|
||||
bool* is_success) {
|
||||
bool* success) {
|
||||
// This an abstract interpretation over the deadness propagation semantics of
|
||||
// the graph executor.
|
||||
//
|
||||
@ -1303,7 +1310,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
// delta should not change in the second iteration.
|
||||
std::vector<bool> should_revisit;
|
||||
should_revisit.resize(graph_.num_node_ids());
|
||||
for (Node* n : tpo) {
|
||||
for (Node* n : topo) {
|
||||
VLOG(4) << "Visiting " << n->name();
|
||||
TF_RETURN_IF_ERROR(
|
||||
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
|
||||
@ -1318,7 +1325,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : tpo) {
|
||||
for (Node* n : topo) {
|
||||
// The nodes added to should_revisit in the previous loop need to be
|
||||
// revisited now. Reprocesing these initial nodes may add *their* consumers
|
||||
// to should_revisit, and these newly added nodes will also be processed by
|
||||
@ -1339,7 +1346,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
if (use_optimistic_mode) {
|
||||
bool is_converged = true;
|
||||
absl::flat_hash_map<string, Predicate*> frame_to_pred;
|
||||
for (Node* n : tpo) {
|
||||
for (Node* n : topo) {
|
||||
if (!n->IsMerge()) {
|
||||
continue;
|
||||
}
|
||||
@ -1382,7 +1389,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
|
||||
// Clear the assigned predicates if the optimistic mode does not converge.
|
||||
if (!is_converged) {
|
||||
for (Node* n : tpo) {
|
||||
for (Node* n : topo) {
|
||||
for (int oid = 0; oid < n->num_outputs(); ++oid) {
|
||||
predicate_map_.erase(TensorId(n->name(), oid));
|
||||
}
|
||||
@ -1390,8 +1397,8 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
}
|
||||
}
|
||||
|
||||
if (is_success != nullptr) {
|
||||
*is_success = is_converged;
|
||||
if (success != nullptr) {
|
||||
*success = is_converged;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -639,7 +639,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
/*force_pessimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
@ -653,7 +653,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
/*force_pessimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
@ -683,7 +683,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
/*force_pessimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<frame>");
|
||||
@ -691,7 +691,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
/*force_pessimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||
"div0/iv:0");
|
||||
@ -746,7 +746,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
/*force_pessimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||
@ -772,7 +772,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
/*force_pessimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||
@ -925,7 +925,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
/*force_pessimistic=*/false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
@ -942,7 +942,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) {
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
/*force_pessimistic=*/true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
|
Loading…
Reference in New Issue
Block a user