parent
270a640994
commit
fd32a7f773
@ -165,7 +165,7 @@ Status RewriteAndPruneGraph(
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
|
||||||
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
|
VLOG(2) << "Post rewrite: " << DumpGraphToFile("tf2xla_post_rewrite", *graph);
|
||||||
PruneForReverseReachability(graph, retval_nodes);
|
PruneForReverseReachability(graph, std::move(retval_nodes));
|
||||||
FixupSourceAndSinkEdges(graph);
|
FixupSourceAndSinkEdges(graph);
|
||||||
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
|
VLOG(2) << "Post prune: " << DumpGraphToFile("tfcompile_post_prune", *graph);
|
||||||
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
// Sanity-check, to make sure the feeds and fetches still exist post-pruning.
|
||||||
|
@ -209,13 +209,9 @@ void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
|
|||||||
bool PruneForReverseReachability(Graph* g,
|
bool PruneForReverseReachability(Graph* g,
|
||||||
std::unordered_set<const Node*> visited) {
|
std::unordered_set<const Node*> visited) {
|
||||||
// Compute set of nodes that we need to traverse in order to reach
|
// Compute set of nodes that we need to traverse in order to reach
|
||||||
// the nodes in "nodes" by performing a breadth-first search from those
|
// the nodes in "visited" by performing a breadth-first search from those
|
||||||
// nodes, and accumulating the visited nodes.
|
// nodes, and accumulating the visited nodes.
|
||||||
std::deque<const Node*> queue;
|
std::deque<const Node*> queue(visited.begin(), visited.end());
|
||||||
for (const Node* n : visited) {
|
|
||||||
VLOG(2) << "Reverse reach init: " << n->name();
|
|
||||||
queue.push_back(n);
|
|
||||||
}
|
|
||||||
while (!queue.empty()) {
|
while (!queue.empty()) {
|
||||||
const Node* n = queue.front();
|
const Node* n = queue.front();
|
||||||
queue.pop_front();
|
queue.pop_front();
|
||||||
@ -227,21 +223,14 @@ bool PruneForReverseReachability(Graph* g,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make a pass over the graph to remove nodes not in "visited"
|
// Make a pass over the graph to remove nodes not in "visited".
|
||||||
std::vector<Node*> all_nodes;
|
|
||||||
all_nodes.reserve(g->num_nodes());
|
|
||||||
for (Node* n : g->nodes()) {
|
|
||||||
all_nodes.push_back(n);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool any_removed = false;
|
bool any_removed = false;
|
||||||
for (Node* n : all_nodes) {
|
for (Node* n : g->op_nodes()) {
|
||||||
if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
|
if (visited.find(n) == visited.end()) {
|
||||||
g->RemoveNode(n);
|
g->RemoveNode(n);
|
||||||
any_removed = true;
|
any_removed = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return any_removed;
|
return any_removed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -484,6 +484,9 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
|
|||||||
} else {
|
} else {
|
||||||
e = free_edges_.back();
|
e = free_edges_.back();
|
||||||
free_edges_.pop_back();
|
free_edges_.pop_back();
|
||||||
|
#ifdef ADDRESS_SANITIZER
|
||||||
|
ASAN_UNPOISON_MEMORY_REGION(e, sizeof(Edge));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
e->id_ = edges_.size();
|
e->id_ = edges_.size();
|
||||||
e->src_ = source;
|
e->src_ = source;
|
||||||
@ -511,13 +514,10 @@ void Graph::RemoveEdge(const Edge* e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Graph::RecycleEdge(const Edge* e) {
|
void Graph::RecycleEdge(const Edge* e) {
|
||||||
Edge* del = const_cast<Edge*>(e);
|
free_edges_.push_back(const_cast<Edge*>(e));
|
||||||
del->src_ = nullptr;
|
#ifdef ADDRESS_SANITIZER
|
||||||
del->dst_ = nullptr;
|
ASAN_POISON_MEMORY_REGION(e, sizeof(Edge));
|
||||||
del->id_ = -1;
|
#endif
|
||||||
del->src_output_ = kControlSlot - 1;
|
|
||||||
del->dst_input_ = kControlSlot - 1;
|
|
||||||
free_edges_.push_back(del);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const Edge* Graph::AddControlEdge(Node* source, Node* dest,
|
const Edge* Graph::AddControlEdge(Node* source, Node* dest,
|
||||||
|
@ -782,5 +782,42 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
|
|||||||
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
|
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
|
||||||
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
|
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
|
||||||
|
|
||||||
|
static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) {
|
||||||
|
testing::StopTiming();
|
||||||
|
const GraphDef graph_def =
|
||||||
|
test::CreateGraphDef(num_nodes, num_edges_per_node);
|
||||||
|
const auto registry = OpRegistry::Global();
|
||||||
|
GraphConstructorOptions opts;
|
||||||
|
for (int i = 0; i < iters; ++i) {
|
||||||
|
Graph graph(registry);
|
||||||
|
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
|
||||||
|
testing::StartTiming();
|
||||||
|
for (Node* n : graph.op_nodes()) {
|
||||||
|
graph.RemoveNode(n);
|
||||||
|
}
|
||||||
|
testing::StopTiming();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 2);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 2);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 2);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 2);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(10, 4);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 4);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 4);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 4);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 4);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(10, 8);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 8);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 8);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 8);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 8);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(10, 16);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 6, 16);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 9, 16);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 12, 16);
|
||||||
|
BENCHMARK(BM_RemoveNode)->ArgPair(1 << 15, 16);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -207,7 +207,7 @@ Status PruneForTargets(Graph* g, const NameIndex& name_index,
|
|||||||
return errors::NotFound("PruneForTargets: Some target nodes not found: ",
|
return errors::NotFound("PruneForTargets: Some target nodes not found: ",
|
||||||
not_found);
|
not_found);
|
||||||
}
|
}
|
||||||
PruneForReverseReachability(g, targets);
|
PruneForReverseReachability(g, std::move(targets));
|
||||||
|
|
||||||
// Reconnect nodes with no outgoing edges to the sink node
|
// Reconnect nodes with no outgoing edges to the sink node
|
||||||
FixupSourceAndSinkEdges(g);
|
FixupSourceAndSinkEdges(g);
|
||||||
|
Loading…
Reference in New Issue
Block a user