From 78529cfe02d52d1c69d1ca06c2948c6841fd4f48 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev <ezhulenev@google.com> Date: Thu, 24 Oct 2019 18:44:57 -0700 Subject: [PATCH] [Grappler:GraphView] Add NodeDiff only if mutation is not a no-op PiperOrigin-RevId: 276609854 Change-Id: I3508fef7f3a7ebb2c235b54b49685521d52b6e31 --- tensorflow/core/grappler/utils/graph_view.cc | 40 +++++++++-------- tensorflow/core/grappler/utils/graph_view.h | 6 ++- .../core/grappler/utils/graph_view_internal.h | 43 ++++++++++++------- 3 files changed, 55 insertions(+), 34 deletions(-) diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc index 6c2b1b39fdd..5b51710857b 100644 --- a/tensorflow/core/grappler/utils/graph_view.cc +++ b/tensorflow/core/grappler/utils/graph_view.cc @@ -262,14 +262,16 @@ MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) { void Mutation::AddMutation( MutableNodeView* node, - std::function<void(MutableNodeViewDiff*)> mutate_fn) { + std::function<bool(MutableNodeViewDiff*)> mutate_fn) { DCHECK(node->graph_view_ == graph_view_); if (node->update_index_ == internal::kMissingIndex) { + MutableNodeViewDiff diff(graph_view_, node->node_index_); + // If mutation is a no-op return and do not add it to the `updated_nodes_`. + if (!mutate_fn(&diff)) return; node->update_index_ = updated_nodes_.size(); - updated_nodes_.emplace_back(graph_view_, node->node_index_); - mutate_fn(&updated_nodes_.back()); + updated_nodes_.push_back(std::move(diff)); } else if (!removed_nodes_.contains(node->node_index_)) { - auto& diff = updated_nodes_[node->update_index_]; + MutableNodeViewDiff& diff = updated_nodes_[node->update_index_]; mutate_fn(&diff); } } @@ -290,7 +292,7 @@ void Mutation::RemoveNode(MutableNodeView* node) { void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) { AddMutation(node, [name](MutableNodeViewDiff* diff) { - internal::UpdateName(diff, name); + return internal::UpdateName(diff, name); }); } @@ -301,8 +303,9 @@ void Mutation::UpdateNodeName(const MutationNewNode& node, } void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) { - AddMutation( - node, [op](MutableNodeViewDiff* diff) { internal::UpdateOp(diff, op); }); + AddMutation(node, [op](MutableNodeViewDiff* diff) { + return internal::UpdateOp(diff, op); + }); } void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) { @@ -313,7 +316,7 @@ void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) { void Mutation::UpdateNodeDevice(MutableNodeView* node, absl::string_view device) { AddMutation(node, [device](MutableNodeViewDiff* diff) { - internal::UpdateDevice(diff, device); + return internal::UpdateDevice(diff, device); }); } @@ -326,7 +329,7 @@ void Mutation::UpdateNodeDevice(const MutationNewNode& node, void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index, const TensorId& fanin) { AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) { - internal::AddOrUpdateRegularFanin(diff, index, fanin); + return internal::AddOrUpdateRegularFanin(diff, index, fanin); }); } @@ -340,7 +343,7 @@ void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index, void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) { AddMutation(node, [index](MutableNodeViewDiff* diff) { - internal::RemoveRegularFanin(diff, index); + return internal::RemoveRegularFanin(diff, index); }); } @@ -357,7 +360,7 @@ void Mutation::AddControllingFanin(MutableNodeView* node, const int control_index = it != node->controlling_fanins_index_.end() ? it->second : internal::kMissingIndex; - internal::AddControllingFanin(diff, control_index, fanin_node_name); + return internal::AddControllingFanin(diff, control_index, fanin_node_name); }); } @@ -374,7 +377,8 @@ void Mutation::RemoveControllingFanin(MutableNodeView* node, const int control_index = it != node->controlling_fanins_index_.end() ? it->second : internal::kMissingIndex; - internal::RemoveControllingFanin(diff, control_index, fanin_node_name); + return internal::RemoveControllingFanin(diff, control_index, + fanin_node_name); }); } @@ -388,7 +392,7 @@ void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node, absl::string_view attr_name, const AttrValue& attr_value) { AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) { - internal::AddOrUpdateAttribute(diff, attr_name, attr_value); + return internal::AddOrUpdateAttribute(diff, attr_name, attr_value); }); } @@ -403,7 +407,7 @@ void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node, void Mutation::RemoveNodeAttr(MutableNodeView* node, absl::string_view attr_name) { AddMutation(node, [attr_name](MutableNodeViewDiff* diff) { - internal::RemoveAttribute(diff, attr_name); + return internal::RemoveAttribute(diff, attr_name); }); } @@ -807,10 +811,10 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() { diff.processed_attrs = AttrValueMap(node->attr().begin(), node->attr().end()); for (const auto& attr_to_remove : diff.attrs_to_remove) { - diff.processed_attrs.erase(attr_to_remove); + (*diff.processed_attrs).erase(attr_to_remove); } for (const auto& attr_to_add : diff.attrs_to_add) { - gtl::InsertOrUpdate(&diff.processed_attrs, attr_to_add.first, + gtl::InsertOrUpdate(&(*diff.processed_attrs), attr_to_add.first, attr_to_add.second); } const string& device = diff.update_device ? diff.device : node->device(); @@ -823,7 +827,7 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() { node->has_experimental_debug_info(), node->experimental_debug_info(), diff.update_op ? diff.op : node->op(), device, - AttrSlice(&diff.processed_attrs)); + AttrSlice(&(*diff.processed_attrs))); if (!s.ok()) { LOG(WARNING) << s.error_message(); } @@ -1189,7 +1193,7 @@ void MutableGraphView::ApplyNodeUpdates() { if (diff.update_device) { node_def->set_device(diff.device); } - node_def->mutable_attr()->swap(diff.processed_attrs); + node_def->mutable_attr()->swap((*diff.processed_attrs)); // Updated fanins. Only one of `regular_inputs_to_remove_` or // `regular_inputs_to_add_` can be set. diff --git a/tensorflow/core/grappler/utils/graph_view.h b/tensorflow/core/grappler/utils/graph_view.h index fc8c58ef703..575df428d6d 100644 --- a/tensorflow/core/grappler/utils/graph_view.h +++ b/tensorflow/core/grappler/utils/graph_view.h @@ -353,8 +353,12 @@ class Mutation { void ResetInternal(); using MutableNodeViewDiff = internal::NodeViewDiff<MutableGraphView>; + + // Adds a mutation to the `node`. Mutation function `mutate_fn` must return + // `true` if it actually does any mutations. If it returns `false` mutation + // will be ignored. void AddMutation(MutableNodeView* node, - std::function<void(MutableNodeViewDiff*)> mutate_fn); + std::function<bool(MutableNodeViewDiff*)> mutate_fn); MutableGraphView* graph_view_ = nullptr; int mutation_counter_ = 0; diff --git a/tensorflow/core/grappler/utils/graph_view_internal.h b/tensorflow/core/grappler/utils/graph_view_internal.h index 837c05ecdbd..fe91e597e7b 100644 --- a/tensorflow/core/grappler/utils/graph_view_internal.h +++ b/tensorflow/core/grappler/utils/graph_view_internal.h @@ -389,13 +389,15 @@ struct NodeViewDiff { std::set<int> controlling_inputs_to_remove; absl::flat_hash_map<string, AttrValue> attrs_to_add; absl::flat_hash_set<string> attrs_to_remove; - AttrValueMap processed_attrs; + // AttrValueMap constructor and destructor are very expensive, we will + // initialize it lazily only if needed. + absl::optional<AttrValueMap> processed_attrs; }; // Updates node name. If `name` is the same as the name in the original node, // the field will be cleared in the diff. template <typename GraphViewT> -inline void UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) { +inline bool UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) { if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) { diff->name.clear(); diff->update_name = false; @@ -403,12 +405,13 @@ inline void UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) { diff->name = string(name); diff->update_name = true; } + return true; } // Updates node op. If `op` is the same as the op in the original node, the // field will be cleared in the diff. template <typename GraphViewT> -inline void UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) { +inline bool UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) { if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) { diff->op.clear(); diff->update_op = false; @@ -416,12 +419,13 @@ inline void UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) { diff->op = string(op); diff->update_op = true; } + return true; } // Updates node device. If `device` is the same as the device in the original // node, the field will be cleared in the diff. template <typename GraphViewT> -inline void UpdateDevice(NodeViewDiff<GraphViewT>* diff, +inline bool UpdateDevice(NodeViewDiff<GraphViewT>* diff, absl::string_view device) { if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) { diff->device.clear(); @@ -430,6 +434,7 @@ inline void UpdateDevice(NodeViewDiff<GraphViewT>* diff, diff->device = string(device); diff->update_device = true; } + return true; } // Adds or updates value in vector `v` at index `i`. This will also resize the @@ -476,11 +481,11 @@ inline bool CheckNodeNameExists( // differs. If `index` is greater than or equal to the number of regular fanins, // `fanin` will be added beyond the end of regular fanins at `index`. template <typename GraphViewT> -inline void AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index, +inline bool AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index, const TensorId& fanin) { if (index < 0) { // Not a valid index for regular fanins. - return; + return false; } auto* node_view = diff->graph_view->GetNode(diff->node_index); const int num_regular_fanins = node_view->NumRegularFanins(); @@ -511,15 +516,16 @@ inline void AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index, ++diff->num_regular_inputs_to_add; } } + return true; } // Remove regular fanin at `index` of regular fanins. This can remove existing // fanins and updated/added fanins via AddOrUpdateRegularFanins. template <typename GraphViewT> -inline void RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) { +inline bool RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) { if (index < 0) { // Not a valid index for regular fanins. - return; + return false; } auto* node_view = diff->graph_view->GetNode(diff->node_index); const int num_regular_fanins = node_view->NumRegularFanins(); @@ -541,19 +547,20 @@ inline void RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) { IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) { // At relative index, appended regular fanin was already marked for // removal. - return; + return false; } // Remove added fanin. diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId(); --diff->num_regular_inputs_to_add; } + return true; } // Adds controlling fanin. If the controlling fanin already exists in the // original node, it will be dedupped. If the controlling fanin is marked for // removal, this will reverse it. template <typename GraphViewT> -inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff, +inline bool AddControllingFanin(NodeViewDiff<GraphViewT>* diff, int control_index, absl::string_view fanin_node_name) { if (control_index == kMissingIndex) { @@ -561,6 +568,7 @@ inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff, } else { diff->controlling_inputs_to_remove.erase(control_index); } + return true; } // Remove controlling fanin. If the controlling fanin does not exist in the @@ -568,7 +576,7 @@ inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff, // in the diff, it will be removed. Otherwise the controlling fanin will be // marked for removal from the original node. template <typename GraphViewT> -inline void RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff, +inline bool RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff, int control_index, absl::string_view fanin_node_name) { if (control_index == kMissingIndex) { @@ -576,28 +584,33 @@ inline void RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff, } else { diff->controlling_inputs_to_remove.emplace(control_index); } + return true; } // Adds or updates an attribute by name. If an attribute exist in the original // node or diff (including those marked for removal), this will overwrite it. template <typename GraphViewT> -inline void AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff, +inline bool AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff, absl::string_view attr_name, const AttrValue& attr_value) { - diff->attrs_to_remove.erase(attr_name); + diff->attrs_to_add.empty() ? 0 : diff->attrs_to_remove.erase(attr_name); gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value); + return true; } // Removes an attribute by name. If an attribute exist in the original node or // diff, this will remove it. template <typename GraphViewT> -inline void RemoveAttribute(NodeViewDiff<GraphViewT>* diff, +inline bool RemoveAttribute(NodeViewDiff<GraphViewT>* diff, absl::string_view attr_name) { - diff->attrs_to_add.erase(attr_name); + const size_t num_erased = + diff->attrs_to_add.empty() ? 0 : diff->attrs_to_add.erase(attr_name); auto* node_view = diff->graph_view->GetNode(diff->node_index); if (node_view->HasAttr(attr_name)) { diff->attrs_to_remove.emplace(attr_name); + return true; } + return num_erased > 0; } // Removes trailing values in vector `v` for values equal to `value`.