From 30749f263eb5d3bc8568400a3e3c559216ba685d Mon Sep 17 00:00:00 2001 From: Russell Power <power@google.com> Date: Mon, 1 Feb 2021 16:48:44 -0800 Subject: [PATCH] Add minor optimization for graph copies. Reserve input/output edgeset sizes when copying graphs. PiperOrigin-RevId: 355055758 Change-Id: Id78260cda6f8bf9ed30663ecc819b5936fff26a8 --- .../core/common_runtime/graph_constructor.cc | 27 +--------------- tensorflow/core/graph/edgeset.h | 9 ++++++ tensorflow/core/graph/graph.cc | 31 +++++++++++++++++++ tensorflow/core/graph/graph.h | 2 ++ tensorflow/core/graph/graph_test.cc | 3 +- 5 files changed, 45 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 639739e9cac..f971b6ee9af 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -1557,31 +1557,6 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, } } -void CopyGraph(const Graph& src, Graph* dest) { - dest->SetConstructionContext(src.GetConstructionContextInternal()); - - for (Node* n : dest->nodes()) { - CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; - } - - // Copy GraphDef versions - dest->set_versions(src.versions()); - - // Copy the nodes. - // "Node in src" -> "Node in *dest" - gtl::FlatMap<const Node*, Node*> node_map; - node_map[src.source_node()] = dest->source_node(); - node_map[src.sink_node()] = dest->sink_node(); - for (Node* n : src.op_nodes()) { - node_map[n] = dest->CopyNode(n); - } - - // Copy the edges - for (const Edge* e : src.edges()) { - Node* src_copy = node_map[e->src()]; - Node* dst_copy = node_map[e->dst()]; - dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); - } -} +void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); } } // namespace tensorflow diff --git a/tensorflow/core/graph/edgeset.h b/tensorflow/core/graph/edgeset.h index 2776c8491c2..c019dd3b957 100644 --- a/tensorflow/core/graph/edgeset.h +++ b/tensorflow/core/graph/edgeset.h @@ -47,6 +47,15 @@ class EdgeSet { void clear(); std::pair<iterator, bool> insert(value_type value); size_type erase(key_type key); + void reserve(size_type new_size) { + if (new_size > kInline) { + auto s = new gtl::FlatSet<const Edge*>(new_size); + s->insert(reinterpret_cast<const Edge**>(std::begin(ptrs_)), + reinterpret_cast<const Edge**>(&ptrs_[0] + size())); + ptrs_[0] = this; + ptrs_[1] = s; + } + } // Caller is not allowed to mutate the EdgeSet while iterating. const_iterator begin() const; diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 2e7e7fbf4c3..56f16000ea2 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -414,6 +414,37 @@ Graph::~Graph() { const VersionDef& Graph::versions() const { return *versions_; } void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; } +void Graph::Copy(const Graph& src) { + SetConstructionContext(src.GetConstructionContextInternal()); + for (Node* n : nodes()) { + CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; + } + + // Copy GraphDef versions + set_versions(src.versions()); + + // Copy the nodes. + // "Node in src" -> "Node in *dest" + gtl::FlatMap<const Node*, Node*> node_map; + node_map.reserve(src.num_nodes()); + node_map[src.source_node()] = source_node(); + node_map[src.sink_node()] = sink_node(); + for (Node* n : src.op_nodes()) { + auto copy = CopyNode(n); + copy->in_edges_.reserve(n->in_edges().size()); + copy->out_edges_.reserve(n->out_edges().size()); + node_map[n] = copy; + } + + // Copy the edges + edges_.reserve(src.num_edges()); + for (const Edge* e : src.edges()) { + Node* src_copy = node_map[e->src()]; + Node* dst_copy = node_map[e->dst()]; + AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); + } +} + Node* Graph::AddNode(NodeDef node_def, Status* status) { const OpRegistrationData* op_reg_data; status->Update(ops_.LookUp(node_def.op(), &op_reg_data)); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index cc5c5b2f2d3..998d2ab5621 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -537,6 +537,8 @@ class Graph { // REQUIRES: node->IsOp() void RemoveNode(Node* node); + void Copy(const Graph& src); + // Adds an edge that connects the xth output of `source` to the yth input of // `dest` and returns it. Does not update dest's NodeDef. const Edge* AddEdge(Node* source, int x, Node* dest, int y); diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 2801bd7c961..2219b797e36 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -787,9 +787,10 @@ void BM_RemoveNode(::testing::benchmark::State& state) { const auto registry = OpRegistry::Global(); GraphConstructorOptions opts; for (auto s : state) { + state.PauseTiming(); Graph graph(registry); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); - testing::StartTiming(); + state.ResumeTiming(); for (Node* n : graph.op_nodes()) { graph.RemoveNode(n); }