From fd32a7f773198e8521589e6a1cabbcdd54151699 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 4 Apr 2019 09:30:49 -0700 Subject: [PATCH] Automated rollback of commit fe30579d67b7c9c1af652fa42ccc94dd200627d4 PiperOrigin-RevId: 241944783 --- tensorflow/compiler/tf2xla/tf2xla.cc | 2 +- tensorflow/core/graph/algorithm.cc | 21 ++++------------ tensorflow/core/graph/graph.cc | 14 +++++------ tensorflow/core/graph/graph_test.cc | 37 ++++++++++++++++++++++++++++ tensorflow/core/graph/subgraph.cc | 2 +- 5 files changed, 51 insertions(+), 25 deletions(-) diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 57fb8367446..d10529774a9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -165,7 +165,7 @@ Status RewriteAndPruneGraph( TF_RETURN_IF_ERROR( AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); - PruneForReverseReachability(graph, retval_nodes); + PruneForReverseReachability(graph, std::move(retval_nodes)); FixupSourceAndSinkEdges(graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); // Sanity-check, to make sure the feeds and fetches still exist post-pruning. diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index ff972e3ca0d..7c469c2f8d8 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -209,13 +209,9 @@ void GetReversePostOrder(const Graph& g, std::vector* order, bool PruneForReverseReachability(Graph* g, std::unordered_set visited) { // Compute set of nodes that we need to traverse in order to reach - // the nodes in "nodes" by performing a breadth-first search from those + // the nodes in "visited" by performing a breadth-first search from those // nodes, and accumulating the visited nodes. - std::deque queue; - for (const Node* n : visited) { - VLOG(2) << "Reverse reach init: " << n->name(); - queue.push_back(n); - } + std::deque queue(visited.begin(), visited.end()); while (!queue.empty()) { const Node* n = queue.front(); queue.pop_front(); @@ -227,21 +223,14 @@ bool PruneForReverseReachability(Graph* g, } } - // Make a pass over the graph to remove nodes not in "visited" - std::vector all_nodes; - all_nodes.reserve(g->num_nodes()); - for (Node* n : g->nodes()) { - all_nodes.push_back(n); - } - + // Make a pass over the graph to remove nodes not in "visited". bool any_removed = false; - for (Node* n : all_nodes) { - if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) { + for (Node* n : g->op_nodes()) { + if (visited.find(n) == visited.end()) { g->RemoveNode(n); any_removed = true; } } - return any_removed; } diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 43010ea60b2..e2b1f7b3044 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -484,6 +484,9 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { } else { e = free_edges_.back(); free_edges_.pop_back(); +#ifdef ADDRESS_SANITIZER + ASAN_UNPOISON_MEMORY_REGION(e, sizeof(Edge)); +#endif } e->id_ = edges_.size(); e->src_ = source; @@ -511,13 +514,10 @@ void Graph::RemoveEdge(const Edge* e) { } void Graph::RecycleEdge(const Edge* e) { - Edge* del = const_cast(e); - del->src_ = nullptr; - del->dst_ = nullptr; - del->id_ = -1; - del->src_output_ = kControlSlot - 1; - del->dst_input_ = kControlSlot - 1; - free_edges_.push_back(del); + free_edges_.push_back(const_cast(e)); +#ifdef ADDRESS_SANITIZER + ASAN_POISON_MEMORY_REGION(e, sizeof(Edge)); +#endif } const Edge* Graph::AddControlEdge(Node* source, Node* dest, diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 5fa42d32fd9..64e0fa70e5c 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -782,5 +782,42 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16); +static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateGraphDef(num_nodes, num_edges_per_node); + const auto registry = OpRegistry::Global(); + GraphConstructorOptions opts; + for (int i = 0; i < iters; ++i) { + Graph graph(registry); + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + testing::StartTiming(); + for (Node* n : graph.op_nodes()) { + graph.RemoveNode(n); + } + testing::StopTiming(); + } +} +BENCHMARK(BM_RemoveNode)->ArgPair(10, 2); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2); +BENCHMARK(BM_RemoveNode)->ArgPair(10, 4); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4); +BENCHMARK(BM_RemoveNode)->ArgPair(10, 8); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8); +BENCHMARK(BM_RemoveNode)->ArgPair(10, 16); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16); +BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 60337e30aa5..7d839723f89 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -207,7 +207,7 @@ Status PruneForTargets(Graph* g, const NameIndex& name_index, return errors::NotFound("PruneForTargets: Some target nodes not found: ", not_found); } - PruneForReverseReachability(g, targets); + PruneForReverseReachability(g, std::move(targets)); // Reconnect nodes with no outgoing edges to the sink node FixupSourceAndSinkEdges(g);