Add minor optimization for graph copies.

Reserve input/output edgeset sizes when copying graphs.

PiperOrigin-RevId: 355055758
Change-Id: Id78260cda6f8bf9ed30663ecc819b5936fff26a8
This commit is contained in:
Russell Power 2021-02-01 16:48:44 -08:00 committed by TensorFlower Gardener
parent d981631696
commit 30749f263e
5 changed files with 45 additions and 27 deletions

View File

@ -1557,31 +1557,6 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
}
}
void CopyGraph(const Graph& src, Graph* dest) {
dest->SetConstructionContext(src.GetConstructionContextInternal());
for (Node* n : dest->nodes()) {
CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
}
// Copy GraphDef versions
dest->set_versions(src.versions());
// Copy the nodes.
// "Node in src" -> "Node in *dest"
gtl::FlatMap<const Node*, Node*> node_map;
node_map[src.source_node()] = dest->source_node();
node_map[src.sink_node()] = dest->sink_node();
for (Node* n : src.op_nodes()) {
node_map[n] = dest->CopyNode(n);
}
// Copy the edges
for (const Edge* e : src.edges()) {
Node* src_copy = node_map[e->src()];
Node* dst_copy = node_map[e->dst()];
dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
}
}
void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
} // namespace tensorflow

View File

@ -47,6 +47,15 @@ class EdgeSet {
void clear();
std::pair<iterator, bool> insert(value_type value);
size_type erase(key_type key);
void reserve(size_type new_size) {
if (new_size > kInline) {
auto s = new gtl::FlatSet<const Edge*>(new_size);
s->insert(reinterpret_cast<const Edge**>(std::begin(ptrs_)),
reinterpret_cast<const Edge**>(&ptrs_[0] + size()));
ptrs_[0] = this;
ptrs_[1] = s;
}
}
// Caller is not allowed to mutate the EdgeSet while iterating.
const_iterator begin() const;

View File

@ -414,6 +414,37 @@ Graph::~Graph() {
const VersionDef& Graph::versions() const { return *versions_; }
void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
void Graph::Copy(const Graph& src) {
SetConstructionContext(src.GetConstructionContextInternal());
for (Node* n : nodes()) {
CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
}
// Copy GraphDef versions
set_versions(src.versions());
// Copy the nodes.
// "Node in src" -> "Node in *dest"
gtl::FlatMap<const Node*, Node*> node_map;
node_map.reserve(src.num_nodes());
node_map[src.source_node()] = source_node();
node_map[src.sink_node()] = sink_node();
for (Node* n : src.op_nodes()) {
auto copy = CopyNode(n);
copy->in_edges_.reserve(n->in_edges().size());
copy->out_edges_.reserve(n->out_edges().size());
node_map[n] = copy;
}
// Copy the edges
edges_.reserve(src.num_edges());
for (const Edge* e : src.edges()) {
Node* src_copy = node_map[e->src()];
Node* dst_copy = node_map[e->dst()];
AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
}
}
Node* Graph::AddNode(NodeDef node_def, Status* status) {
const OpRegistrationData* op_reg_data;
status->Update(ops_.LookUp(node_def.op(), &op_reg_data));

View File

@ -537,6 +537,8 @@ class Graph {
// REQUIRES: node->IsOp()
void RemoveNode(Node* node);
void Copy(const Graph& src);
// Adds an edge that connects the xth output of `source` to the yth input of
// `dest` and returns it. Does not update dest's NodeDef.
const Edge* AddEdge(Node* source, int x, Node* dest, int y);

View File

@ -787,9 +787,10 @@ void BM_RemoveNode(::testing::benchmark::State& state) {
const auto registry = OpRegistry::Global();
GraphConstructorOptions opts;
for (auto s : state) {
state.PauseTiming();
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
testing::StartTiming();
state.ResumeTiming();
for (Node* n : graph.op_nodes()) {
graph.RemoveNode(n);
}