From 78529cfe02d52d1c69d1ca06c2948c6841fd4f48 Mon Sep 17 00:00:00 2001
From: Eugene Zhulenev <ezhulenev@google.com>
Date: Thu, 24 Oct 2019 18:44:57 -0700
Subject: [PATCH] [Grappler:GraphView] Add NodeDiff only if mutation is not a
 no-op

PiperOrigin-RevId: 276609854
Change-Id: I3508fef7f3a7ebb2c235b54b49685521d52b6e31
---
 tensorflow/core/grappler/utils/graph_view.cc  | 40 +++++++++--------
 tensorflow/core/grappler/utils/graph_view.h   |  6 ++-
 .../core/grappler/utils/graph_view_internal.h | 43 ++++++++++++-------
 3 files changed, 55 insertions(+), 34 deletions(-)

diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc
index 6c2b1b39fdd..5b51710857b 100644
--- a/tensorflow/core/grappler/utils/graph_view.cc
+++ b/tensorflow/core/grappler/utils/graph_view.cc
@@ -262,14 +262,16 @@ MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) {
 
 void Mutation::AddMutation(
     MutableNodeView* node,
-    std::function<void(MutableNodeViewDiff*)> mutate_fn) {
+    std::function<bool(MutableNodeViewDiff*)> mutate_fn) {
   DCHECK(node->graph_view_ == graph_view_);
   if (node->update_index_ == internal::kMissingIndex) {
+    MutableNodeViewDiff diff(graph_view_, node->node_index_);
+    // If mutation is a no-op return and do not add it to the `updated_nodes_`.
+    if (!mutate_fn(&diff)) return;
     node->update_index_ = updated_nodes_.size();
-    updated_nodes_.emplace_back(graph_view_, node->node_index_);
-    mutate_fn(&updated_nodes_.back());
+    updated_nodes_.push_back(std::move(diff));
   } else if (!removed_nodes_.contains(node->node_index_)) {
-    auto& diff = updated_nodes_[node->update_index_];
+    MutableNodeViewDiff& diff = updated_nodes_[node->update_index_];
     mutate_fn(&diff);
   }
 }
@@ -290,7 +292,7 @@ void Mutation::RemoveNode(MutableNodeView* node) {
 
 void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) {
   AddMutation(node, [name](MutableNodeViewDiff* diff) {
-    internal::UpdateName(diff, name);
+    return internal::UpdateName(diff, name);
   });
 }
 
@@ -301,8 +303,9 @@ void Mutation::UpdateNodeName(const MutationNewNode& node,
 }
 
 void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) {
-  AddMutation(
-      node, [op](MutableNodeViewDiff* diff) { internal::UpdateOp(diff, op); });
+  AddMutation(node, [op](MutableNodeViewDiff* diff) {
+    return internal::UpdateOp(diff, op);
+  });
 }
 
 void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
@@ -313,7 +316,7 @@ void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) {
 void Mutation::UpdateNodeDevice(MutableNodeView* node,
                                 absl::string_view device) {
   AddMutation(node, [device](MutableNodeViewDiff* diff) {
-    internal::UpdateDevice(diff, device);
+    return internal::UpdateDevice(diff, device);
   });
 }
 
@@ -326,7 +329,7 @@ void Mutation::UpdateNodeDevice(const MutationNewNode& node,
 void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index,
                                        const TensorId& fanin) {
   AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) {
-    internal::AddOrUpdateRegularFanin(diff, index, fanin);
+    return internal::AddOrUpdateRegularFanin(diff, index, fanin);
   });
 }
 
@@ -340,7 +343,7 @@ void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index,
 
 void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) {
   AddMutation(node, [index](MutableNodeViewDiff* diff) {
-    internal::RemoveRegularFanin(diff, index);
+    return internal::RemoveRegularFanin(diff, index);
   });
 }
 
@@ -357,7 +360,7 @@ void Mutation::AddControllingFanin(MutableNodeView* node,
     const int control_index = it != node->controlling_fanins_index_.end()
                                   ? it->second
                                   : internal::kMissingIndex;
-    internal::AddControllingFanin(diff, control_index, fanin_node_name);
+    return internal::AddControllingFanin(diff, control_index, fanin_node_name);
   });
 }
 
@@ -374,7 +377,8 @@ void Mutation::RemoveControllingFanin(MutableNodeView* node,
     const int control_index = it != node->controlling_fanins_index_.end()
                                   ? it->second
                                   : internal::kMissingIndex;
-    internal::RemoveControllingFanin(diff, control_index, fanin_node_name);
+    return internal::RemoveControllingFanin(diff, control_index,
+                                            fanin_node_name);
   });
 }
 
@@ -388,7 +392,7 @@ void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node,
                                    absl::string_view attr_name,
                                    const AttrValue& attr_value) {
   AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) {
-    internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
+    return internal::AddOrUpdateAttribute(diff, attr_name, attr_value);
   });
 }
 
