Implement optimistic deadness analysis.

The optimistic deadness analysis can complement the original pessimistic
analysis and reduce false positives due to loop structures such as in
dynamic rnn.

In addition, a new rule {S,&,X} & ~X => S is added to try to cancel out
the leftover symbols after while loops. Otherwise, symbols keep accumulating
in the case of cascaded while loops.
This commit is contained in:
Trent Lo 2019-05-17 11:51:43 -07:00
parent 6232230c0e
commit e1f36241fc
3 changed files with 479 additions and 78 deletions

View File

@ -85,22 +85,28 @@ limitations under the License.
// true on iteration 0, 1, 2 respectively. This is made more precise in the
// comment on the AndRecurrence class.
//
// The general algorithm that deals with cycles does two RPO (reverse post
// order) passes over the graph. On the first pass it assigns a symbolic
// predicate to merge nodes with backedges. On the second pass it tries to
// pattern matche the predicates for the backedges of these merges and infer an
// AndRecurrence for the merge.
// The general algorithm that deals with cycles does two topological-order
// iterations over the graph. On the first iteration it assigns a symbolic
// predicate to merge nodes with backedges. On the second iteration it tries
// to pattern match the predicates for the backedges of these merges and infer
// an AndRecurrence for the merge. In other words, we do a data flow analysis
// where the data-flow lattice has two elements, Symbolic and NonSymbolic with
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
// sufficient to converge.
//
// In other words, we do a pessimistic data flow analysis where the data-flow
// lattice has two elements, Symbolic and NonSymbolic with Symbolic >
// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
// converge. We don't do an optimistic data flow analysis to make pattern
// matching easier: if we assigned the predicate of the initial value to the
// merge during the first pass, on the second pass the backedge may see a
// simplified value that would be difficult to pattern match.
// We first do an optimisitc analysis and, if it does not converge, we then fall
// back to a pessimistic analysis. The optimistic analysis assigns the same
// symbolic predicate to all the merge nodes whose preceding enter nodes have
// the same frame name on the first iteration. On the second iteration, if all
// the merge nodes are pattern matched into the same AndRecurrence predicate
// instance, the optimistic assignment of the same symbolic predicate is correct
// and the analyzed result is taken.
//
// We still use symbolic predicates for merges for which we can't pattern match
// on the backedge predicate. This is conservatively correct.
// Otherwise, if the optimistic analysis fails to converge, we then obtain the
// result by falling back to the pessimistic analysis which assigns a unique
// symbolic predicate to each merge on the first iteration. We still use
// symbolic predicates for merges for which we can't pattern match on the
// backedge predicate. This is conservatively correct.
namespace tensorflow {
@ -636,6 +642,36 @@ Predicate* PredicateFactory::MakeAndOrImpl(
negated_ops.insert(negated_op);
}
// Simplify {S,&,X} & ~X & ... => S & ...
if (is_and) {
absl::flat_hash_set<Predicate*> to_remove;
std::vector<Predicate*> to_add;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kAndRecurrence) {
AndRecurrencePredicate* and_rec =
static_cast<AndRecurrencePredicate*>(op);
if (negated_ops.count(and_rec->step())) {
// Remove and_rec and ~X and insert S. Note that checking the
// existence of ~X through negated_ops is sufficient since it makes
// sure the predicate is in the input operands. It does not need to
// be in simplified_ops if it was already cancelled out.
to_remove.insert(and_rec);
to_remove.insert(MakeNotPredicate(and_rec->step()));
to_add.push_back(and_rec->start());
}
}
}
auto it = simplified_ops.begin();
while (it != simplified_ops.end()) {
if (to_remove.count(*it)) {
it = simplified_ops.erase(it);
} else {
++it;
}
}
simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
}
// If all ops contain the same subop, then factor it out thanks to the
// distributive property. Such as:
// - (A & B) | (A & C) | (A & D) => A & (B | C | D)
@ -699,8 +735,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
explicit DeadnessAnalysisImpl(const Graph* graph)
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate();
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
Status Populate(bool force_pessimistic = false);
Status PopulateFrame(absl::Span<Node* const> tpo, bool use_optimistic_mode,
bool* is_success);
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
Node* n, int oidx) const override;
void Print() const override;
@ -742,16 +779,19 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
}
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode);
Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
Status HandleNode(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode = false);
const Graph& graph_;
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_;
std::vector<ControlFlowInfo> control_flow_info_;
bool vlog_;
absl::flat_hash_map<string, Node*> frame_to_merge_node_;
};
TensorId InputEdgeToTensorId(const Edge* e) {
@ -914,10 +954,125 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
return Status::OK();
}
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
const Node* src_node, string* frame) {
int depth = 0;
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
while (cfi_iter->parent_frame != src_node) {
n = cfi_iter->parent_frame;
cfi_iter = &cfi_infos[n->id()];
if (depth++ > 5000) {
return errors::Internal(
"Frame of depth > 5000: Probably malformed graph or a bug in "
"BuildControlFlowInfo");
}
}
*frame = cfi_iter->frame_name;
return Status::OK();
}
// Compute a special topological order for the Graph, where nodes having the
// same root frame are placed adjacent to each other.
Status GetFrameBasedTopologicalOrder(const Graph* g,
absl::Span<const ControlFlowInfo> cf_infos,
std::vector<Node*>* order) {
absl::flat_hash_map<string, size_t> num_enters;
absl::flat_hash_map<string, size_t> num_exits;
std::vector<size_t> num_ready_inputs(g->num_node_ids(), 0);
Node* src_node = g->source_node();
std::deque<Node*> ready;
ready.push_back(src_node);
for (const auto* node : g->op_nodes()) {
const ControlFlowInfo& cf = cf_infos[node->id()];
bool is_root_level = cf.parent_frame == src_node;
if (IsEnter(node) && is_root_level) {
// Since we care only the root-level frame, full frame names are the same
// as frame names.
++num_enters[cf.frame_name];
} else if (IsExit(node) && is_root_level) {
++num_exits[cf.frame_name];
}
if (IsMerge(node)) {
for (const Edge* e : node->in_edges()) {
if (IsNextIteration(e->src())) {
++num_ready_inputs[node->id()];
}
}
}
}
absl::flat_hash_map<string, std::vector<Node*>> staging_enter_vecs;
// Exit nodes shall all be from the same frame, as we process a frame at a
// time. So, one vector is enough.
std::vector<Node*> staging_exit_vec;
while (!ready.empty()) {
Node* curr_node = ready.front();
ready.pop_front();
VLOG(4) << "Visiting " << curr_node->name();
order->push_back(curr_node);
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
if (IsNextIteration(curr_node) && IsMerge(out)) {
// Edge NextIteration->Merge has been counted.
continue;
}
++num_ready_inputs[out->id()];
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
bool is_root_level = cf_infos[out_id].parent_frame == src_node;
string frame_name = cf_infos[out_id].frame_name;
if (IsEnter(out) && is_root_level) {
staging_enter_vecs[frame_name].push_back(out);
} else if (IsExit(out) && is_root_level) {
staging_exit_vec.push_back(out);
} else {
ready.push_back(out);
}
}
if (ready.empty()) {
if (!staging_exit_vec.empty()) {
// Move staging nodes into the ready queue if any. If there are staging
// exits we must process them before processing the staging enters to
// make sure all nodes in the currently processing frame are visited
// before starting processing other frames.
string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name;
CHECK(staging_exit_vec.size() == num_exits[frame_name]);
ready.insert(ready.end(), staging_exit_vec.begin(),
staging_exit_vec.end());
staging_exit_vec.clear();
} else {
// Otherwise, try moving the staging enter nodes into the ready queue.
for (auto iter = staging_enter_vecs.begin();
iter != staging_enter_vecs.end(); ++iter) {
string frame_name = iter->first;
const std::vector<Node*>& staging_enters = iter->second;
if (staging_enters.size() == num_enters[frame_name]) {
ready.insert(ready.end(), staging_enters.begin(),
staging_enters.end());
staging_enter_vecs.erase(iter);
break;
}
}
}
}
}
CHECK(staging_enter_vecs.empty() && staging_exit_vec.empty());
return Status::OK();
}
} // namespace
Status DeadnessAnalysisImpl::HandleMerge(Node* n,
std::vector<bool>* should_revisit) {
std::vector<bool>* should_revisit,
bool use_optimistic_mode) {
// Merge ignores deadness of its control inputs. A merge that isn't the
// target of a backedge has is alive iff any of its data inputs are. The
// liveness of a merge that is the target of a backedge can sometimes be
@ -937,8 +1092,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// We're visiting this merge for the first time and it has an unvisited
// backedge.
Predicate* input_data_pred;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
if (use_optimistic_mode) {
// In the optimistic mode, we use the first-seen Merge node per
// frame as the representative Merge node. It is just convenient and
// does not affect the result after pattern-matching into the
// AndRecurrence form.
string frame_name = control_flow_info_[n->id()].frame_name;
auto ret = frame_to_merge_node_.insert({frame_name, n});
Node* representative = ret.first->second;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
representative, /*output_idx=*/0, /*must_be_true=*/false,
&input_data_pred));
} else {
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
}
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
should_revisit);
@ -948,7 +1116,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
std::vector<Predicate*> input_preds;
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
// We're visiting this merge for the first time and it is a acyclic merge.
// We're visiting this merge for the first time and it is an acyclic merge.
Predicate* input_data_pred =
predicate_factory_.MakeOrPredicate(input_preds);
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
@ -1022,11 +1190,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
}
Status DeadnessAnalysisImpl::HandleNode(Node* n,
std::vector<bool>* should_revisit) {
std::vector<bool>* should_revisit,
bool use_optimistic_mode) {
if (n->IsSwitch()) {
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
} else if (n->IsMerge()) {
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
} else if (n->IsControlTrigger()) {
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
nullptr);
@ -1040,17 +1209,19 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
return Status::OK();
}
Status DeadnessAnalysisImpl::Populate() {
std::vector<Node*> rpo;
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
return PopulateWithReversePostOrder(rpo);
}
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
absl::Span<Node* const> rpo) {
// We populate the nodes along a special topological order where nodes having
// the same root frame are placed adjacent to each other. This grouping enables
// processing the graph per root frame at a time and guarantees that when a root
// frame is being processed, nodes in the downstream frames have not yet been
// processed. This property is important because we need to process an entire
// frame to know whether the optimistic mode converges or not. In other words,
// nodes in the downstream frames shall not be populated until all of its
// upstream frames are populated. In effect, this order enables processing each
// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
// (root) frame. Note that we don't separate while loops belonging to the same
// nested while, as there is no clean cut for separating them in the topological
// order.
Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) {
std::vector<string> unreachable_nodes;
// Compute the loop structure of the graph.
TF_RETURN_IF_ERROR(
@ -1069,6 +1240,52 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
absl::StrJoin(unreachable_nodes, ", "));
}
std::vector<Node*> tpo;
TF_RETURN_IF_ERROR(
GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &tpo));
size_t frame_start = 0;
for (size_t i = 0; i < tpo.size(); ++i) {
// Collect nodes until we see a node who has a different root frame.
if (i != tpo.size() - 1) {
string i_frame_name, next_frame_name;
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i], control_flow_info_,
graph_.source_node(), &i_frame_name));
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i + 1], control_flow_info_,
graph_.source_node(), &next_frame_name));
if (i_frame_name == next_frame_name) {
continue;
}
}
string frame_name = control_flow_info_[tpo[i]->id()].frame_name;
absl::Span<Node*> sub_tpo(tpo.data() + frame_start, i - frame_start + 1);
frame_start = i + 1;
// First, try the optimistic mode.
bool is_success = false;
if (!force_pessimistic && !frame_name.empty()) {
TF_RETURN_IF_ERROR(
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ true, &is_success));
}
if (!is_success) {
// The optimistic mode does not converge. Let's fall back to the
// pessimistic mode.
TF_RETURN_IF_ERROR(
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ false, nullptr));
}
if (VLOG_IS_ON(2)) {
VLOG(2) << "Done populating frame " << frame_name << " using the "
<< (is_success ? "optimistic" : "pessimistic") << " mode.";
}
}
return Status::OK();
}
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
bool use_optimistic_mode,
bool* is_success) {
// This an abstract interpretation over the deadness propagation semantics of
// the graph executor.
//
@ -1086,9 +1303,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
// delta should not change in the second iteration.
std::vector<bool> should_revisit;
should_revisit.resize(graph_.num_node_ids());
for (Node* n : rpo) {
for (Node* n : tpo) {
VLOG(4) << "Visiting " << n->name();
TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
TF_RETURN_IF_ERROR(
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
if (n->IsNextIteration()) {
// If this is a backedge for a merge node then remember to reprocess the
// merge the next time we run.
@ -1100,11 +1318,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
}
}
for (Node* n : rpo) {
for (Node* n : tpo) {
// The nodes added to should_revisit in the previous loop need to be
// revisited now. Reprocesing these initial nodes may add *their* consumers
// to should_revisit, and these newly added nodes will also be processed by
// this very same loop. Since we're traversing the graph in reverse post
// this very same loop. Since we're traversing the graph in topological
// order (producers before consumers) and HandleNode(n) can only ever add
// n's consumers to should_revisit, we won't "miss" an addition to
// should_revisit.
@ -1114,6 +1332,69 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
}
}
// Check if the optimistic analysis converges. Specifically, check whether
// all the predicates of the merge nodes in the same frame are the same. If
// yes, report success. If not, report failure and clear the assigned
// predicates.
if (use_optimistic_mode) {
bool is_converged = true;
absl::flat_hash_map<string, Predicate*> frame_to_pred;
for (Node* n : tpo) {
if (!n->IsMerge()) {
continue;
}
const Edge* e;
TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
if (e == nullptr) {
// Skip acyclic merge nodes.
continue;
}
Node* merge = n;
string frame_name = control_flow_info_[merge->id()].frame_name;
auto it = predicate_map_.find(TensorId(merge->name(), 0));
Predicate* merge_pred = it->second;
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
is_converged = false;
VLOG(2) << "Running the optimistic mode on frame " << frame_name
<< " does not converge because node " << merge->name()
<< " cannot be mapped into the AndRecurrence form.";
break;
}
auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
if (!insert_result.second) {
// If we have already seen this frame name, verify the predicate is the
// same as the previously seen one's.
AndRecurrencePredicate* curr_andrec =
static_cast<AndRecurrencePredicate*>(merge_pred);
AndRecurrencePredicate* prev_andrec =
static_cast<AndRecurrencePredicate*>(insert_result.first->second);
if (curr_andrec != prev_andrec) {
is_converged = false;
VLOG(2) << "Running the optimistic mode on frame " << frame_name
<< " does not converge. Seeing different Merge predicates: \n"
<< curr_andrec->ToString() << " and \n"
<< prev_andrec->ToString();
break;
}
}
}
// Clear the assigned predicates if the optimistic mode does not converge.
if (!is_converged) {
for (Node* n : tpo) {
for (int oid = 0; oid < n->num_outputs(); ++oid) {
predicate_map_.erase(TensorId(n->name(), oid));
}
predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
}
}
if (is_success != nullptr) {
*is_success = is_converged;
}
}
return Status::OK();
}
@ -1170,22 +1451,14 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
}
namespace deadness_analysis_internal {
Status ComputePredicates(const Graph& graph,
PredicateMapTy* out_predicate_map) {
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool force_pessimistic) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.Populate());
TF_RETURN_IF_ERROR(impl.Populate(force_pessimistic));
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
Status ComputePredicates(const Graph& graph,
absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map) {
DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
*out_predicate_map = impl.PredicateMapAsString();
return Status::OK();
}
} // namespace deadness_analysis_internal
} // namespace tensorflow

