From 2394a61d64fd198965a0b027b7b840eaca0d6f25 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Tue, 30 Jun 2020 12:01:25 -0700 Subject: [PATCH] [Grappler] Replace redundant variable updates (e.g. AssignAdd of all zero deltas) with NoOps or Identity. Also relax a constraint in dependency optimizer to allow pruning of unused variable reads. PiperOrigin-RevId: 319069467 Change-Id: I97199e464c8a9fee0055077efbcf481f2545bf8a --- tensorflow/core/grappler/op_types.cc | 4 ++ tensorflow/core/grappler/op_types.h | 1 + .../grappler/optimizers/constant_folding.cc | 69 +++++++++++++++++++ .../grappler/optimizers/constant_folding.h | 6 ++ .../optimizers/dependency_optimizer.cc | 27 ++++++-- .../python/grappler/constant_folding_test.py | 47 +++++++++++++ .../python/grappler/memory_optimizer_test.py | 1 + 7 files changed, 151 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index efd23b6005e..9d30f24e047 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -429,6 +429,10 @@ bool IsReadVariableOp(const NodeDef& node) { return node.op() == "ReadVariableOp"; } +bool IsReadVariablesOp(const NodeDef& node) { + return node.op() == "_ReadVariablesOp"; +} + bool IsReal(const NodeDef& node) { return node.op() == "Real"; } bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 59fc68daba5..141eda7415a 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -136,6 +136,7 @@ bool IsQueue(const NodeDef& node); bool IsRandomShuffle(const NodeDef& node); bool IsRank(const NodeDef& node); bool IsReadVariableOp(const NodeDef& node); +bool IsReadVariablesOp(const NodeDef& node); bool IsReal(const NodeDef& node); bool IsRealDiv(const NodeDef& node); bool IsReciprocalGrad(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 50a3daf379f..b8c2958b6bd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1781,9 +1781,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { return false; } +// Replace an operation with Identity. void ConstantFolding::ReplaceOperationWithIdentity( int input_to_forward, const GraphProperties& properties, NodeDef* node, GraphDef* graph) { + if (input_to_forward < 0 || input_to_forward >= node->input_size()) return; const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); if (dtype == DT_INVALID) return; @@ -1836,6 +1838,26 @@ void ConstantFolding::ReplaceOperationWithSnapshot( graph_modified_ = true; } +// Replace a node with NoOp. Change all inputs to control dependencies. +// If the node has non-control outputs, no change will be performed. +void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) { + if (HasRegularOutputs(*node, *node_map_)) return; + node->set_op("NoOp"); + node->clear_attr(); + // Change all inputs to control dependencies. + for (int i = 0; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + const string ctrl_dep = + AddControlDependency(node->input(i), graph, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); + node->set_input(i, ctrl_dep); + } + DedupControlInputs(node); + graph_modified_ = true; +} + void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo( int input_to_broadcast, const GraphProperties& properties, NodeDef* node, GraphDef* graph) { @@ -2036,6 +2058,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED( SimplifySelect(*properties, optimized_graph, node)); + RETURN_IF_MODIFIED( + RemoveRedundantVariableUpdates(*properties, optimized_graph, node)); graph_modified_ = graph_modified_cached; return Status::OK(); @@ -2460,6 +2484,51 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties, return true; } +void ConstantFolding::RemoveRedundantVariableUpdates( + const GraphProperties& properties, GraphDef* optimized_graph, + NodeDef* node) { + static const absl::flat_hash_set<string>* kVariableReadOps = + new absl::flat_hash_set<string>{"AssignAddVariableOp", + "AssignSubVariableOp", + "AssignAdd", + "AssignSub", + "ScatterAdd", + "ScatterSub", + "ScatterMul", + "ScatterDiv", + "ScatterNdAdd", + "ScatterNdSub", + "ScatterNdMul", + "ScatterNdDiv", + "ResourceScatterAdd", + "ResourceScatterSub", + "ResourceScatterMul", + "ResourceScatterDiv", + "ResourceScatterNdAdd", + "ResourceScatterNdSub", + "ResourceScatterNdMul", + "ResourceScatterNdDiv"}; + if (kVariableReadOps == nullptr || + kVariableReadOps->find(node->op()) == kVariableReadOps->end()) + return; + const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1; + const NodeDef* delta_node = node_map_->GetNode(node->input(value_index)); + if (delta_node == nullptr) return; + const bool is_add_or_sub = absl::StrContains(node->op(), "Add") || + absl::StrContains(node->op(), "Sub"); + if ((is_add_or_sub && IsZeros(*delta_node)) || + (!is_add_or_sub && IsOnes(*delta_node))) { + VLOG(1) << "Removing redundant variable update: " << node->DebugString(); + if (absl::StrContains(node->op(), "Variable") || + absl::StrContains(node->op(), "Resource")) { + ReplaceOperationWithNoOp(node, optimized_graph); + } else { + ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node, + optimized_graph); + } + } +} + bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node) { if (!IsEnter(*node) || node->input_size() == 0 || diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 7a06cfc1e1a..79ef82c9a0a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -106,6 +106,7 @@ class ConstantFolding : public GraphOptimizer { void ReplaceOperationWithSnapshot(int input_to_forward, const GraphProperties& properties, NodeDef* node, GraphDef* graph); + void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph); void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast, const GraphProperties& properties, NodeDef* node, GraphDef* graph); @@ -229,6 +230,7 @@ class ConstantFolding : public GraphOptimizer { const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const; // Changes a reduction into an Identity op, returning true on success. bool ReplaceReductionWithIdentity(NodeDef* node) const; + // Simplifies a Reduction operation to an Identity/Reshape operation if // applicable. bool SimplifyReduction(GraphDef* optimized_graph, @@ -286,6 +288,10 @@ class ConstantFolding : public GraphOptimizer { bool SimplifySelect(const GraphProperties& properties, GraphDef* optimized_graph, NodeDef* node); + // Replaces variable updates that are effectively no-ops with NoOp nodes. + void RemoveRedundantVariableUpdates(const GraphProperties& properties, + GraphDef* optimized_graph, NodeDef* node); + // Removes Reverse op over dimensions with size 1. Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 58ef14e3d3d..10914860710 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -94,14 +94,33 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const { bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const { if (HasRegularOutputs(node, *node_map_)) { // The output values of this node may be needed. + VLOG(3) << "Not safe to convert '" << node.name() + << " to NoOp. Node has outputs."; return false; } - if (!fetch_nodes_known_ || - nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + if (!fetch_nodes_known_) { + VLOG(3) << "Not safe to convert '" << node.name() + << " to NoOp. Fetches unknown."; return false; } - if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) || - !IsFreeOfSideEffect(node)) { + if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + VLOG(3) << "Not safe to convert to NoOp: " << node.name() + << " is in preserve set."; + return false; + } + if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) { + VLOG(3) << "Not safe to convert '" << node.name() + << " to NoOp. Node modifies frame info."; + return false; + } + // Ops reading variables are marked as stateful, but are safe to remove if + // redundant. + const bool is_variable_read = IsReadVariableOp(node) || + IsReadVariablesOp(node) || + absl::StrContains(node.op(), "Gather"); + if (!is_variable_read && !IsFreeOfSideEffect(node)) { + VLOG(3) << "Not safe to convert '" << node.name() + << " to NoOp. Node has side effect."; return false; } if (node.op().rfind("Submodel", 0) == 0) { diff --git a/tensorflow/python/grappler/constant_folding_test.py b/tensorflow/python/grappler/constant_folding_test.py index 3ba5b7418a7..3336d3f7e8f 100644 --- a/tensorflow/python/grappler/constant_folding_test.py +++ b/tensorflow/python/grappler/constant_folding_test.py @@ -20,12 +20,17 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -63,6 +68,48 @@ class ConstantFoldingTest(test.TestCase): y_v = self.evaluate(y) self.assertAllEqual(np.zeros([10, 20, 30]), y_v) + # See b/159753857. + def testGradientGraphOptimization(self): + + @def_function.function + def f(x, y): + with backprop.GradientTape() as tape: + z = math_ops.mul(x, array_ops.zeros_like(x)) + l = math_ops.add(z, y) + l = math_ops.reduce_sum(l) + + gx, gy = tape.gradient(l, [x, y]) + x.assign_add(gx) + y.assign_add(gy) + return x + y + + # XLA completely optimizes away the variable reads and + # assignments, so skip the test. + if test_util.is_xla_enabled(): + self.skipTest('Not relevant for XLA') + with context.eager_mode(): + x = resource_variable_ops.ResourceVariable( + np.random.uniform(size=[2, 2]), dtype=dtypes.float32) + y = resource_variable_ops.ResourceVariable( + np.random.uniform(size=[2, 2]), dtype=dtypes.float32) + with context.collect_graphs(optimized=True) as graphs: + f(x, y).numpy() + self.assertLen(graphs, 1) + assign_count = 0 + read_count = 0 + for node in graphs[0].node: + if node.op == 'AssignAddVariableOp': + self.assertEqual(node.input[0], 'y') + assign_count += 1 + if node.op == 'ReadVariableOp': + read_count += 1 + + # Make sure that the only variable update that remains after + # grappler optimization is that of y, and that we prune all + # but the 2 necessary variable reads. + self.assertEqual(assign_count, 1) + self.assertEqual(read_count, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 2beed594479..446ef128c26 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -56,6 +56,7 @@ class MemoryOptimizerSwapTest(test.TestCase): rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)) graph = tf_optimizer.OptimizeGraph(config, mg)