[Grappler] Use flat_hash_set to keep removed nodes in graph view mutation

PiperOrigin-RevId: 276605620
Change-Id: I1246980a2833c6ccbe1dc05b60a101754fa70a06
This commit is contained in:
Eugene Zhulenev 2019-10-24 18:07:07 -07:00 committed by TensorFlower Gardener
parent 7785075046
commit 73b669223b
3 changed files with 39 additions and 23 deletions

View File

@ -268,7 +268,7 @@ void Mutation::AddMutation(
node->update_index_ = updated_nodes_.size();
updated_nodes_.emplace_back(graph_view_, node->node_index_);
mutate_fn(&updated_nodes_.back());
} else if (!removed_nodes_[node->node_index_]) {
} else if (!removed_nodes_.contains(node->node_index_)) {
auto& diff = updated_nodes_[node->update_index_];
mutate_fn(&diff);
}
@ -285,7 +285,7 @@ void Mutation::RemoveNode(MutableNodeView* node) {
updated_nodes_.pop_back();
update_index = internal::kMissingIndex;
}
removed_nodes_[node->node_index_] = true;
removed_nodes_.insert(node->node_index_);
}
void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
@ -414,9 +414,9 @@ void Mutation::RemoveNodeAttr(const MutationNewNode& node,
}
void Mutation::ResetInternal() {
std::vector<MutableNodeViewDiff>().swap(updated_nodes_);
std::vector<bool>(graph_view_->NumNodes()).swap(removed_nodes_);
std::vector<MutationNewNodeHolder>().swap(new_nodes_);
updated_nodes_.clear();
removed_nodes_.clear();
new_nodes_.clear();
}
void Mutation::Reset() {
@ -610,11 +610,9 @@ Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes(
}
}
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) {
if (mutation_.removed_nodes_[i]) {
const string& node_name = nodes_[i].GetName();
node_names->emplace(node_name, i);
}
for (int node_index : mutation_.removed_nodes_) {
const string& node_name = nodes_[node_index].GetName();
node_names->emplace(node_name, node_index);
}
auto name_conflict = [](const absl::string_view node_name) {
@ -713,7 +711,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
// Check all fanouts of a single port.
MutableNodeView* fanout_view = regular_fanout.node_view();
if (fanout_view->update_index_ == internal::kMissingIndex) {
if (mutation_.removed_nodes_[fanout_view->node_index_]) {
if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
// Fanout node will be removed, this can be ignored.
continue;
} else if (!overwritten_nodes[fanout_view->node_index_]) {
@ -739,7 +737,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
MutableNodeView* fanout_view = controlled_fanout.node_view();
if (fanout_view->update_index_ == internal::kMissingIndex) {
if (mutation_.removed_nodes_[fanout_view->node_index_]) {
if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
// Fanout node will be removed, this can be ignored.
continue;
} else if (!overwritten_nodes[fanout_view->node_index_]) {
@ -918,7 +916,7 @@ void MutableGraphView::FixRenamedNodes(
nodes_[renamed.overwritten_node_index_];
ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
node_index_by_name_.erase(node_to_overwrite.GetName());
if (mutation_.removed_nodes_[node_to_overwrite.node_index_]) {
if (mutation_.removed_nodes_.contains(node_to_overwrite.node_index_)) {
(*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
}
} else {
@ -952,7 +950,7 @@ void MutableGraphView::AddNewNodes(
node_def->mutable_device()->swap(*new_node.node.mutable_device());
node_def->mutable_input()->Clear();
node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
mutation_.removed_nodes_[node_index] = false;
mutation_.removed_nodes_.erase(node_index);
} else {
// New node.
auto* new_node_def = graph_->add_node();
@ -1303,14 +1301,12 @@ void MutableGraphView::RemoveNodesInternal(
std::vector<int> node_indices_to_remove;
node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
overwritten_nodes.size());
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) {
if (mutation_.removed_nodes_[i]) {
auto& node = nodes_[i];
RemoveAllFaninFanoutInternal(&node);
node_indices_to_remove.push_back(i);
if (!overwritten_name_removed_nodes[i]) {
node_index_by_name_.erase(node.GetName());
}
for (int node_index : mutation_.removed_nodes_) {
auto& node = nodes_[node_index];
RemoveAllFaninFanoutInternal(&node);
node_indices_to_remove.push_back(node_index);
if (!overwritten_name_removed_nodes[node_index]) {
node_index_by_name_.erase(node.GetName());
}
}
node_indices_to_remove.insert(node_indices_to_remove.end(),

View File

@ -359,7 +359,7 @@ class Mutation {
MutableGraphView* graph_view_ = nullptr;
int mutation_counter_ = 0;
std::vector<MutableNodeViewDiff> updated_nodes_;
std::vector<bool> removed_nodes_;
absl::flat_hash_set<int> removed_nodes_;
using MutationNewNodeHolder = internal::NewNode<MutableGraphView>;
std::vector<MutationNewNodeHolder> new_nodes_;

View File

@ -2415,8 +2415,28 @@ static void BM_MutableGraphViewConstruction(int iters, int num_nodes,
num_edges_per_node);
}
static void BM_MutableGraphViewClearAttrs(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
Status s;
MutableGraphView graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
utils::Mutation* mutation = graph_view.GetMutationBuilder();
for (int j = 0; j < num_nodes; ++j) {
mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr");
}
s = mutation->Apply();
}
testing::StopTiming();
}
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction);
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs);
#define RUN_NUM_NODE_BENCHMARK(name) \
BENCHMARK(name) \