Automated rollback of commit fe30579d67

PiperOrigin-RevId: 241944783
This commit is contained in:
A. Unique TensorFlower 2019-04-04 09:30:49 -07:00 committed by TensorFlower Gardener
parent 270a640994
commit fd32a7f773
5 changed files with 51 additions and 25 deletions

View File

@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
PruneForReverseReachability(graph, retval_nodes);
PruneForReverseReachability(graph, std::move(retval_nodes));
FixupSourceAndSinkEdges(graph);
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.

View File

@ -209,13 +209,9 @@ void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
bool PruneForReverseReachability(Graph* g,
std::unordered_set<const Node*> visited) {
// Compute set of nodes that we need to traverse in order to reach
// the nodes in "nodes" by performing a breadth-first search from those
// the nodes in "visited" by performing a breadth-first search from those
// nodes, and accumulating the visited nodes.
std::deque<const Node*> queue;
for (const Node* n : visited) {
VLOG(2) << "Reverse reach init: " << n->name();
queue.push_back(n);
}
std::deque<const Node*> queue(visited.begin(), visited.end());
while (!queue.empty()) {
const Node* n = queue.front();
queue.pop_front();
@ -227,21 +223,14 @@ bool PruneForReverseReachability(Graph* g,
}
}
// Make a pass over the graph to remove nodes not in "visited"
std::vector<Node*> all_nodes;
all_nodes.reserve(g->num_nodes());
for (Node* n : g->nodes()) {
all_nodes.push_back(n);
}
// Make a pass over the graph to remove nodes not in "visited".
bool any_removed = false;
for (Node* n : all_nodes) {
if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
for (Node* n : g->op_nodes()) {
if (visited.find(n) == visited.end()) {
g->RemoveNode(n);
any_removed = true;
}
}
return any_removed;
}

View File

@ -484,6 +484,9 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
} else {
e = free_edges_.back();
free_edges_.pop_back();
#ifdef ADDRESS_SANITIZER
ASAN_UNPOISON_MEMORY_REGION(e, sizeof(Edge));
#endif
}
e->id_ = edges_.size();
e->src_ = source;
@ -511,13 +514,10 @@ void Graph::RemoveEdge(const Edge* e) {
}
void Graph::RecycleEdge(const Edge* e) {
Edge* del = const_cast<Edge*>(e);
del->src_ = nullptr;
del->dst_ = nullptr;
del->id_ = -1;
del->src_output_ = kControlSlot - 1;
del->dst_input_ = kControlSlot - 1;
free_edges_.push_back(del);
free_edges_.push_back(const_cast<Edge*>(e));
#ifdef ADDRESS_SANITIZER
ASAN_POISON_MEMORY_REGION(e, sizeof(Edge));
#endif
}
const Edge* Graph::AddControlEdge(Node* source, Node* dest,

View File

@ -782,5 +782,42 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global();
GraphConstructorOptions opts;
for (int i = 0; i < iters; ++i) {
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
testing::StartTiming();
for (Node* n : graph.op_nodes()) {
graph.RemoveNode(n);
}
testing::StopTiming();
}
}
BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(10, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(10, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(10, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16);
} // namespace
} // namespace tensorflow

View File

@ -207,7 +207,7 @@ Status PruneForTargets(Graph* g, const NameIndex& name_index,
return errors::NotFound("PruneForTargets: Some target nodes not found: ",
not_found);
}
PruneForReverseReachability(g, targets);
PruneForReverseReachability(g, std::move(targets));
// Reconnect nodes with no outgoing edges to the sink node
FixupSourceAndSinkEdges(g);