From 2394a61d64fd198965a0b027b7b840eaca0d6f25 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 30 Jun 2020 12:01:25 -0700
Subject: [PATCH] [Grappler] Replace redundant variable updates (e.g. AssignAdd
 of all zero deltas) with NoOps or Identity. Also relax a constraint in
 dependency optimizer to allow pruning of unused variable reads.

PiperOrigin-RevId: 319069467
Change-Id: I97199e464c8a9fee0055077efbcf481f2545bf8a
---
 tensorflow/core/grappler/op_types.cc          |  4 ++
 tensorflow/core/grappler/op_types.h           |  1 +
 .../grappler/optimizers/constant_folding.cc   | 69 +++++++++++++++++++
 .../grappler/optimizers/constant_folding.h    |  6 ++
 .../optimizers/dependency_optimizer.cc        | 27 ++++++--
 .../python/grappler/constant_folding_test.py  | 47 +++++++++++++
 .../python/grappler/memory_optimizer_test.py  |  1 +
 7 files changed, 151 insertions(+), 4 deletions(-)

diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index efd23b6005e..9d30f24e047 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -429,6 +429,10 @@ bool IsReadVariableOp(const NodeDef& node) {
   return node.op() == "ReadVariableOp";
 }
 
+bool IsReadVariablesOp(const NodeDef& node) {
+  return node.op() == "_ReadVariablesOp";
+}
+
 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
 
 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 59fc68daba5..141eda7415a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -136,6 +136,7 @@ bool IsQueue(const NodeDef& node);
 bool IsRandomShuffle(const NodeDef& node);
 bool IsRank(const NodeDef& node);
 bool IsReadVariableOp(const NodeDef& node);
+bool IsReadVariablesOp(const NodeDef& node);
 bool IsReal(const NodeDef& node);
 bool IsRealDiv(const NodeDef& node);
 bool IsReciprocalGrad(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 50a3daf379f..b8c2958b6bd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1781,9 +1781,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
   return false;
 }
 
+// Replace an operation with Identity.
 void ConstantFolding::ReplaceOperationWithIdentity(
     int input_to_forward, const GraphProperties& properties, NodeDef* node,
     GraphDef* graph) {
+  if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
   if (dtype == DT_INVALID) return;
 
@@ -1836,6 +1838,26 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
   graph_modified_ = true;
 }
 
+// Replace a node with NoOp. Change all inputs to control dependencies.
+// If the node has non-control outputs, no change will be performed.
+void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) {
+  if (HasRegularOutputs(*node, *node_map_)) return;
+  node->set_op("NoOp");
+  node->clear_attr();
+  // Change all inputs to control dependencies.
+  for (int i = 0; i < node->input_size(); ++i) {
+    if (IsControlInput(node->input(i))) {
+      break;
+    }
+    const string ctrl_dep =
+        AddControlDependency(node->input(i), graph, node_map_.get());
+    node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
+    node->set_input(i, ctrl_dep);
+  }
+  DedupControlInputs(node);
+  graph_modified_ = true;
+}
+
 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
     GraphDef* graph) {
@@ -2036,6 +2058,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
   SET_AND_RETURN_IF_MODIFIED(
       SimplifySelect(*properties, optimized_graph, node));
+  RETURN_IF_MODIFIED(
+      RemoveRedundantVariableUpdates(*properties, optimized_graph, node));
 
   graph_modified_ = graph_modified_cached;
   return Status::OK();
@@ -2460,6 +2484,51 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
   return true;
 }
 
