[Grappler] Replace redundant variable updates (e.g. AssignAdd of all zero deltas) with NoOps or Identity. Also relax a constraint in dependency optimizer to allow pruning of unused variable reads.
PiperOrigin-RevId: 319069467 Change-Id: I97199e464c8a9fee0055077efbcf481f2545bf8a
This commit is contained in:
parent
913738f832
commit
2394a61d64
@ -429,6 +429,10 @@ bool IsReadVariableOp(const NodeDef& node) {
|
|||||||
return node.op() == "ReadVariableOp";
|
return node.op() == "ReadVariableOp";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsReadVariablesOp(const NodeDef& node) {
|
||||||
|
return node.op() == "_ReadVariablesOp";
|
||||||
|
}
|
||||||
|
|
||||||
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
|
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
|
||||||
|
|
||||||
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
|
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
|
||||||
|
@ -136,6 +136,7 @@ bool IsQueue(const NodeDef& node);
|
|||||||
bool IsRandomShuffle(const NodeDef& node);
|
bool IsRandomShuffle(const NodeDef& node);
|
||||||
bool IsRank(const NodeDef& node);
|
bool IsRank(const NodeDef& node);
|
||||||
bool IsReadVariableOp(const NodeDef& node);
|
bool IsReadVariableOp(const NodeDef& node);
|
||||||
|
bool IsReadVariablesOp(const NodeDef& node);
|
||||||
bool IsReal(const NodeDef& node);
|
bool IsReal(const NodeDef& node);
|
||||||
bool IsRealDiv(const NodeDef& node);
|
bool IsRealDiv(const NodeDef& node);
|
||||||
bool IsReciprocalGrad(const NodeDef& node);
|
bool IsReciprocalGrad(const NodeDef& node);
|
||||||
|
@ -1781,9 +1781,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replace an operation with Identity.
|
||||||
void ConstantFolding::ReplaceOperationWithIdentity(
|
void ConstantFolding::ReplaceOperationWithIdentity(
|
||||||
int input_to_forward, const GraphProperties& properties, NodeDef* node,
|
int input_to_forward, const GraphProperties& properties, NodeDef* node,
|
||||||
GraphDef* graph) {
|
GraphDef* graph) {
|
||||||
|
if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
|
||||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||||
if (dtype == DT_INVALID) return;
|
if (dtype == DT_INVALID) return;
|
||||||
|
|
||||||
@ -1836,6 +1838,26 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
|
|||||||
graph_modified_ = true;
|
graph_modified_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replace a node with NoOp. Change all inputs to control dependencies.
|
||||||
|
// If the node has non-control outputs, no change will be performed.
|
||||||
|
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) {
|
||||||
|
if (HasRegularOutputs(*node, *node_map_)) return;
|
||||||
|
node->set_op("NoOp");
|
||||||
|
node->clear_attr();
|
||||||
|
// Change all inputs to control dependencies.
|
||||||
|
for (int i = 0; i < node->input_size(); ++i) {
|
||||||
|
if (IsControlInput(node->input(i))) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const string ctrl_dep =
|
||||||
|
AddControlDependency(node->input(i), graph, node_map_.get());
|
||||||
|
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||||
|
node->set_input(i, ctrl_dep);
|
||||||
|
}
|
||||||
|
DedupControlInputs(node);
|
||||||
|
graph_modified_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
|
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
|
||||||
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
|
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
|
||||||
GraphDef* graph) {
|
GraphDef* graph) {
|
||||||
@ -2036,6 +2058,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
|||||||
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
|
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
|
||||||
SET_AND_RETURN_IF_MODIFIED(
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
SimplifySelect(*properties, optimized_graph, node));
|
SimplifySelect(*properties, optimized_graph, node));
|
||||||
|
RETURN_IF_MODIFIED(
|
||||||
|
RemoveRedundantVariableUpdates(*properties, optimized_graph, node));
|
||||||
|
|
||||||
graph_modified_ = graph_modified_cached;
|
graph_modified_ = graph_modified_cached;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -2460,6 +2484,51 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConstantFolding::RemoveRedundantVariableUpdates(
|
||||||
|
const GraphProperties& properties, GraphDef* optimized_graph,
|
||||||
|
NodeDef* node) {
|
||||||
|
static const absl::flat_hash_set<string>* kVariableReadOps =
|
||||||
|
new absl::flat_hash_set<string>{"AssignAddVariableOp",
|
||||||
|
"AssignSubVariableOp",
|
||||||
|
"AssignAdd",
|
||||||
|
"AssignSub",
|
||||||
|
"ScatterAdd",
|
||||||
|
"ScatterSub",
|
||||||
|
"ScatterMul",
|
||||||
|
"ScatterDiv",
|
||||||
|
"ScatterNdAdd",
|
||||||
|
"ScatterNdSub",
|
||||||
|
"ScatterNdMul",
|
||||||
|
"ScatterNdDiv",
|
||||||
|
"ResourceScatterAdd",
|
||||||
|
"ResourceScatterSub",
|
||||||
|
"ResourceScatterMul",
|
||||||
|
"ResourceScatterDiv",
|
||||||
|
"ResourceScatterNdAdd",
|
||||||
|
"ResourceScatterNdSub",
|
||||||
|
"ResourceScatterNdMul",
|
||||||
|
"ResourceScatterNdDiv"};
|
||||||
|
if (kVariableReadOps == nullptr ||
|
||||||
|
kVariableReadOps->find(node->op()) == kVariableReadOps->end())
|
||||||
|
return;
|
||||||
|
const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
|
||||||
|
const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
|
||||||
|
if (delta_node == nullptr) return;
|
||||||
|
const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
|
||||||
|
absl::StrContains(node->op(), "Sub");
|
||||||
|
if ((is_add_or_sub && IsZeros(*delta_node)) ||
|
||||||
|
(!is_add_or_sub && IsOnes(*delta_node))) {
|
||||||
|
VLOG(1) << "Removing redundant variable update: " << node->DebugString();
|
||||||
|
if (absl::StrContains(node->op(), "Variable") ||
|
||||||
|
absl::StrContains(node->op(), "Resource")) {
|
||||||
|
ReplaceOperationWithNoOp(node, optimized_graph);
|
||||||
|
} else {
|
||||||
|
ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node,
|
||||||
|
optimized_graph);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
if (!IsEnter(*node) || node->input_size() == 0 ||
|
if (!IsEnter(*node) || node->input_size() == 0 ||
|
||||||
|
@ -106,6 +106,7 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
void ReplaceOperationWithSnapshot(int input_to_forward,
|
void ReplaceOperationWithSnapshot(int input_to_forward,
|
||||||
const GraphProperties& properties,
|
const GraphProperties& properties,
|
||||||
NodeDef* node, GraphDef* graph);
|
NodeDef* node, GraphDef* graph);
|
||||||
|
void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph);
|
||||||
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
|
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
|
||||||
const GraphProperties& properties,
|
const GraphProperties& properties,
|
||||||
NodeDef* node, GraphDef* graph);
|
NodeDef* node, GraphDef* graph);
|
||||||
@ -229,6 +230,7 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
|
const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
|
||||||
// Changes a reduction into an Identity op, returning true on success.
|
// Changes a reduction into an Identity op, returning true on success.
|
||||||
bool ReplaceReductionWithIdentity(NodeDef* node) const;
|
bool ReplaceReductionWithIdentity(NodeDef* node) const;
|
||||||
|
|
||||||
// Simplifies a Reduction operation to an Identity/Reshape operation if
|
// Simplifies a Reduction operation to an Identity/Reshape operation if
|
||||||
// applicable.
|
// applicable.
|
||||||
bool SimplifyReduction(GraphDef* optimized_graph,
|
bool SimplifyReduction(GraphDef* optimized_graph,
|
||||||
@ -286,6 +288,10 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
bool SimplifySelect(const GraphProperties& properties,
|
bool SimplifySelect(const GraphProperties& properties,
|
||||||
GraphDef* optimized_graph, NodeDef* node);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
|
// Replaces variable updates that are effectively no-ops with NoOp nodes.
|
||||||
|
void RemoveRedundantVariableUpdates(const GraphProperties& properties,
|
||||||
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Removes Reverse op over dimensions with size 1.
|
// Removes Reverse op over dimensions with size 1.
|
||||||
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
@ -94,14 +94,33 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
|
|||||||
bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
|
bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
|
||||||
if (HasRegularOutputs(node, *node_map_)) {
|
if (HasRegularOutputs(node, *node_map_)) {
|
||||||
// The output values of this node may be needed.
|
// The output values of this node may be needed.
|
||||||
|
VLOG(3) << "Not safe to convert '" << node.name()
|
||||||
|
<< " to NoOp. Node has outputs.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!fetch_nodes_known_ ||
|
if (!fetch_nodes_known_) {
|
||||||
nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
VLOG(3) << "Not safe to convert '" << node.name()
|
||||||
|
<< " to NoOp. Fetches unknown.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
|
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||||
!IsFreeOfSideEffect(node)) {
|
VLOG(3) << "Not safe to convert to NoOp: " << node.name()
|
||||||
|
<< " is in preserve set.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
|
||||||
|
VLOG(3) << "Not safe to convert '" << node.name()
|
||||||
|
<< " to NoOp. Node modifies frame info.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Ops reading variables are marked as stateful, but are safe to remove if
|
||||||
|
// redundant.
|
||||||
|
const bool is_variable_read = IsReadVariableOp(node) ||
|
||||||
|
IsReadVariablesOp(node) ||
|
||||||
|
absl::StrContains(node.op(), "Gather");
|
||||||
|
if (!is_variable_read && !IsFreeOfSideEffect(node)) {
|
||||||
|
VLOG(3) << "Not safe to convert '" << node.name()
|
||||||
|
<< " to NoOp. Node has side effect.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (node.op().rfind("Submodel", 0) == 0) {
|
if (node.op().rfind("Submodel", 0) == 0) {
|
||||||
|
@ -20,12 +20,17 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -63,6 +68,48 @@ class ConstantFoldingTest(test.TestCase):
|
|||||||
y_v = self.evaluate(y)
|
y_v = self.evaluate(y)
|
||||||
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
|
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
|
||||||
|
|
||||||
|
# See b/159753857.
|
||||||
|
def testGradientGraphOptimization(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
z = math_ops.mul(x, array_ops.zeros_like(x))
|
||||||
|
l = math_ops.add(z, y)
|
||||||
|
l = math_ops.reduce_sum(l)
|
||||||
|
|
||||||
|
gx, gy = tape.gradient(l, [x, y])
|
||||||
|
x.assign_add(gx)
|
||||||
|
y.assign_add(gy)
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
# XLA completely optimizes away the variable reads and
|
||||||
|
# assignments, so skip the test.
|
||||||
|
if test_util.is_xla_enabled():
|
||||||
|
self.skipTest('Not relevant for XLA')
|
||||||
|
with context.eager_mode():
|
||||||
|
x = resource_variable_ops.ResourceVariable(
|
||||||
|
np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
|
||||||
|
y = resource_variable_ops.ResourceVariable(
|
||||||
|
np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
|
||||||
|
with context.collect_graphs(optimized=True) as graphs:
|
||||||
|
f(x, y).numpy()
|
||||||
|
self.assertLen(graphs, 1)
|
||||||
|
assign_count = 0
|
||||||
|
read_count = 0
|
||||||
|
for node in graphs[0].node:
|
||||||
|
if node.op == 'AssignAddVariableOp':
|
||||||
|
self.assertEqual(node.input[0], 'y')
|
||||||
|
assign_count += 1
|
||||||
|
if node.op == 'ReadVariableOp':
|
||||||
|
read_count += 1
|
||||||
|
|
||||||
|
# Make sure that the only variable update that remains after
|
||||||
|
# grappler optimization is that of y, and that we prune all
|
||||||
|
# but the 2 necessary variable reads.
|
||||||
|
self.assertEqual(assign_count, 1)
|
||||||
|
self.assertEqual(read_count, 2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -56,6 +56,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
|
|||||||
rewriter_config_pb2.RewriterConfig(
|
rewriter_config_pb2.RewriterConfig(
|
||||||
disable_model_pruning=True,
|
disable_model_pruning=True,
|
||||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
|
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||||
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
|
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
|
||||||
graph = tf_optimizer.OptimizeGraph(config, mg)
|
graph = tf_optimizer.OptimizeGraph(config, mg)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user