@@ -403,7 +407,7 @@ void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node,
 void Mutation::RemoveNodeAttr(MutableNodeView* node,
                               absl::string_view attr_name) {
   AddMutation(node, [attr_name](MutableNodeViewDiff* diff) {
-    internal::RemoveAttribute(diff, attr_name);
+    return internal::RemoveAttribute(diff, attr_name);
   });
 }
 
@@ -807,10 +811,10 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() {
     diff.processed_attrs =
         AttrValueMap(node->attr().begin(), node->attr().end());
     for (const auto& attr_to_remove : diff.attrs_to_remove) {
-      diff.processed_attrs.erase(attr_to_remove);
+      (*diff.processed_attrs).erase(attr_to_remove);
     }
     for (const auto& attr_to_add : diff.attrs_to_add) {
-      gtl::InsertOrUpdate(&diff.processed_attrs, attr_to_add.first,
+      gtl::InsertOrUpdate(&(*diff.processed_attrs), attr_to_add.first,
                           attr_to_add.second);
     }
     const string& device = diff.update_device ? diff.device : node->device();
@@ -823,7 +827,7 @@ Status MutableGraphView::CheckKernelRegisteredForNodes() {
                                   node->has_experimental_debug_info(),
                                   node->experimental_debug_info(),
                                   diff.update_op ? diff.op : node->op(), device,
-                                  AttrSlice(&diff.processed_attrs));
+                                  AttrSlice(&(*diff.processed_attrs)));
     if (!s.ok()) {
       LOG(WARNING) << s.error_message();
     }
@@ -1189,7 +1193,7 @@ void MutableGraphView::ApplyNodeUpdates() {
     if (diff.update_device) {
       node_def->set_device(diff.device);
     }
-    node_def->mutable_attr()->swap(diff.processed_attrs);
+    node_def->mutable_attr()->swap((*diff.processed_attrs));
 
     // Updated fanins. Only one of `regular_inputs_to_remove_` or
     // `regular_inputs_to_add_` can be set.
diff --git a/tensorflow/core/grappler/utils/graph_view.h b/tensorflow/core/grappler/utils/graph_view.h
index fc8c58ef703..575df428d6d 100644
--- a/tensorflow/core/grappler/utils/graph_view.h
+++ b/tensorflow/core/grappler/utils/graph_view.h
@@ -353,8 +353,12 @@ class Mutation {
   void ResetInternal();
 
   using MutableNodeViewDiff = internal::NodeViewDiff<MutableGraphView>;
+
+  // Adds a mutation to the `node`. Mutation function `mutate_fn` must return
+  // `true` if it actually does any mutations. If it returns `false` mutation
+  // will be ignored.
   void AddMutation(MutableNodeView* node,
-                   std::function<void(MutableNodeViewDiff*)> mutate_fn);
+                   std::function<bool(MutableNodeViewDiff*)> mutate_fn);
 
   MutableGraphView* graph_view_ = nullptr;
   int mutation_counter_ = 0;
diff --git a/tensorflow/core/grappler/utils/graph_view_internal.h b/tensorflow/core/grappler/utils/graph_view_internal.h
index 837c05ecdbd..fe91e597e7b 100644
--- a/tensorflow/core/grappler/utils/graph_view_internal.h
+++ b/tensorflow/core/grappler/utils/graph_view_internal.h
@@ -389,13 +389,15 @@ struct NodeViewDiff {
   std::set<int> controlling_inputs_to_remove;
   absl::flat_hash_map<string, AttrValue> attrs_to_add;
   absl::flat_hash_set<string> attrs_to_remove;
-  AttrValueMap processed_attrs;
+  // AttrValueMap constructor and destructor are very expensive, we will
+  // initialize it lazily only if needed.
+  absl::optional<AttrValueMap> processed_attrs;
 };
 
 // Updates node name. If `name` is the same as the name in the original node,
 // the field will be cleared in the diff.
 template <typename GraphViewT>
-inline void UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
+inline bool UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
   if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) {
     diff->name.clear();
     diff->update_name = false;
@@ -403,12 +405,13 @@ inline void UpdateName(NodeViewDiff<GraphViewT>* diff, absl::string_view name) {
     diff->name = string(name);
     diff->update_name = true;
   }
+  return true;
 }
 
 // Updates node op. If `op` is the same as the op in the original node, the
 // field will be cleared in the diff.
 template <typename GraphViewT>
-inline void UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
+inline bool UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
   if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) {
     diff->op.clear();
     diff->update_op = false;
@@ -416,12 +419,13 @@ inline void UpdateOp(NodeViewDiff<GraphViewT>* diff, absl::string_view op) {
     diff->op = string(op);
     diff->update_op = true;
   }
+  return true;
 }
 
 // Updates node device. If `device` is the same as the device in the original
 // node, the field will be cleared in the diff.
 template <typename GraphViewT>