+void ConstantFolding::RemoveRedundantVariableUpdates(
+    const GraphProperties& properties, GraphDef* optimized_graph,
+    NodeDef* node) {
+  static const absl::flat_hash_set<string>* kVariableReadOps =
+      new absl::flat_hash_set<string>{"AssignAddVariableOp",
+                                      "AssignSubVariableOp",
+                                      "AssignAdd",
+                                      "AssignSub",
+                                      "ScatterAdd",
+                                      "ScatterSub",
+                                      "ScatterMul",
+                                      "ScatterDiv",
+                                      "ScatterNdAdd",
+                                      "ScatterNdSub",
+                                      "ScatterNdMul",
+                                      "ScatterNdDiv",
+                                      "ResourceScatterAdd",
+                                      "ResourceScatterSub",
+                                      "ResourceScatterMul",
+                                      "ResourceScatterDiv",
+                                      "ResourceScatterNdAdd",
+                                      "ResourceScatterNdSub",
+                                      "ResourceScatterNdMul",
+                                      "ResourceScatterNdDiv"};
+  if (kVariableReadOps == nullptr ||
+      kVariableReadOps->find(node->op()) == kVariableReadOps->end())
+    return;
+  const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
+  const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
+  if (delta_node == nullptr) return;
+  const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
+                             absl::StrContains(node->op(), "Sub");
+  if ((is_add_or_sub && IsZeros(*delta_node)) ||
+      (!is_add_or_sub && IsOnes(*delta_node))) {
+    VLOG(1) << "Removing redundant variable update: " << node->DebugString();
+    if (absl::StrContains(node->op(), "Variable") ||
+        absl::StrContains(node->op(), "Resource")) {
+      ReplaceOperationWithNoOp(node, optimized_graph);
+    } else {
+      ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node,
+                                   optimized_graph);
+    }
+  }
+}
+
 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
                                              NodeDef* node) {
   if (!IsEnter(*node) || node->input_size() == 0 ||
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 7a06cfc1e1a..79ef82c9a0a 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -106,6 +106,7 @@ class ConstantFolding : public GraphOptimizer {
   void ReplaceOperationWithSnapshot(int input_to_forward,
                                     const GraphProperties& properties,
                                     NodeDef* node, GraphDef* graph);
+  void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph);
   void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
                                              const GraphProperties& properties,
                                              NodeDef* node, GraphDef* graph);
@@ -229,6 +230,7 @@ class ConstantFolding : public GraphOptimizer {
       const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
   // Changes a reduction into an Identity op, returning true on success.
   bool ReplaceReductionWithIdentity(NodeDef* node) const;
+
   // Simplifies a Reduction operation to an Identity/Reshape operation if
   // applicable.
   bool SimplifyReduction(GraphDef* optimized_graph,
@@ -286,6 +288,10 @@ class ConstantFolding : public GraphOptimizer {
   bool SimplifySelect(const GraphProperties& properties,
                       GraphDef* optimized_graph, NodeDef* node);
 
+  // Replaces variable updates that are effectively no-ops with NoOp nodes.
+  void RemoveRedundantVariableUpdates(const GraphProperties& properties,
+                                      GraphDef* optimized_graph, NodeDef* node);
+
   // Removes Reverse op over dimensions with size 1.
   Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
                        GraphDef* optimized_graph, NodeDef* node);
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 58ef14e3d3d..10914860710 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -94,14 +94,33 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
 bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
   if (HasRegularOutputs(node, *node_map_)) {
     // The output values of this node may be needed.
+    VLOG(3) << "Not safe to convert '" << node.name()
+            << " to NoOp. Node has outputs.";
     return false;
   }
-  if (!fetch_nodes_known_ ||
-      nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
+  if (!fetch_nodes_known_) {
+    VLOG(3) << "Not safe to convert '" << node.name()
+            << " to NoOp. Fetches unknown.";
     return false;
   }
-  if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
-      !IsFreeOfSideEffect(node)) {
+  if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
+    VLOG(3) << "Not safe to convert to NoOp: " << node.name()
+            << " is in preserve set.";
+    return false;
+  }
+  if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
+    VLOG(3) << "Not safe to convert '" << node.name()
+            << " to NoOp. Node modifies frame info.";
+    return false;
+  }
+  // Ops reading variables are marked as stateful, but are safe to remove if
+  // redundant.
+  const bool is_variable_read = IsReadVariableOp(node) ||
+                                IsReadVariablesOp(node) ||
+                                absl::StrContains(node.op(), "Gather");
+  if (!is_variable_read && !IsFreeOfSideEffect(node)) {
+    VLOG(3) << "Not safe to convert '" << node.name()
+            << " to NoOp. Node has side effect.";
     return false;
   }
   if (node.op().rfind("Submodel", 0) == 0) {
diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py
index 3ba5b7418a7..3336d3f7e8f 100644
--- a/tensorflow/python/grappler/constant_folding_test.py
+++ b/tensorflow/python/grappler/constant_folding_test.py
@@ -20,12 +20,17 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import test
 
 
@@ -63,6 +68,48 @@ class ConstantFoldingTest(test.TestCase):
       y_v = self.evaluate(y)
       self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
 
+  # See b/159753857.
+  def testGradientGraphOptimization(self):
+
+    @def_function.function
+    def f(x, y):
+      with backprop.GradientTape() as tape:
+        z = math_ops.mul(x, array_ops.zeros_like(x))
+        l = math_ops.add(z, y)
+        l = math_ops.reduce_sum(l)
+
+      gx, gy = tape.gradient(l, [x, y])
+      x.assign_add(gx)
+      y.assign_add(gy)
+      return x + y
+
+    # XLA completely optimizes away the variable reads and
+    # assignments, so skip the test.
+    if test_util.is_xla_enabled():
+      self.skipTest('Not relevant for XLA')
+    with context.eager_mode():
+      x = resource_variable_ops.ResourceVariable(
+          np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
+      y = resource_variable_ops.ResourceVariable(
+          np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
+      with context.collect_graphs(optimized=True) as graphs:
+        f(x, y).numpy()
+    self.assertLen(graphs, 1)
+    assign_count = 0
+    read_count = 0
+    for node in graphs[0].node:
+      if node.op == 'AssignAddVariableOp':
+        self.assertEqual(node.input[0], 'y')
+        assign_count += 1
+      if node.op == 'ReadVariableOp':
+        read_count += 1
+
+    # Make sure that the only variable update that remains after
+    # grappler optimization is that of y, and that we prune all
+    # but the 2 necessary variable reads.
+    self.assertEqual(assign_count, 1)
+    self.assertEqual(read_count, 2)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index 2beed594479..446ef128c26 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -56,6 +56,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
         rewriter_config_pb2.RewriterConfig(
             disable_model_pruning=True,
             constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+            dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
             memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
     graph = tf_optimizer.OptimizeGraph(config, mg)