From 73b669223b7c60d90bfec0e294e021656a77a165 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Thu, 24 Oct 2019 18:07:07 -0700 Subject: [PATCH] [Grappler] Use flat_hash_set to keep removed nodes in graph view mutation PiperOrigin-RevId: 276605620 Change-Id: I1246980a2833c6ccbe1dc05b60a101754fa70a06 --- tensorflow/core/grappler/utils/graph_view.cc | 40 +++++++++---------- tensorflow/core/grappler/utils/graph_view.h | 2 +- .../core/grappler/utils/graph_view_test.cc | 20 ++++++++++ 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc index c1a443333ee..6c2b1b39fdd 100644 --- a/tensorflow/core/grappler/utils/graph_view.cc +++ b/tensorflow/core/grappler/utils/graph_view.cc @@ -268,7 +268,7 @@ void Mutation::AddMutation( node->update_index_ = updated_nodes_.size(); updated_nodes_.emplace_back(graph_view_, node->node_index_); mutate_fn(&updated_nodes_.back()); - } else if (!removed_nodes_[node->node_index_]) { + } else if (!removed_nodes_.contains(node->node_index_)) { auto& diff = updated_nodes_[node->update_index_]; mutate_fn(&diff); } @@ -285,7 +285,7 @@ void Mutation::RemoveNode(MutableNodeView* node) { updated_nodes_.pop_back(); update_index = internal::kMissingIndex; } - removed_nodes_[node->node_index_] = true; + removed_nodes_.insert(node->node_index_); } void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) { @@ -414,9 +414,9 @@ void Mutation::RemoveNodeAttr(const MutationNewNode& node, } void Mutation::ResetInternal() { - std::vector().swap(updated_nodes_); - std::vector(graph_view_->NumNodes()).swap(removed_nodes_); - std::vector().swap(new_nodes_); + updated_nodes_.clear(); + removed_nodes_.clear(); + new_nodes_.clear(); } void Mutation::Reset() { @@ -610,11 +610,9 @@ Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes( } } - for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) { - if (mutation_.removed_nodes_[i]) { - const string& node_name = nodes_[i].GetName(); - node_names->emplace(node_name, i); - } + for (int node_index : mutation_.removed_nodes_) { + const string& node_name = nodes_[node_index].GetName(); + node_names->emplace(node_name, node_index); } auto name_conflict = [](const absl::string_view node_name) { @@ -713,7 +711,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed( // Check all fanouts of a single port. MutableNodeView* fanout_view = regular_fanout.node_view(); if (fanout_view->update_index_ == internal::kMissingIndex) { - if (mutation_.removed_nodes_[fanout_view->node_index_]) { + if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) { // Fanout node will be removed, this can be ignored. continue; } else if (!overwritten_nodes[fanout_view->node_index_]) { @@ -739,7 +737,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed( for (const auto& controlled_fanout : node_view.GetControlledFanouts()) { MutableNodeView* fanout_view = controlled_fanout.node_view(); if (fanout_view->update_index_ == internal::kMissingIndex) { - if (mutation_.removed_nodes_[fanout_view->node_index_]) { + if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) { // Fanout node will be removed, this can be ignored. continue; } else if (!overwritten_nodes[fanout_view->node_index_]) { @@ -918,7 +916,7 @@ void MutableGraphView::FixRenamedNodes( nodes_[renamed.overwritten_node_index_]; ReplaceNodeFanouts(&renamed_node, &node_to_overwrite); node_index_by_name_.erase(node_to_overwrite.GetName()); - if (mutation_.removed_nodes_[node_to_overwrite.node_index_]) { + if (mutation_.removed_nodes_.contains(node_to_overwrite.node_index_)) { (*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true; } } else { @@ -952,7 +950,7 @@ void MutableGraphView::AddNewNodes( node_def->mutable_device()->swap(*new_node.node.mutable_device()); node_def->mutable_input()->Clear(); node_def->mutable_attr()->swap(*new_node.node.mutable_attr()); - mutation_.removed_nodes_[node_index] = false; + mutation_.removed_nodes_.erase(node_index); } else { // New node. auto* new_node_def = graph_->add_node(); @@ -1303,14 +1301,12 @@ void MutableGraphView::RemoveNodesInternal( std::vector node_indices_to_remove; node_indices_to_remove.reserve(mutation_.updated_nodes_.size() + overwritten_nodes.size()); - for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) { - if (mutation_.removed_nodes_[i]) { - auto& node = nodes_[i]; - RemoveAllFaninFanoutInternal(&node); - node_indices_to_remove.push_back(i); - if (!overwritten_name_removed_nodes[i]) { - node_index_by_name_.erase(node.GetName()); - } + for (int node_index : mutation_.removed_nodes_) { + auto& node = nodes_[node_index]; + RemoveAllFaninFanoutInternal(&node); + node_indices_to_remove.push_back(node_index); + if (!overwritten_name_removed_nodes[node_index]) { + node_index_by_name_.erase(node.GetName()); } } node_indices_to_remove.insert(node_indices_to_remove.end(), diff --git a/tensorflow/core/grappler/utils/graph_view.h b/tensorflow/core/grappler/utils/graph_view.h index 456c68a30e9..fc8c58ef703 100644 --- a/tensorflow/core/grappler/utils/graph_view.h +++ b/tensorflow/core/grappler/utils/graph_view.h @@ -359,7 +359,7 @@ class Mutation { MutableGraphView* graph_view_ = nullptr; int mutation_counter_ = 0; std::vector updated_nodes_; - std::vector removed_nodes_; + absl::flat_hash_set removed_nodes_; using MutationNewNodeHolder = internal::NewNode; std::vector new_nodes_; diff --git a/tensorflow/core/grappler/utils/graph_view_test.cc b/tensorflow/core/grappler/utils/graph_view_test.cc index a8f4b65c415..ce196d366ed 100644 --- a/tensorflow/core/grappler/utils/graph_view_test.cc +++ b/tensorflow/core/grappler/utils/graph_view_test.cc @@ -2415,8 +2415,28 @@ static void BM_MutableGraphViewConstruction(int iters, int num_nodes, num_edges_per_node); } +static void BM_MutableGraphViewClearAttrs(int iters, int num_nodes, + int num_edges_per_node) { + testing::StopTiming(); + GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); + + Status s; + MutableGraphView graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + utils::Mutation* mutation = graph_view.GetMutationBuilder(); + for (int j = 0; j < num_nodes; ++j) { + mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr"); + } + s = mutation->Apply(); + } + testing::StopTiming(); +} + RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction); RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction); +RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs); #define RUN_NUM_NODE_BENCHMARK(name) \ BENCHMARK(name) \