From 30749f263eb5d3bc8568400a3e3c559216ba685d Mon Sep 17 00:00:00 2001
From: Russell Power <power@google.com>
Date: Mon, 1 Feb 2021 16:48:44 -0800
Subject: [PATCH] Add minor optimization for graph copies.

Reserve input/output edgeset sizes when copying graphs.

PiperOrigin-RevId: 355055758
Change-Id: Id78260cda6f8bf9ed30663ecc819b5936fff26a8
---
 .../core/common_runtime/graph_constructor.cc  | 27 +---------------
 tensorflow/core/graph/edgeset.h               |  9 ++++++
 tensorflow/core/graph/graph.cc                | 31 +++++++++++++++++++
 tensorflow/core/graph/graph.h                 |  2 ++
 tensorflow/core/graph/graph_test.cc           |  3 +-
 5 files changed, 45 insertions(+), 27 deletions(-)

diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc
index 639739e9cac..f971b6ee9af 100644
--- a/tensorflow/core/common_runtime/graph_constructor.cc
+++ b/tensorflow/core/common_runtime/graph_constructor.cc
@@ -1557,31 +1557,6 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
   }
 }
 
-void CopyGraph(const Graph& src, Graph* dest) {
-  dest->SetConstructionContext(src.GetConstructionContextInternal());
-
-  for (Node* n : dest->nodes()) {
-    CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
-  }
-
-  // Copy GraphDef versions
-  dest->set_versions(src.versions());
-
-  // Copy the nodes.
-  // "Node in src" -> "Node in *dest"
-  gtl::FlatMap<const Node*, Node*> node_map;
-  node_map[src.source_node()] = dest->source_node();
-  node_map[src.sink_node()] = dest->sink_node();
-  for (Node* n : src.op_nodes()) {
-    node_map[n] = dest->CopyNode(n);
-  }
-
-  // Copy the edges
-  for (const Edge* e : src.edges()) {
-    Node* src_copy = node_map[e->src()];
-    Node* dst_copy = node_map[e->dst()];
-    dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
-  }
-}
+void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h
index 2776c8491c2..c019dd3b957 100644
--- a/tensorflow/core/graph/edgeset.h
+++ b/tensorflow/core/graph/edgeset.h
@@ -47,6 +47,15 @@ class EdgeSet {
   void clear();
   std::pair<iterator, bool> insert(value_type value);
   size_type erase(key_type key);
+  void reserve(size_type new_size) {
+    if (new_size > kInline) {
+      auto s = new gtl::FlatSet<const Edge*>(new_size);
+      s->insert(reinterpret_cast<const Edge**>(std::begin(ptrs_)),
+                reinterpret_cast<const Edge**>(&ptrs_[0] + size()));
+      ptrs_[0] = this;
+      ptrs_[1] = s;
+    }
+  }
 
   // Caller is not allowed to mutate the EdgeSet while iterating.
   const_iterator begin() const;
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 2e7e7fbf4c3..56f16000ea2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -414,6 +414,37 @@ Graph::~Graph() {
 const VersionDef& Graph::versions() const { return *versions_; }
 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
 
+void Graph::Copy(const Graph& src) {
+  SetConstructionContext(src.GetConstructionContextInternal());
+  for (Node* n : nodes()) {
+    CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
+  }
+
+  // Copy GraphDef versions
+  set_versions(src.versions());
+
+  // Copy the nodes.
+  // "Node in src" -> "Node in *dest"
+  gtl::FlatMap<const Node*, Node*> node_map;
+  node_map.reserve(src.num_nodes());
+  node_map[src.source_node()] = source_node();
+  node_map[src.sink_node()] = sink_node();
+  for (Node* n : src.op_nodes()) {
+    auto copy = CopyNode(n);
+    copy->in_edges_.reserve(n->in_edges().size());
+    copy->out_edges_.reserve(n->out_edges().size());
+    node_map[n] = copy;
+  }
+
+  // Copy the edges
+  edges_.reserve(src.num_edges());
+  for (const Edge* e : src.edges()) {
+    Node* src_copy = node_map[e->src()];
+    Node* dst_copy = node_map[e->dst()];
+    AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
+  }
+}
+
 Node* Graph::AddNode(NodeDef node_def, Status* status) {
   const OpRegistrationData* op_reg_data;
   status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index cc5c5b2f2d3..998d2ab5621 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -537,6 +537,8 @@ class Graph {
   // REQUIRES: node->IsOp()
   void RemoveNode(Node* node);
 
+  void Copy(const Graph& src);
+
   // Adds an edge that connects the xth output of `source` to the yth input of
   // `dest` and returns it. Does not update dest's NodeDef.
   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
index 2801bd7c961..2219b797e36 100644
--- a/tensorflow/core/graph/graph_test.cc
+++ b/tensorflow/core/graph/graph_test.cc
@@ -787,9 +787,10 @@ void BM_RemoveNode(::testing::benchmark::State& state) {
   const auto registry = OpRegistry::Global();
   GraphConstructorOptions opts;
   for (auto s : state) {
+    state.PauseTiming();
     Graph graph(registry);
     TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
-    testing::StartTiming();
+    state.ResumeTiming();
     for (Node* n : graph.op_nodes()) {
       graph.RemoveNode(n);
     }