[Grappler] Use flat_hash_set to keep removed nodes in graph view mutation

PiperOrigin-RevId: 276605620
Change-Id: I1246980a2833c6ccbe1dc05b60a101754fa70a06
This commit is contained in:
Eugene Zhulenev 2019-10-24 18:07:07 -07:00 committed by TensorFlower Gardener
parent 7785075046
commit 73b669223b
3 changed files with 39 additions and 23 deletions

View File

@ -268,7 +268,7 @@ void Mutation::AddMutation(
node->update_index_ = updated_nodes_.size(); node->update_index_ = updated_nodes_.size();
updated_nodes_.emplace_back(graph_view_, node->node_index_); updated_nodes_.emplace_back(graph_view_, node->node_index_);
mutate_fn(&updated_nodes_.back()); mutate_fn(&updated_nodes_.back());
} else if (!removed_nodes_[node->node_index_]) { } else if (!removed_nodes_.contains(node->node_index_)) {
auto& diff = updated_nodes_[node->update_index_]; auto& diff = updated_nodes_[node->update_index_];
mutate_fn(&diff); mutate_fn(&diff);
} }
@ -285,7 +285,7 @@ void Mutation::RemoveNode(MutableNodeView* node) {
updated_nodes_.pop_back(); updated_nodes_.pop_back();
update_index = internal::kMissingIndex; update_index = internal::kMissingIndex;
} }
removed_nodes_[node->node_index_] = true; removed_nodes_.insert(node->node_index_);
} }
void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) { void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
@ -414,9 +414,9 @@ void Mutation::RemoveNodeAttr(const MutationNewNode& node,
} }
void Mutation::ResetInternal() { void Mutation::ResetInternal() {
std::vector<MutableNodeViewDiff>().swap(updated_nodes_); updated_nodes_.clear();
std::vector<bool>(graph_view_->NumNodes()).swap(removed_nodes_); removed_nodes_.clear();
std::vector<MutationNewNodeHolder>().swap(new_nodes_); new_nodes_.clear();
} }
void Mutation::Reset() { void Mutation::Reset() {
@ -610,11 +610,9 @@ Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes(
} }
} }
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) { for (int node_index : mutation_.removed_nodes_) {
if (mutation_.removed_nodes_[i]) { const string& node_name = nodes_[node_index].GetName();
const string& node_name = nodes_[i].GetName(); node_names->emplace(node_name, node_index);
node_names->emplace(node_name, i);
}
} }
auto name_conflict = [](const absl::string_view node_name) { auto name_conflict = [](const absl::string_view node_name) {
@ -713,7 +711,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
// Check all fanouts of a single port. // Check all fanouts of a single port.
MutableNodeView* fanout_view = regular_fanout.node_view(); MutableNodeView* fanout_view = regular_fanout.node_view();
if (fanout_view->update_index_ == internal::kMissingIndex) { if (fanout_view->update_index_ == internal::kMissingIndex) {
if (mutation_.removed_nodes_[fanout_view->node_index_]) { if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
// Fanout node will be removed, this can be ignored. // Fanout node will be removed, this can be ignored.
continue; continue;
} else if (!overwritten_nodes[fanout_view->node_index_]) { } else if (!overwritten_nodes[fanout_view->node_index_]) {
@ -739,7 +737,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) { for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
MutableNodeView* fanout_view = controlled_fanout.node_view(); MutableNodeView* fanout_view = controlled_fanout.node_view();
if (fanout_view->update_index_ == internal::kMissingIndex) { if (fanout_view->update_index_ == internal::kMissingIndex) {
if (mutation_.removed_nodes_[fanout_view->node_index_]) { if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
// Fanout node will be removed, this can be ignored. // Fanout node will be removed, this can be ignored.
continue; continue;
} else if (!overwritten_nodes[fanout_view->node_index_]) { } else if (!overwritten_nodes[fanout_view->node_index_]) {
@ -918,7 +916,7 @@ void MutableGraphView::FixRenamedNodes(
nodes_[renamed.overwritten_node_index_]; nodes_[renamed.overwritten_node_index_];
ReplaceNodeFanouts(&renamed_node, &node_to_overwrite); ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
node_index_by_name_.erase(node_to_overwrite.GetName()); node_index_by_name_.erase(node_to_overwrite.GetName());
if (mutation_.removed_nodes_[node_to_overwrite.node_index_]) { if (mutation_.removed_nodes_.contains(node_to_overwrite.node_index_)) {
(*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true; (*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
} }
} else { } else {
@ -952,7 +950,7 @@ void MutableGraphView::AddNewNodes(
node_def->mutable_device()->swap(*new_node.node.mutable_device()); node_def->mutable_device()->swap(*new_node.node.mutable_device());
node_def->mutable_input()->Clear(); node_def->mutable_input()->Clear();
node_def->mutable_attr()->swap(*new_node.node.mutable_attr()); node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
mutation_.removed_nodes_[node_index] = false; mutation_.removed_nodes_.erase(node_index);
} else { } else {
// New node. // New node.
auto* new_node_def = graph_->add_node(); auto* new_node_def = graph_->add_node();
@ -1303,14 +1301,12 @@ void MutableGraphView::RemoveNodesInternal(
std::vector<int> node_indices_to_remove; std::vector<int> node_indices_to_remove;
node_indices_to_remove.reserve(mutation_.updated_nodes_.size() + node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
overwritten_nodes.size()); overwritten_nodes.size());
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) { for (int node_index : mutation_.removed_nodes_) {
if (mutation_.removed_nodes_[i]) { auto& node = nodes_[node_index];
auto& node = nodes_[i]; RemoveAllFaninFanoutInternal(&node);
RemoveAllFaninFanoutInternal(&node); node_indices_to_remove.push_back(node_index);
node_indices_to_remove.push_back(i); if (!overwritten_name_removed_nodes[node_index]) {
if (!overwritten_name_removed_nodes[i]) { node_index_by_name_.erase(node.GetName());
node_index_by_name_.erase(node.GetName());
}
} }
} }
node_indices_to_remove.insert(node_indices_to_remove.end(), node_indices_to_remove.insert(node_indices_to_remove.end(),

View File

@ -359,7 +359,7 @@ class Mutation {
MutableGraphView* graph_view_ = nullptr; MutableGraphView* graph_view_ = nullptr;
int mutation_counter_ = 0; int mutation_counter_ = 0;
std::vector<MutableNodeViewDiff> updated_nodes_; std::vector<MutableNodeViewDiff> updated_nodes_;
std::vector<bool> removed_nodes_; absl::flat_hash_set<int> removed_nodes_;
using MutationNewNodeHolder = internal::NewNode<MutableGraphView>; using MutationNewNodeHolder = internal::NewNode<MutableGraphView>;
std::vector<MutationNewNodeHolder> new_nodes_; std::vector<MutationNewNodeHolder> new_nodes_;

View File

@ -2415,8 +2415,28 @@ static void BM_MutableGraphViewConstruction(int iters, int num_nodes,
num_edges_per_node); num_edges_per_node);
} }
static void BM_MutableGraphViewClearAttrs(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
Status s;
MutableGraphView graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
utils::Mutation* mutation = graph_view.GetMutationBuilder();
for (int j = 0; j < num_nodes; ++j) {
mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr");
}
s = mutation->Apply();
}
testing::StopTiming();
}
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction); RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction); RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction);
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs);
#define RUN_NUM_NODE_BENCHMARK(name) \ #define RUN_NUM_NODE_BENCHMARK(name) \
BENCHMARK(name) \ BENCHMARK(name) \