View File

@ -25,15 +25,9 @@ namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only.
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool force_pessimistic = false);
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only. Makes deadness analysis visit the graph in the order
// specified in `reverse_post_order` which must be a valid RPO for the graph
// minus NextIteration->Merge edges.
Status ComputePredicates(const Graph& graph,
absl::Span<Node* const> reverse_post_order,
PredicateMapTy* out_predicate_map);
} // namespace deadness_analysis_internal
} // namespace tensorflow

View File

@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
predicate_map[ControlOutputFor(iv.induction_var)]);
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
predicate_map[ControlOutputFor(iv.induction_var)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
predicate_map[ControlOutputFor(iv.induction_var)]);
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
FixupSourceAndSinkEdges(root.graph());
// To make deadness analysis think that dependent_iv is a loop we need an RPO
// that visits the merge before the backedge. This is a legal RPO for
// deadness analysis since it ignores NextIteration->Merge edges during RPO.
// Right now dependent_iv has an edge from Merge to NextIteration so do the
// RPO with this edge in place. Then remove this edge to get our test case.
std::vector<Node*> rpo;
GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
/*edge_filter=*/[](const Edge& edge) {
return !edge.src()->IsNextIteration();
});
TF_ASSERT_OK(root.graph()->UpdateEdge(
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"{#true,&,*iv0/cond:0}<frame>");
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"div0/iv:0");
@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
"{(*iv_outer/cond:0 & "
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
"cond:0}<inner_loop;outer_loop>");
// force_pessimistic = true or not should produce the same results because
// of fallback. However, note that the order of iv_inner/cond:0 and
// iv_inner/iv:0 is different because the optimistic approach does not
// create predicates for all merges and it can change the predicate id and
// hence the symbol order.
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
"iv_inner/iv:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
"{{#true,&,(iv_outer/iv:0 & "
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
}
}
@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
}
}
TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv_outer =
CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
Output enter_constant_outer_loop = ops::internal::Enter(
root.WithOpName("constant_enter_outer_loop"),
ops::Const(root.WithOpName("constant"), 5), "outer_loop",
ops::internal::Enter::Attrs().IsConstant(true));
ops::Switch inner_value(root.WithOpName("outer_is_live"),
enter_constant_outer_loop, iv_outer.loop_cond);
InductionVarInfo iv_inner = CreateInductionVariable(
root, "iv_inner", "inner_loop", inner_value.output_true);
DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
root, "div0_inner", "inner_loop", iv_inner.loop_cond,
div0_outer.induction_var);
DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
root, "div1_inner", "inner_loop", iv_inner.loop_cond,
div1_outer.induction_var);
Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
"tensor_a", "sender", 0, "receiver");
Output capture_enter_outer = ops::internal::Enter(
root.WithOpName("capture_enter_outer"), captured, "outer_loop",
ops::internal::Enter::Attrs().IsConstant(true));
Output capture_enter_inner = ops::internal::Enter(
root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
ops::internal::Enter::Attrs().IsConstant(true));
Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
capture_enter_inner);
TF_ASSERT_OK(root.graph()->UpdateEdge(
mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
div1_inner.induction_var);
VLogGraphIfAsked(*root.graph());
{
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
TF_ASSERT_OK_AND_ASSIGN(
bool has_inputs_with_mismatching_deadness,
HasInputsWithMismatchingDeadness(*result, *add0.node()));
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
}
}
TEST(DeadnessAnalysisTest, CyclicRecurrence) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
DependentInductionVar div0 =
CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
DependentInductionVar div1 =
CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
FixupSourceAndSinkEdges(root.graph());
TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
div0.latch.output_true.node(), 0));
TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
div1.latch.output_true.node(), 0));
VLogGraphIfAsked(*root.graph());
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
// This tests the rule {S,&,X} & ~X => S.
TensorId switch_false_out = {div1.latch.output_false.node()->name(),
div1.latch.output_false.index()};
EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
}
{
PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic*/ true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
}
}
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
Scope root = Scope::NewRootScope().ExitOnError();
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);