[Grappler] Replace redundant variable updates (e.g. AssignAdd of all zero deltas) with NoOps or Identity. Also relax a constraint in dependency optimizer to allow pruning of unused variable reads.

PiperOrigin-RevId: 319069467
Change-Id: I97199e464c8a9fee0055077efbcf481f2545bf8a
This commit is contained in:
A. Unique TensorFlower 2020-06-30 12:01:25 -07:00 committed by TensorFlower Gardener
parent 913738f832
commit 2394a61d64
7 changed files with 151 additions and 4 deletions

View File

@ -429,6 +429,10 @@ bool IsReadVariableOp(const NodeDef& node) {
return node.op() == "ReadVariableOp";
}
bool IsReadVariablesOp(const NodeDef& node) {
return node.op() == "_ReadVariablesOp";
}
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }

View File

@ -136,6 +136,7 @@ bool IsQueue(const NodeDef& node);
bool IsRandomShuffle(const NodeDef& node);
bool IsRank(const NodeDef& node);
bool IsReadVariableOp(const NodeDef& node);
bool IsReadVariablesOp(const NodeDef& node);
bool IsReal(const NodeDef& node);
bool IsRealDiv(const NodeDef& node);
bool IsReciprocalGrad(const NodeDef& node);

View File

@ -1781,9 +1781,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
return false;
}
// Replace an operation with Identity.
void ConstantFolding::ReplaceOperationWithIdentity(
int input_to_forward, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
@ -1836,6 +1838,26 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
graph_modified_ = true;
}
// Replace a node with NoOp. Change all inputs to control dependencies.
// If the node has non-control outputs, no change will be performed.
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) {
if (HasRegularOutputs(*node, *node_map_)) return;
node->set_op("NoOp");
node->clear_attr();
// Change all inputs to control dependencies.
for (int i = 0; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
DedupControlInputs(node);
graph_modified_ = true;
}
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
@ -2036,6 +2058,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifySelect(*properties, optimized_graph, node));
RETURN_IF_MODIFIED(
RemoveRedundantVariableUpdates(*properties, optimized_graph, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
@ -2460,6 +2484,51 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
return true;
}
void ConstantFolding::RemoveRedundantVariableUpdates(
const GraphProperties& properties, GraphDef* optimized_graph,
NodeDef* node) {
static const absl::flat_hash_set<string>* kVariableReadOps =
new absl::flat_hash_set<string>{"AssignAddVariableOp",
"AssignSubVariableOp",
"AssignAdd",
"AssignSub",
"ScatterAdd",
"ScatterSub",
"ScatterMul",
"ScatterDiv",
"ScatterNdAdd",
"ScatterNdSub",
"ScatterNdMul",
"ScatterNdDiv",
"ResourceScatterAdd",
"ResourceScatterSub",
"ResourceScatterMul",
"ResourceScatterDiv",
"ResourceScatterNdAdd",
"ResourceScatterNdSub",
"ResourceScatterNdMul",
"ResourceScatterNdDiv"};
if (kVariableReadOps == nullptr ||
kVariableReadOps->find(node->op()) == kVariableReadOps->end())
return;
const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
if (delta_node == nullptr) return;
const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
absl::StrContains(node->op(), "Sub");
if ((is_add_or_sub && IsZeros(*delta_node)) ||
(!is_add_or_sub && IsOnes(*delta_node))) {
VLOG(1) << "Removing redundant variable update: " << node->DebugString();
if (absl::StrContains(node->op(), "Variable") ||
absl::StrContains(node->op(), "Resource")) {
ReplaceOperationWithNoOp(node, optimized_graph);
} else {
ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node,
optimized_graph);
}
}
}
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
NodeDef* node) {
if (!IsEnter(*node) || node->input_size() == 0 ||

View File

@ -106,6 +106,7 @@ class ConstantFolding : public GraphOptimizer {
void ReplaceOperationWithSnapshot(int input_to_forward,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph);
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
@ -229,6 +230,7 @@ class ConstantFolding : public GraphOptimizer {
const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
// Changes a reduction into an Identity op, returning true on success.
bool ReplaceReductionWithIdentity(NodeDef* node) const;
// Simplifies a Reduction operation to an Identity/Reshape operation if
// applicable.
bool SimplifyReduction(GraphDef* optimized_graph,
@ -286,6 +288,10 @@ class ConstantFolding : public GraphOptimizer {
bool SimplifySelect(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
// Replaces variable updates that are effectively no-ops with NoOp nodes.
void RemoveRedundantVariableUpdates(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
// Removes Reverse op over dimensions with size 1.
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node);

View File

@ -94,14 +94,33 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
if (HasRegularOutputs(node, *node_map_)) {
// The output values of this node may be needed.
VLOG(3) << "Not safe to convert '" << node.name()
<< " to NoOp. Node has outputs.";
return false;
}
if (!fetch_nodes_known_ ||
nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
if (!fetch_nodes_known_) {
VLOG(3) << "Not safe to convert '" << node.name()
<< " to NoOp. Fetches unknown.";
return false;
}
if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
!IsFreeOfSideEffect(node)) {
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
VLOG(3) << "Not safe to convert to NoOp: " << node.name()
<< " is in preserve set.";
return false;
}
if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
VLOG(3) << "Not safe to convert '" << node.name()
<< " to NoOp. Node modifies frame info.";
return false;
}
// Ops reading variables are marked as stateful, but are safe to remove if
// redundant.
const bool is_variable_read = IsReadVariableOp(node) ||
IsReadVariablesOp(node) ||
absl::StrContains(node.op(), "Gather");
if (!is_variable_read && !IsFreeOfSideEffect(node)) {
VLOG(3) << "Not safe to convert '" << node.name()
<< " to NoOp. Node has side effect.";
return false;
}
if (node.op().rfind("Submodel", 0) == 0) {

View File

@ -20,12 +20,17 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
@ -63,6 +68,48 @@ class ConstantFoldingTest(test.TestCase):
y_v = self.evaluate(y)
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
# See b/159753857.
def testGradientGraphOptimization(self):
@def_function.function
def f(x, y):
with backprop.GradientTape() as tape:
z = math_ops.mul(x, array_ops.zeros_like(x))
l = math_ops.add(z, y)
l = math_ops.reduce_sum(l)
gx, gy = tape.gradient(l, [x, y])
x.assign_add(gx)
y.assign_add(gy)
return x + y
# XLA completely optimizes away the variable reads and
# assignments, so skip the test.
if test_util.is_xla_enabled():
self.skipTest('Not relevant for XLA')
with context.eager_mode():
x = resource_variable_ops.ResourceVariable(
np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
y = resource_variable_ops.ResourceVariable(
np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
with context.collect_graphs(optimized=True) as graphs:
f(x, y).numpy()
self.assertLen(graphs, 1)
assign_count = 0
read_count = 0
for node in graphs[0].node:
if node.op == 'AssignAddVariableOp':
self.assertEqual(node.input[0], 'y')
assign_count += 1
if node.op == 'ReadVariableOp':
read_count += 1
# Make sure that the only variable update that remains after
# grappler optimization is that of y, and that we prune all
# but the 2 necessary variable reads.
self.assertEqual(assign_count, 1)
self.assertEqual(read_count, 2)
if __name__ == '__main__':
test.main()

View File

@ -56,6 +56,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
graph = tf_optimizer.OptimizeGraph(config, mg)