[Grappler] Use flat_hash_set to keep removed nodes in graph view mutation
PiperOrigin-RevId: 276605620 Change-Id: I1246980a2833c6ccbe1dc05b60a101754fa70a06
This commit is contained in:
parent
7785075046
commit
73b669223b
@ -268,7 +268,7 @@ void Mutation::AddMutation(
|
|||||||
node->update_index_ = updated_nodes_.size();
|
node->update_index_ = updated_nodes_.size();
|
||||||
updated_nodes_.emplace_back(graph_view_, node->node_index_);
|
updated_nodes_.emplace_back(graph_view_, node->node_index_);
|
||||||
mutate_fn(&updated_nodes_.back());
|
mutate_fn(&updated_nodes_.back());
|
||||||
} else if (!removed_nodes_[node->node_index_]) {
|
} else if (!removed_nodes_.contains(node->node_index_)) {
|
||||||
auto& diff = updated_nodes_[node->update_index_];
|
auto& diff = updated_nodes_[node->update_index_];
|
||||||
mutate_fn(&diff);
|
mutate_fn(&diff);
|
||||||
}
|
}
|
||||||
@ -285,7 +285,7 @@ void Mutation::RemoveNode(MutableNodeView* node) {
|
|||||||
updated_nodes_.pop_back();
|
updated_nodes_.pop_back();
|
||||||
update_index = internal::kMissingIndex;
|
update_index = internal::kMissingIndex;
|
||||||
}
|
}
|
||||||
removed_nodes_[node->node_index_] = true;
|
removed_nodes_.insert(node->node_index_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
|
void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
|
||||||
@ -414,9 +414,9 @@ void Mutation::RemoveNodeAttr(const MutationNewNode& node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Mutation::ResetInternal() {
|
void Mutation::ResetInternal() {
|
||||||
std::vector<MutableNodeViewDiff>().swap(updated_nodes_);
|
updated_nodes_.clear();
|
||||||
std::vector<bool>(graph_view_->NumNodes()).swap(removed_nodes_);
|
removed_nodes_.clear();
|
||||||
std::vector<MutationNewNodeHolder>().swap(new_nodes_);
|
new_nodes_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Mutation::Reset() {
|
void Mutation::Reset() {
|
||||||
@ -610,11 +610,9 @@ Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) {
|
for (int node_index : mutation_.removed_nodes_) {
|
||||||
if (mutation_.removed_nodes_[i]) {
|
const string& node_name = nodes_[node_index].GetName();
|
||||||
const string& node_name = nodes_[i].GetName();
|
node_names->emplace(node_name, node_index);
|
||||||
node_names->emplace(node_name, i);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto name_conflict = [](const absl::string_view node_name) {
|
auto name_conflict = [](const absl::string_view node_name) {
|
||||||
@ -713,7 +711,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
|
|||||||
// Check all fanouts of a single port.
|
// Check all fanouts of a single port.
|
||||||
MutableNodeView* fanout_view = regular_fanout.node_view();
|
MutableNodeView* fanout_view = regular_fanout.node_view();
|
||||||
if (fanout_view->update_index_ == internal::kMissingIndex) {
|
if (fanout_view->update_index_ == internal::kMissingIndex) {
|
||||||
if (mutation_.removed_nodes_[fanout_view->node_index_]) {
|
if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
|
||||||
// Fanout node will be removed, this can be ignored.
|
// Fanout node will be removed, this can be ignored.
|
||||||
continue;
|
continue;
|
||||||
} else if (!overwritten_nodes[fanout_view->node_index_]) {
|
} else if (!overwritten_nodes[fanout_view->node_index_]) {
|
||||||
@ -739,7 +737,7 @@ Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed(
|
|||||||
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
|
for (const auto& controlled_fanout : node_view.GetControlledFanouts()) {
|
||||||
MutableNodeView* fanout_view = controlled_fanout.node_view();
|
MutableNodeView* fanout_view = controlled_fanout.node_view();
|
||||||
if (fanout_view->update_index_ == internal::kMissingIndex) {
|
if (fanout_view->update_index_ == internal::kMissingIndex) {
|
||||||
if (mutation_.removed_nodes_[fanout_view->node_index_]) {
|
if (mutation_.removed_nodes_.contains(fanout_view->node_index_)) {
|
||||||
// Fanout node will be removed, this can be ignored.
|
// Fanout node will be removed, this can be ignored.
|
||||||
continue;
|
continue;
|
||||||
} else if (!overwritten_nodes[fanout_view->node_index_]) {
|
} else if (!overwritten_nodes[fanout_view->node_index_]) {
|
||||||
@ -918,7 +916,7 @@ void MutableGraphView::FixRenamedNodes(
|
|||||||
nodes_[renamed.overwritten_node_index_];
|
nodes_[renamed.overwritten_node_index_];
|
||||||
ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
|
ReplaceNodeFanouts(&renamed_node, &node_to_overwrite);
|
||||||
node_index_by_name_.erase(node_to_overwrite.GetName());
|
node_index_by_name_.erase(node_to_overwrite.GetName());
|
||||||
if (mutation_.removed_nodes_[node_to_overwrite.node_index_]) {
|
if (mutation_.removed_nodes_.contains(node_to_overwrite.node_index_)) {
|
||||||
(*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
|
(*overwritten_name_removed_nodes)[node_to_overwrite.node_index_] = true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -952,7 +950,7 @@ void MutableGraphView::AddNewNodes(
|
|||||||
node_def->mutable_device()->swap(*new_node.node.mutable_device());
|
node_def->mutable_device()->swap(*new_node.node.mutable_device());
|
||||||
node_def->mutable_input()->Clear();
|
node_def->mutable_input()->Clear();
|
||||||
node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
|
node_def->mutable_attr()->swap(*new_node.node.mutable_attr());
|
||||||
mutation_.removed_nodes_[node_index] = false;
|
mutation_.removed_nodes_.erase(node_index);
|
||||||
} else {
|
} else {
|
||||||
// New node.
|
// New node.
|
||||||
auto* new_node_def = graph_->add_node();
|
auto* new_node_def = graph_->add_node();
|
||||||
@ -1303,14 +1301,12 @@ void MutableGraphView::RemoveNodesInternal(
|
|||||||
std::vector<int> node_indices_to_remove;
|
std::vector<int> node_indices_to_remove;
|
||||||
node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
|
node_indices_to_remove.reserve(mutation_.updated_nodes_.size() +
|
||||||
overwritten_nodes.size());
|
overwritten_nodes.size());
|
||||||
for (int i = 0; i < mutation_.removed_nodes_.size(); ++i) {
|
for (int node_index : mutation_.removed_nodes_) {
|
||||||
if (mutation_.removed_nodes_[i]) {
|
auto& node = nodes_[node_index];
|
||||||
auto& node = nodes_[i];
|
RemoveAllFaninFanoutInternal(&node);
|
||||||
RemoveAllFaninFanoutInternal(&node);
|
node_indices_to_remove.push_back(node_index);
|
||||||
node_indices_to_remove.push_back(i);
|
if (!overwritten_name_removed_nodes[node_index]) {
|
||||||
if (!overwritten_name_removed_nodes[i]) {
|
node_index_by_name_.erase(node.GetName());
|
||||||
node_index_by_name_.erase(node.GetName());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
node_indices_to_remove.insert(node_indices_to_remove.end(),
|
node_indices_to_remove.insert(node_indices_to_remove.end(),
|
||||||
|
@ -359,7 +359,7 @@ class Mutation {
|
|||||||
MutableGraphView* graph_view_ = nullptr;
|
MutableGraphView* graph_view_ = nullptr;
|
||||||
int mutation_counter_ = 0;
|
int mutation_counter_ = 0;
|
||||||
std::vector<MutableNodeViewDiff> updated_nodes_;
|
std::vector<MutableNodeViewDiff> updated_nodes_;
|
||||||
std::vector<bool> removed_nodes_;
|
absl::flat_hash_set<int> removed_nodes_;
|
||||||
|
|
||||||
using MutationNewNodeHolder = internal::NewNode<MutableGraphView>;
|
using MutationNewNodeHolder = internal::NewNode<MutableGraphView>;
|
||||||
std::vector<MutationNewNodeHolder> new_nodes_;
|
std::vector<MutationNewNodeHolder> new_nodes_;
|
||||||
|
@ -2415,8 +2415,28 @@ static void BM_MutableGraphViewConstruction(int iters, int num_nodes,
|
|||||||
num_edges_per_node);
|
num_edges_per_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void BM_MutableGraphViewClearAttrs(int iters, int num_nodes,
|
||||||
|
int num_edges_per_node) {
|
||||||
|
testing::StopTiming();
|
||||||
|
GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
|
||||||
|
|
||||||
|
Status s;
|
||||||
|
MutableGraphView graph_view(&graph_def, &s);
|
||||||
|
|
||||||
|
testing::StartTiming();
|
||||||
|
for (int i = 0; i < iters; ++i) {
|
||||||
|
utils::Mutation* mutation = graph_view.GetMutationBuilder();
|
||||||
|
for (int j = 0; j < num_nodes; ++j) {
|
||||||
|
mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr");
|
||||||
|
}
|
||||||
|
s = mutation->Apply();
|
||||||
|
}
|
||||||
|
testing::StopTiming();
|
||||||
|
}
|
||||||
|
|
||||||
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
|
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
|
||||||
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction);
|
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction);
|
||||||
|
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs);
|
||||||
|
|
||||||
#define RUN_NUM_NODE_BENCHMARK(name) \
|
#define RUN_NUM_NODE_BENCHMARK(name) \
|
||||||
BENCHMARK(name) \
|
BENCHMARK(name) \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user