Optimize tf graph manipulation:

* Don't copy all nodes in PruneForReverseReachability.
  * std::move target nodes when calling PruneForReverseReachability.
  * Don't clear content of nodes when moving them to the free list: They are immediately overwritten when re-used, so no need to touch the memory where they reside when recycling them.

PiperOrigin-RevId: 241841395
This commit is contained in:
A. Unique TensorFlower 2019-04-03 17:43:50 -07:00 committed by TensorFlower Gardener
parent d258fb3028
commit fe30579d67
5 changed files with 51 additions and 25 deletions

View File

@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes)); AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph); VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
PruneForReverseReachability(graph, retval_nodes); PruneForReverseReachability(graph, std::move(retval_nodes));
FixupSourceAndSinkEdges(graph); FixupSourceAndSinkEdges(graph);
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph); VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
// Sanity-check, to make sure the feeds and fetches still exist post-pruning. // Sanity-check, to make sure the feeds and fetches still exist post-pruning.

View File

@ -209,13 +209,9 @@ void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
bool PruneForReverseReachability(Graph* g, bool PruneForReverseReachability(Graph* g,
std::unordered_set<const Node*> visited) { std::unordered_set<const Node*> visited) {
// Compute set of nodes that we need to traverse in order to reach // Compute set of nodes that we need to traverse in order to reach
// the nodes in "nodes" by performing a breadth-first search from those // the nodes in "visited" by performing a breadth-first search from those
// nodes, and accumulating the visited nodes. // nodes, and accumulating the visited nodes.
std::deque<const Node*> queue; std::deque<const Node*> queue(visited.begin(), visited.end());
for (const Node* n : visited) {
VLOG(2) << "Reverse reach init: " << n->name();
queue.push_back(n);
}
while (!queue.empty()) { while (!queue.empty()) {
const Node* n = queue.front(); const Node* n = queue.front();
queue.pop_front(); queue.pop_front();
@ -227,21 +223,14 @@ bool PruneForReverseReachability(Graph* g,
} }
} }
// Make a pass over the graph to remove nodes not in "visited" // Make a pass over the graph to remove nodes not in "visited".
std::vector<Node*> all_nodes;
all_nodes.reserve(g->num_nodes());
for (Node* n : g->nodes()) {
all_nodes.push_back(n);
}
bool any_removed = false; bool any_removed = false;
for (Node* n : all_nodes) { for (Node* n : g->op_nodes()) {
if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) { if (visited.find(n) == visited.end()) {
g->RemoveNode(n); g->RemoveNode(n);
any_removed = true; any_removed = true;
} }
} }
return any_removed; return any_removed;
} }

View File

@ -484,6 +484,9 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
} else { } else {
e = free_edges_.back(); e = free_edges_.back();
free_edges_.pop_back(); free_edges_.pop_back();
#ifdef ADDRESS_SANITIZER
ASAN_UNPOISON_MEMORY_REGION(static_cast<void*>(e), sizeof(Edge));
#endif
} }
e->id_ = edges_.size(); e->id_ = edges_.size();
e->src_ = source; e->src_ = source;
@ -511,13 +514,10 @@ void Graph::RemoveEdge(const Edge* e) {
} }
void Graph::RecycleEdge(const Edge* e) { void Graph::RecycleEdge(const Edge* e) {
Edge* del = const_cast<Edge*>(e); free_edges_.push_back(const_cast<Edge*>(e));
del->src_ = nullptr; #ifdef ADDRESS_SANITIZER
del->dst_ = nullptr; ASAN_POISON_MEMORY_REGION(static_cast<void*>(e), sizeof(Edge));
del->id_ = -1; #endif
del->src_output_ = kControlSlot - 1;
del->dst_input_ = kControlSlot - 1;
free_edges_.push_back(del);
} }
const Edge* Graph::AddControlEdge(Node* source, Node* dest, const Edge* Graph::AddControlEdge(Node* source, Node* dest,

View File

@ -782,5 +782,42 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global();
GraphConstructorOptions opts;
for (int i = 0; i < iters; ++i) {
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
testing::StartTiming();
for (Node* n : graph.op_nodes()) {
graph.RemoveNode(n);
}
testing::StopTiming();
}
}
BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2);
BENCHMARK(BM_RemoveNode)->ArgPair(10, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4);
BENCHMARK(BM_RemoveNode)->ArgPair(10, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8);
BENCHMARK(BM_RemoveNode)->ArgPair(10, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16);
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16);
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -207,7 +207,7 @@ Status PruneForTargets(Graph* g, const NameIndex& name_index,
return errors::NotFound("PruneForTargets: Some target nodes not found: ", return errors::NotFound("PruneForTargets: Some target nodes not found: ",
not_found); not_found);
} }
PruneForReverseReachability(g, targets); PruneForReverseReachability(g, std::move(targets));
// Reconnect nodes with no outgoing edges to the sink node // Reconnect nodes with no outgoing edges to the sink node
FixupSourceAndSinkEdges(g); FixupSourceAndSinkEdges(g);