Add minor optimization for graph copies.
Reserve input/output edgeset sizes when copying graphs. PiperOrigin-RevId: 355055758 Change-Id: Id78260cda6f8bf9ed30663ecc819b5936fff26a8
This commit is contained in:
parent
d981631696
commit
30749f263e
@ -1557,31 +1557,6 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
||||
}
|
||||
}
|
||||
|
||||
void CopyGraph(const Graph& src, Graph* dest) {
|
||||
dest->SetConstructionContext(src.GetConstructionContextInternal());
|
||||
|
||||
for (Node* n : dest->nodes()) {
|
||||
CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
|
||||
}
|
||||
|
||||
// Copy GraphDef versions
|
||||
dest->set_versions(src.versions());
|
||||
|
||||
// Copy the nodes.
|
||||
// "Node in src" -> "Node in *dest"
|
||||
gtl::FlatMap<const Node*, Node*> node_map;
|
||||
node_map[src.source_node()] = dest->source_node();
|
||||
node_map[src.sink_node()] = dest->sink_node();
|
||||
for (Node* n : src.op_nodes()) {
|
||||
node_map[n] = dest->CopyNode(n);
|
||||
}
|
||||
|
||||
// Copy the edges
|
||||
for (const Edge* e : src.edges()) {
|
||||
Node* src_copy = node_map[e->src()];
|
||||
Node* dst_copy = node_map[e->dst()];
|
||||
dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
|
||||
}
|
||||
}
|
||||
void CopyGraph(const Graph& src, Graph* dest) { dest->Copy(src); }
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -47,6 +47,15 @@ class EdgeSet {
|
||||
void clear();
|
||||
std::pair<iterator, bool> insert(value_type value);
|
||||
size_type erase(key_type key);
|
||||
void reserve(size_type new_size) {
|
||||
if (new_size > kInline) {
|
||||
auto s = new gtl::FlatSet<const Edge*>(new_size);
|
||||
s->insert(reinterpret_cast<const Edge**>(std::begin(ptrs_)),
|
||||
reinterpret_cast<const Edge**>(&ptrs_[0] + size()));
|
||||
ptrs_[0] = this;
|
||||
ptrs_[1] = s;
|
||||
}
|
||||
}
|
||||
|
||||
// Caller is not allowed to mutate the EdgeSet while iterating.
|
||||
const_iterator begin() const;
|
||||
|
@ -414,6 +414,37 @@ Graph::~Graph() {
|
||||
const VersionDef& Graph::versions() const { return *versions_; }
|
||||
void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
|
||||
|
||||
void Graph::Copy(const Graph& src) {
|
||||
SetConstructionContext(src.GetConstructionContextInternal());
|
||||
for (Node* n : nodes()) {
|
||||
CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
|
||||
}
|
||||
|
||||
// Copy GraphDef versions
|
||||
set_versions(src.versions());
|
||||
|
||||
// Copy the nodes.
|
||||
// "Node in src" -> "Node in *dest"
|
||||
gtl::FlatMap<const Node*, Node*> node_map;
|
||||
node_map.reserve(src.num_nodes());
|
||||
node_map[src.source_node()] = source_node();
|
||||
node_map[src.sink_node()] = sink_node();
|
||||
for (Node* n : src.op_nodes()) {
|
||||
auto copy = CopyNode(n);
|
||||
copy->in_edges_.reserve(n->in_edges().size());
|
||||
copy->out_edges_.reserve(n->out_edges().size());
|
||||
node_map[n] = copy;
|
||||
}
|
||||
|
||||
// Copy the edges
|
||||
edges_.reserve(src.num_edges());
|
||||
for (const Edge* e : src.edges()) {
|
||||
Node* src_copy = node_map[e->src()];
|
||||
Node* dst_copy = node_map[e->dst()];
|
||||
AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
Node* Graph::AddNode(NodeDef node_def, Status* status) {
|
||||
const OpRegistrationData* op_reg_data;
|
||||
status->Update(ops_.LookUp(node_def.op(), &op_reg_data));
|
||||
|
@ -537,6 +537,8 @@ class Graph {
|
||||
// REQUIRES: node->IsOp()
|
||||
void RemoveNode(Node* node);
|
||||
|
||||
void Copy(const Graph& src);
|
||||
|
||||
// Adds an edge that connects the xth output of `source` to the yth input of
|
||||
// `dest` and returns it. Does not update dest's NodeDef.
|
||||
const Edge* AddEdge(Node* source, int x, Node* dest, int y);
|
||||
|
@ -787,9 +787,10 @@ void BM_RemoveNode(::testing::benchmark::State& state) {
|
||||
const auto registry = OpRegistry::Global();
|
||||
GraphConstructorOptions opts;
|
||||
for (auto s : state) {
|
||||
state.PauseTiming();
|
||||
Graph graph(registry);
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
|
||||
testing::StartTiming();
|
||||
state.ResumeTiming();
|
||||
for (Node* n : graph.op_nodes()) {
|
||||
graph.RemoveNode(n);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user