-inline void UpdateDevice(NodeViewDiff<GraphViewT>* diff,
+inline bool UpdateDevice(NodeViewDiff<GraphViewT>* diff,
                          absl::string_view device) {
   if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) {
     diff->device.clear();
@@ -430,6 +434,7 @@ inline void UpdateDevice(NodeViewDiff<GraphViewT>* diff,
     diff->device = string(device);
     diff->update_device = true;
   }
+  return true;
 }
 
 // Adds or updates value in vector `v` at index `i`. This will also resize the
@@ -476,11 +481,11 @@ inline bool CheckNodeNameExists(
 // differs. If `index` is greater than or equal to the number of regular fanins,
 // `fanin` will be added beyond the end of regular fanins at `index`.
 template <typename GraphViewT>
-inline void AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
+inline bool AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
                                     const TensorId& fanin) {
   if (index < 0) {
     // Not a valid index for regular fanins.
-    return;
+    return false;
   }
   auto* node_view = diff->graph_view->GetNode(diff->node_index);
   const int num_regular_fanins = node_view->NumRegularFanins();
@@ -511,15 +516,16 @@ inline void AddOrUpdateRegularFanin(NodeViewDiff<GraphViewT>* diff, int index,
       ++diff->num_regular_inputs_to_add;
     }
   }
+  return true;
 }
 
 // Remove regular fanin at `index` of regular fanins. This can remove existing
 // fanins and updated/added fanins via AddOrUpdateRegularFanins.
 template <typename GraphViewT>
-inline void RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
+inline bool RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
   if (index < 0) {
     // Not a valid index for regular fanins.
-    return;
+    return false;
   }
   auto* node_view = diff->graph_view->GetNode(diff->node_index);
   const int num_regular_fanins = node_view->NumRegularFanins();
@@ -541,19 +547,20 @@ inline void RemoveRegularFanin(NodeViewDiff<GraphViewT>* diff, int index) {
         IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) {
       // At relative index, appended regular fanin was already marked for
       // removal.
-      return;
+      return false;
     }
     // Remove added fanin.
     diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId();
     --diff->num_regular_inputs_to_add;
   }
+  return true;
 }
 
 // Adds controlling fanin. If the controlling fanin already exists in the
 // original node, it will be dedupped. If the controlling fanin is marked for
 // removal, this will reverse it.
 template <typename GraphViewT>
-inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
+inline bool AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
                                 int control_index,
                                 absl::string_view fanin_node_name) {
   if (control_index == kMissingIndex) {
@@ -561,6 +568,7 @@ inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
   } else {
     diff->controlling_inputs_to_remove.erase(control_index);
   }
+  return true;
 }
 
 // Remove controlling fanin. If the controlling fanin does not exist in the
@@ -568,7 +576,7 @@ inline void AddControllingFanin(NodeViewDiff<GraphViewT>* diff,
 // in the diff, it will be removed. Otherwise the controlling fanin will be
 // marked for removal from the original node.
 template <typename GraphViewT>
-inline void RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
+inline bool RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
                                    int control_index,
                                    absl::string_view fanin_node_name) {
   if (control_index == kMissingIndex) {
@@ -576,28 +584,33 @@ inline void RemoveControllingFanin(NodeViewDiff<GraphViewT>* diff,
   } else {
     diff->controlling_inputs_to_remove.emplace(control_index);
   }
+  return true;
 }
 
 // Adds or updates an attribute by name. If an attribute exist in the original
 // node or diff (including those marked for removal), this will overwrite it.
 template <typename GraphViewT>
-inline void AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff,
+inline bool AddOrUpdateAttribute(NodeViewDiff<GraphViewT>* diff,
                                  absl::string_view attr_name,
                                  const AttrValue& attr_value) {
-  diff->attrs_to_remove.erase(attr_name);
+  diff->attrs_to_add.empty() ? 0 : diff->attrs_to_remove.erase(attr_name);
   gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value);
+  return true;
 }
 
 // Removes an attribute by name. If an attribute exist in the original node or
 // diff, this will remove it.
 template <typename GraphViewT>
-inline void RemoveAttribute(NodeViewDiff<GraphViewT>* diff,
+inline bool RemoveAttribute(NodeViewDiff<GraphViewT>* diff,
                             absl::string_view attr_name) {
-  diff->attrs_to_add.erase(attr_name);
+  const size_t num_erased =
+      diff->attrs_to_add.empty() ? 0 : diff->attrs_to_add.erase(attr_name);
   auto* node_view = diff->graph_view->GetNode(diff->node_index);
   if (node_view->HasAttr(attr_name)) {
     diff->attrs_to_remove.emplace(attr_name);
+    return true;
   }
+  return num_erased > 0;
 }
 
 // Removes trailing values in vector `v` for values equal to `value`.