Implement optimistic deadness analysis.
The optimistic deadness analysis can complement the original pessimistic analysis and reduce false positives due to loop structures such as in dynamic rnn. In addition, a new rule {S,&,X} & ~X => S is added to try to cancel out the leftover symbols after while loops. Otherwise, symbols keep accumulating in the case of cascaded while loops.
This commit is contained in:
parent
6232230c0e
commit
e1f36241fc
@ -85,22 +85,28 @@ limitations under the License.
|
||||
// true on iteration 0, 1, 2 respectively. This is made more precise in the
|
||||
// comment on the AndRecurrence class.
|
||||
//
|
||||
// The general algorithm that deals with cycles does two RPO (reverse post
|
||||
// order) passes over the graph. On the first pass it assigns a symbolic
|
||||
// predicate to merge nodes with backedges. On the second pass it tries to
|
||||
// pattern matche the predicates for the backedges of these merges and infer an
|
||||
// AndRecurrence for the merge.
|
||||
// The general algorithm that deals with cycles does two topological-order
|
||||
// iterations over the graph. On the first iteration it assigns a symbolic
|
||||
// predicate to merge nodes with backedges. On the second iteration it tries
|
||||
// to pattern match the predicates for the backedges of these merges and infer
|
||||
// an AndRecurrence for the merge. In other words, we do a data flow analysis
|
||||
// where the data-flow lattice has two elements, Symbolic and NonSymbolic with
|
||||
// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are
|
||||
// sufficient to converge.
|
||||
//
|
||||
// In other words, we do a pessimistic data flow analysis where the data-flow
|
||||
// lattice has two elements, Symbolic and NonSymbolic with Symbolic >
|
||||
// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to
|
||||
// converge. We don't do an optimistic data flow analysis to make pattern
|
||||
// matching easier: if we assigned the predicate of the initial value to the
|
||||
// merge during the first pass, on the second pass the backedge may see a
|
||||
// simplified value that would be difficult to pattern match.
|
||||
// We first do an optimisitc analysis and, if it does not converge, we then fall
|
||||
// back to a pessimistic analysis. The optimistic analysis assigns the same
|
||||
// symbolic predicate to all the merge nodes whose preceding enter nodes have
|
||||
// the same frame name on the first iteration. On the second iteration, if all
|
||||
// the merge nodes are pattern matched into the same AndRecurrence predicate
|
||||
// instance, the optimistic assignment of the same symbolic predicate is correct
|
||||
// and the analyzed result is taken.
|
||||
//
|
||||
// We still use symbolic predicates for merges for which we can't pattern match
|
||||
// on the backedge predicate. This is conservatively correct.
|
||||
// Otherwise, if the optimistic analysis fails to converge, we then obtain the
|
||||
// result by falling back to the pessimistic analysis which assigns a unique
|
||||
// symbolic predicate to each merge on the first iteration. We still use
|
||||
// symbolic predicates for merges for which we can't pattern match on the
|
||||
// backedge predicate. This is conservatively correct.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -636,6 +642,36 @@ Predicate* PredicateFactory::MakeAndOrImpl(
|
||||
negated_ops.insert(negated_op);
|
||||
}
|
||||
|
||||
// Simplify {S,&,X} & ~X & ... => S & ...
|
||||
if (is_and) {
|
||||
absl::flat_hash_set<Predicate*> to_remove;
|
||||
std::vector<Predicate*> to_add;
|
||||
for (Predicate* op : simplified_ops) {
|
||||
if (op->kind() == Predicate::Kind::kAndRecurrence) {
|
||||
AndRecurrencePredicate* and_rec =
|
||||
static_cast<AndRecurrencePredicate*>(op);
|
||||
if (negated_ops.count(and_rec->step())) {
|
||||
// Remove and_rec and ~X and insert S. Note that checking the
|
||||
// existence of ~X through negated_ops is sufficient since it makes
|
||||
// sure the predicate is in the input operands. It does not need to
|
||||
// be in simplified_ops if it was already cancelled out.
|
||||
to_remove.insert(and_rec);
|
||||
to_remove.insert(MakeNotPredicate(and_rec->step()));
|
||||
to_add.push_back(and_rec->start());
|
||||
}
|
||||
}
|
||||
}
|
||||
auto it = simplified_ops.begin();
|
||||
while (it != simplified_ops.end()) {
|
||||
if (to_remove.count(*it)) {
|
||||
it = simplified_ops.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end());
|
||||
}
|
||||
|
||||
// If all ops contain the same subop, then factor it out thanks to the
|
||||
// distributive property. Such as:
|
||||
// - (A & B) | (A & C) | (A & D) => A & (B | C | D)
|
||||
@ -699,8 +735,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
explicit DeadnessAnalysisImpl(const Graph* graph)
|
||||
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
|
||||
|
||||
Status Populate();
|
||||
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
|
||||
Status Populate(bool force_pessimistic = false);
|
||||
Status PopulateFrame(absl::Span<Node* const> tpo, bool use_optimistic_mode,
|
||||
bool* is_success);
|
||||
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
|
||||
Node* n, int oidx) const override;
|
||||
void Print() const override;
|
||||
@ -742,16 +779,19 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
|
||||
}
|
||||
|
||||
Status HandleSwitch(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleMerge(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleMerge(Node* n, std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode);
|
||||
Status HandleRecv(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleGeneric(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
|
||||
Status HandleNode(Node* n, std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode = false);
|
||||
|
||||
const Graph& graph_;
|
||||
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
|
||||
PredicateFactory predicate_factory_;
|
||||
std::vector<ControlFlowInfo> control_flow_info_;
|
||||
bool vlog_;
|
||||
absl::flat_hash_map<string, Node*> frame_to_merge_node_;
|
||||
};
|
||||
|
||||
TensorId InputEdgeToTensorId(const Edge* e) {
|
||||
@ -914,10 +954,125 @@ Status GetFullFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
|
||||
const Node* src_node, string* frame) {
|
||||
int depth = 0;
|
||||
const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()];
|
||||
while (cfi_iter->parent_frame != src_node) {
|
||||
n = cfi_iter->parent_frame;
|
||||
cfi_iter = &cfi_infos[n->id()];
|
||||
|
||||
if (depth++ > 5000) {
|
||||
return errors::Internal(
|
||||
"Frame of depth > 5000: Probably malformed graph or a bug in "
|
||||
"BuildControlFlowInfo");
|
||||
}
|
||||
}
|
||||
|
||||
*frame = cfi_iter->frame_name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Compute a special topological order for the Graph, where nodes having the
|
||||
// same root frame are placed adjacent to each other.
|
||||
Status GetFrameBasedTopologicalOrder(const Graph* g,
|
||||
absl::Span<const ControlFlowInfo> cf_infos,
|
||||
std::vector<Node*>* order) {
|
||||
absl::flat_hash_map<string, size_t> num_enters;
|
||||
absl::flat_hash_map<string, size_t> num_exits;
|
||||
std::vector<size_t> num_ready_inputs(g->num_node_ids(), 0);
|
||||
Node* src_node = g->source_node();
|
||||
std::deque<Node*> ready;
|
||||
ready.push_back(src_node);
|
||||
for (const auto* node : g->op_nodes()) {
|
||||
const ControlFlowInfo& cf = cf_infos[node->id()];
|
||||
bool is_root_level = cf.parent_frame == src_node;
|
||||
if (IsEnter(node) && is_root_level) {
|
||||
// Since we care only the root-level frame, full frame names are the same
|
||||
// as frame names.
|
||||
++num_enters[cf.frame_name];
|
||||
} else if (IsExit(node) && is_root_level) {
|
||||
++num_exits[cf.frame_name];
|
||||
}
|
||||
if (IsMerge(node)) {
|
||||
for (const Edge* e : node->in_edges()) {
|
||||
if (IsNextIteration(e->src())) {
|
||||
++num_ready_inputs[node->id()];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, std::vector<Node*>> staging_enter_vecs;
|
||||
// Exit nodes shall all be from the same frame, as we process a frame at a
|
||||
// time. So, one vector is enough.
|
||||
std::vector<Node*> staging_exit_vec;
|
||||
while (!ready.empty()) {
|
||||
Node* curr_node = ready.front();
|
||||
ready.pop_front();
|
||||
|
||||
VLOG(4) << "Visiting " << curr_node->name();
|
||||
order->push_back(curr_node);
|
||||
|
||||
for (const Edge* out_edge : curr_node->out_edges()) {
|
||||
Node* out = out_edge->dst();
|
||||
int out_id = out->id();
|
||||
if (IsNextIteration(curr_node) && IsMerge(out)) {
|
||||
// Edge NextIteration->Merge has been counted.
|
||||
continue;
|
||||
}
|
||||
++num_ready_inputs[out->id()];
|
||||
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
|
||||
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
|
||||
|
||||
bool is_root_level = cf_infos[out_id].parent_frame == src_node;
|
||||
string frame_name = cf_infos[out_id].frame_name;
|
||||
if (IsEnter(out) && is_root_level) {
|
||||
staging_enter_vecs[frame_name].push_back(out);
|
||||
} else if (IsExit(out) && is_root_level) {
|
||||
staging_exit_vec.push_back(out);
|
||||
} else {
|
||||
ready.push_back(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (ready.empty()) {
|
||||
if (!staging_exit_vec.empty()) {
|
||||
// Move staging nodes into the ready queue if any. If there are staging
|
||||
// exits we must process them before processing the staging enters to
|
||||
// make sure all nodes in the currently processing frame are visited
|
||||
// before starting processing other frames.
|
||||
string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name;
|
||||
CHECK(staging_exit_vec.size() == num_exits[frame_name]);
|
||||
ready.insert(ready.end(), staging_exit_vec.begin(),
|
||||
staging_exit_vec.end());
|
||||
staging_exit_vec.clear();
|
||||
} else {
|
||||
// Otherwise, try moving the staging enter nodes into the ready queue.
|
||||
for (auto iter = staging_enter_vecs.begin();
|
||||
iter != staging_enter_vecs.end(); ++iter) {
|
||||
string frame_name = iter->first;
|
||||
const std::vector<Node*>& staging_enters = iter->second;
|
||||
if (staging_enters.size() == num_enters[frame_name]) {
|
||||
ready.insert(ready.end(), staging_enters.begin(),
|
||||
staging_enters.end());
|
||||
staging_enter_vecs.erase(iter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CHECK(staging_enter_vecs.empty() && staging_exit_vec.empty());
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
std::vector<bool>* should_revisit) {
|
||||
std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode) {
|
||||
// Merge ignores deadness of its control inputs. A merge that isn't the
|
||||
// target of a backedge has is alive iff any of its data inputs are. The
|
||||
// liveness of a merge that is the target of a backedge can sometimes be
|
||||
@ -937,8 +1092,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
// We're visiting this merge for the first time and it has an unvisited
|
||||
// backedge.
|
||||
Predicate* input_data_pred;
|
||||
if (use_optimistic_mode) {
|
||||
// In the optimistic mode, we use the first-seen Merge node per
|
||||
// frame as the representative Merge node. It is just convenient and
|
||||
// does not affect the result after pattern-matching into the
|
||||
// AndRecurrence form.
|
||||
string frame_name = control_flow_info_[n->id()].frame_name;
|
||||
auto ret = frame_to_merge_node_.insert({frame_name, n});
|
||||
Node* representative = ret.first->second;
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
representative, /*output_idx=*/0, /*must_be_true=*/false,
|
||||
&input_data_pred));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
|
||||
n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred));
|
||||
}
|
||||
|
||||
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
||||
should_revisit);
|
||||
@ -948,7 +1116,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
|
||||
std::vector<Predicate*> input_preds;
|
||||
TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
|
||||
|
||||
// We're visiting this merge for the first time and it is a acyclic merge.
|
||||
// We're visiting this merge for the first time and it is an acyclic merge.
|
||||
Predicate* input_data_pred =
|
||||
predicate_factory_.MakeOrPredicate(input_preds);
|
||||
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
|
||||
@ -1022,11 +1190,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
||||
std::vector<bool>* should_revisit) {
|
||||
std::vector<bool>* should_revisit,
|
||||
bool use_optimistic_mode) {
|
||||
if (n->IsSwitch()) {
|
||||
TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit));
|
||||
} else if (n->IsMerge()) {
|
||||
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit));
|
||||
TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode));
|
||||
} else if (n->IsControlTrigger()) {
|
||||
SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(),
|
||||
nullptr);
|
||||
@ -1040,17 +1209,19 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::Populate() {
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(),
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
return PopulateWithReversePostOrder(rpo);
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
absl::Span<Node* const> rpo) {
|
||||
// We populate the nodes along a special topological order where nodes having
|
||||
// the same root frame are placed adjacent to each other. This grouping enables
|
||||
// processing the graph per root frame at a time and guarantees that when a root
|
||||
// frame is being processed, nodes in the downstream frames have not yet been
|
||||
// processed. This property is important because we need to process an entire
|
||||
// frame to know whether the optimistic mode converges or not. In other words,
|
||||
// nodes in the downstream frames shall not be populated until all of its
|
||||
// upstream frames are populated. In effect, this order enables processing each
|
||||
// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique
|
||||
// (root) frame. Note that we don't separate while loops belonging to the same
|
||||
// nested while, as there is no clean cut for separating them in the topological
|
||||
// order.
|
||||
Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) {
|
||||
std::vector<string> unreachable_nodes;
|
||||
// Compute the loop structure of the graph.
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1069,6 +1240,52 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
absl::StrJoin(unreachable_nodes, ", "));
|
||||
}
|
||||
|
||||
std::vector<Node*> tpo;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &tpo));
|
||||
|
||||
size_t frame_start = 0;
|
||||
for (size_t i = 0; i < tpo.size(); ++i) {
|
||||
// Collect nodes until we see a node who has a different root frame.
|
||||
if (i != tpo.size() - 1) {
|
||||
string i_frame_name, next_frame_name;
|
||||
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i], control_flow_info_,
|
||||
graph_.source_node(), &i_frame_name));
|
||||
TF_RETURN_IF_ERROR(GetRootFrame(tpo[i + 1], control_flow_info_,
|
||||
graph_.source_node(), &next_frame_name));
|
||||
if (i_frame_name == next_frame_name) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
string frame_name = control_flow_info_[tpo[i]->id()].frame_name;
|
||||
absl::Span<Node*> sub_tpo(tpo.data() + frame_start, i - frame_start + 1);
|
||||
frame_start = i + 1;
|
||||
|
||||
// First, try the optimistic mode.
|
||||
bool is_success = false;
|
||||
if (!force_pessimistic && !frame_name.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ true, &is_success));
|
||||
}
|
||||
if (!is_success) {
|
||||
// The optimistic mode does not converge. Let's fall back to the
|
||||
// pessimistic mode.
|
||||
TF_RETURN_IF_ERROR(
|
||||
PopulateFrame(sub_tpo, /*use_optimistic_mode*/ false, nullptr));
|
||||
}
|
||||
if (VLOG_IS_ON(2)) {
|
||||
VLOG(2) << "Done populating frame " << frame_name << " using the "
|
||||
<< (is_success ? "optimistic" : "pessimistic") << " mode.";
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> tpo,
|
||||
bool use_optimistic_mode,
|
||||
bool* is_success) {
|
||||
// This an abstract interpretation over the deadness propagation semantics of
|
||||
// the graph executor.
|
||||
//
|
||||
@ -1086,9 +1303,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
// delta should not change in the second iteration.
|
||||
std::vector<bool> should_revisit;
|
||||
should_revisit.resize(graph_.num_node_ids());
|
||||
for (Node* n : rpo) {
|
||||
for (Node* n : tpo) {
|
||||
VLOG(4) << "Visiting " << n->name();
|
||||
TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode));
|
||||
if (n->IsNextIteration()) {
|
||||
// If this is a backedge for a merge node then remember to reprocess the
|
||||
// merge the next time we run.
|
||||
@ -1100,11 +1318,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
}
|
||||
}
|
||||
|
||||
for (Node* n : rpo) {
|
||||
for (Node* n : tpo) {
|
||||
// The nodes added to should_revisit in the previous loop need to be
|
||||
// revisited now. Reprocesing these initial nodes may add *their* consumers
|
||||
// to should_revisit, and these newly added nodes will also be processed by
|
||||
// this very same loop. Since we're traversing the graph in reverse post
|
||||
// this very same loop. Since we're traversing the graph in topological
|
||||
// order (producers before consumers) and HandleNode(n) can only ever add
|
||||
// n's consumers to should_revisit, we won't "miss" an addition to
|
||||
// should_revisit.
|
||||
@ -1114,6 +1332,69 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder(
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the optimistic analysis converges. Specifically, check whether
|
||||
// all the predicates of the merge nodes in the same frame are the same. If
|
||||
// yes, report success. If not, report failure and clear the assigned
|
||||
// predicates.
|
||||
if (use_optimistic_mode) {
|
||||
bool is_converged = true;
|
||||
absl::flat_hash_map<string, Predicate*> frame_to_pred;
|
||||
for (Node* n : tpo) {
|
||||
if (!n->IsMerge()) {
|
||||
continue;
|
||||
}
|
||||
const Edge* e;
|
||||
TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e));
|
||||
if (e == nullptr) {
|
||||
// Skip acyclic merge nodes.
|
||||
continue;
|
||||
}
|
||||
Node* merge = n;
|
||||
string frame_name = control_flow_info_[merge->id()].frame_name;
|
||||
auto it = predicate_map_.find(TensorId(merge->name(), 0));
|
||||
Predicate* merge_pred = it->second;
|
||||
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
|
||||
is_converged = false;
|
||||
VLOG(2) << "Running the optimistic mode on frame " << frame_name
|
||||
<< " does not converge because node " << merge->name()
|
||||
<< " cannot be mapped into the AndRecurrence form.";
|
||||
break;
|
||||
}
|
||||
|
||||
auto insert_result = frame_to_pred.insert({frame_name, merge_pred});
|
||||
if (!insert_result.second) {
|
||||
// If we have already seen this frame name, verify the predicate is the
|
||||
// same as the previously seen one's.
|
||||
AndRecurrencePredicate* curr_andrec =
|
||||
static_cast<AndRecurrencePredicate*>(merge_pred);
|
||||
AndRecurrencePredicate* prev_andrec =
|
||||
static_cast<AndRecurrencePredicate*>(insert_result.first->second);
|
||||
if (curr_andrec != prev_andrec) {
|
||||
is_converged = false;
|
||||
VLOG(2) << "Running the optimistic mode on frame " << frame_name
|
||||
<< " does not converge. Seeing different Merge predicates: \n"
|
||||
<< curr_andrec->ToString() << " and \n"
|
||||
<< prev_andrec->ToString();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the assigned predicates if the optimistic mode does not converge.
|
||||
if (!is_converged) {
|
||||
for (Node* n : tpo) {
|
||||
for (int oid = 0; oid < n->num_outputs(); ++oid) {
|
||||
predicate_map_.erase(TensorId(n->name(), oid));
|
||||
}
|
||||
predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot));
|
||||
}
|
||||
}
|
||||
|
||||
if (is_success != nullptr) {
|
||||
*is_success = is_converged;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1170,22 +1451,14 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
|
||||
}
|
||||
|
||||
namespace deadness_analysis_internal {
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
PredicateMapTy* out_predicate_map) {
|
||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
|
||||
bool force_pessimistic) {
|
||||
DeadnessAnalysisImpl impl(&graph);
|
||||
TF_RETURN_IF_ERROR(impl.Populate());
|
||||
TF_RETURN_IF_ERROR(impl.Populate(force_pessimistic));
|
||||
*out_predicate_map = impl.PredicateMapAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
absl::Span<Node* const> reverse_post_order,
|
||||
PredicateMapTy* out_predicate_map) {
|
||||
DeadnessAnalysisImpl impl(&graph);
|
||||
TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order));
|
||||
*out_predicate_map = impl.PredicateMapAsString();
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace deadness_analysis_internal
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -25,15 +25,9 @@ namespace deadness_analysis_internal {
|
||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||
// testing purposes only.
|
||||
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
|
||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
|
||||
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
|
||||
bool force_pessimistic = false);
|
||||
|
||||
// Returns a map describing the predicate each Tensor was mapped to. For
|
||||
// testing purposes only. Makes deadness analysis visit the graph in the order
|
||||
// specified in `reverse_post_order` which must be a valid RPO for the graph
|
||||
// minus NextIteration->Merge edges.
|
||||
Status ComputePredicates(const Graph& graph,
|
||||
absl::Span<Node* const> reverse_post_order,
|
||||
PredicateMapTy* out_predicate_map);
|
||||
} // namespace deadness_analysis_internal
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
|
||||
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
|
||||
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||
predicate_map[ControlOutputFor(iv.induction_var)]);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
||||
CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0);
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
|
||||
// To make deadness analysis think that dependent_iv is a loop we need an RPO
|
||||
// that visits the merge before the backedge. This is a legal RPO for
|
||||
// deadness analysis since it ignores NextIteration->Merge edges during RPO.
|
||||
// Right now dependent_iv has an edge from Merge to NextIteration so do the
|
||||
// RPO with this edge in place. Then remove this edge to get our test case.
|
||||
std::vector<Node*> rpo;
|
||||
GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{},
|
||||
/*edge_filter=*/[](const Edge& edge) {
|
||||
return !edge.src()->IsNextIteration();
|
||||
});
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
||||
iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0));
|
||||
|
||||
@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
|
||||
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map));
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<frame>");
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
|
||||
"div0/iv:0");
|
||||
@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
|
||||
"{(*iv_outer/cond:0 & "
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
|
||||
"cond:0}<inner_loop;outer_loop>");
|
||||
|
||||
// force_pessimistic = true or not should produce the same results because
|
||||
// of fallback. However, note that the order of iv_inner/cond:0 and
|
||||
// iv_inner/iv:0 is different because the optimistic approach does not
|
||||
// create predicates for all merges and it can change the predicate id and
|
||||
// hence the symbol order.
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
|
||||
"iv_inner/iv:0)}<inner_loop;outer_loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
|
||||
"{#true,&,*iv_outer/cond:0}<outer_loop>");
|
||||
@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
|
||||
"{{#true,&,(iv_outer/iv:0 & "
|
||||
"*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
|
||||
"*iv_inner/cond:0)}<inner_loop;outer_loop>");
|
||||
predicate_map[ControlOutputFor(dependent_inner_iv0)]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
InductionVarInfo iv_outer =
|
||||
CreateInductionVariable(root, "iv_outer", "outer_loop", 0);
|
||||
Output enter_constant_outer_loop = ops::internal::Enter(
|
||||
root.WithOpName("constant_enter_outer_loop"),
|
||||
ops::Const(root.WithOpName("constant"), 5), "outer_loop",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
ops::Switch inner_value(root.WithOpName("outer_is_live"),
|
||||
enter_constant_outer_loop, iv_outer.loop_cond);
|
||||
InductionVarInfo iv_inner = CreateInductionVariable(
|
||||
root, "iv_inner", "inner_loop", inner_value.output_true);
|
||||
|
||||
DependentInductionVar div0_outer = CreateDependentLoopInvariantValue(
|
||||
root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0);
|
||||
DependentInductionVar div1_outer = CreateDependentLoopInvariantValue(
|
||||
root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0);
|
||||
|
||||
DependentInductionVar div0_inner = CreateDependentLoopInvariantValue(
|
||||
root, "div0_inner", "inner_loop", iv_inner.loop_cond,
|
||||
div0_outer.induction_var);
|
||||
DependentInductionVar div1_inner = CreateDependentLoopInvariantValue(
|
||||
root, "div1_inner", "inner_loop", iv_inner.loop_cond,
|
||||
div1_outer.induction_var);
|
||||
|
||||
Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32,
|
||||
"tensor_a", "sender", 0, "receiver");
|
||||
Output capture_enter_outer = ops::internal::Enter(
|
||||
root.WithOpName("capture_enter_outer"), captured, "outer_loop",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
Output capture_enter_inner = ops::internal::Enter(
|
||||
root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop",
|
||||
ops::internal::Enter::Attrs().IsConstant(true));
|
||||
Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var,
|
||||
capture_enter_inner);
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(
|
||||
mul0.node(), 0, div1_inner.latch.output_true.node(), 0));
|
||||
|
||||
Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var,
|
||||
div1_inner.induction_var);
|
||||
|
||||
VLogGraphIfAsked(*root.graph());
|
||||
|
||||
{
|
||||
std::unique_ptr<DeadnessAnalysis> result;
|
||||
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool has_inputs_with_mismatching_deadness,
|
||||
HasInputsWithMismatchingDeadness(*result, *add0.node()));
|
||||
EXPECT_TRUE(has_inputs_with_mismatching_deadness);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, CyclicRecurrence) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0);
|
||||
DependentInductionVar div0 =
|
||||
CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0);
|
||||
DependentInductionVar div1 =
|
||||
CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0);
|
||||
FixupSourceAndSinkEdges(root.graph());
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0,
|
||||
div0.latch.output_true.node(), 0));
|
||||
TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0,
|
||||
div1.latch.output_true.node(), 0));
|
||||
|
||||
VLogGraphIfAsked(*root.graph());
|
||||
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ false));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
|
||||
// This tests the rule {S,&,X} & ~X => S.
|
||||
TensorId switch_false_out = {div1.latch.output_false.node()->name(),
|
||||
div1.latch.output_false.index()};
|
||||
EXPECT_EQ(predicate_map[switch_false_out], "(#true)");
|
||||
}
|
||||
{
|
||||
PredicateMapTy predicate_map;
|
||||
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
|
||||
/*force_pessimistic*/ true));
|
||||
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
|
||||
"{#true,&,*iv0/cond:0}<loop>");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0");
|
||||
EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0");
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);
|
||||
|
Loading…
Reference in New Issue
Block a user