[Grappler] Replace redundant variable updates (e.g. AssignAdd of all zero deltas) with NoOps or Identity. Also relax a constraint in dependency optimizer to allow pruning of unused variable reads.
PiperOrigin-RevId: 319069467 Change-Id: I97199e464c8a9fee0055077efbcf481f2545bf8a
This commit is contained in:
parent
913738f832
commit
2394a61d64
tensorflow
core/grappler
python/grappler
@ -429,6 +429,10 @@ bool IsReadVariableOp(const NodeDef& node) {
|
||||
return node.op() == "ReadVariableOp";
|
||||
}
|
||||
|
||||
bool IsReadVariablesOp(const NodeDef& node) {
|
||||
return node.op() == "_ReadVariablesOp";
|
||||
}
|
||||
|
||||
bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
|
||||
|
||||
bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
|
||||
|
@ -136,6 +136,7 @@ bool IsQueue(const NodeDef& node);
|
||||
bool IsRandomShuffle(const NodeDef& node);
|
||||
bool IsRank(const NodeDef& node);
|
||||
bool IsReadVariableOp(const NodeDef& node);
|
||||
bool IsReadVariablesOp(const NodeDef& node);
|
||||
bool IsReal(const NodeDef& node);
|
||||
bool IsRealDiv(const NodeDef& node);
|
||||
bool IsReciprocalGrad(const NodeDef& node);
|
||||
|
@ -1781,9 +1781,11 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Replace an operation with Identity.
|
||||
void ConstantFolding::ReplaceOperationWithIdentity(
|
||||
int input_to_forward, const GraphProperties& properties, NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
|
||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||
if (dtype == DT_INVALID) return;
|
||||
|
||||
@ -1836,6 +1838,26 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
|
||||
graph_modified_ = true;
|
||||
}
|
||||
|
||||
// Replace a node with NoOp. Change all inputs to control dependencies.
|
||||
// If the node has non-control outputs, no change will be performed.
|
||||
void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph) {
|
||||
if (HasRegularOutputs(*node, *node_map_)) return;
|
||||
node->set_op("NoOp");
|
||||
node->clear_attr();
|
||||
// Change all inputs to control dependencies.
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
if (IsControlInput(node->input(i))) {
|
||||
break;
|
||||
}
|
||||
const string ctrl_dep =
|
||||
AddControlDependency(node->input(i), graph, node_map_.get());
|
||||
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||
node->set_input(i, ctrl_dep);
|
||||
}
|
||||
DedupControlInputs(node);
|
||||
graph_modified_ = true;
|
||||
}
|
||||
|
||||
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
|
||||
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
@ -2036,6 +2058,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
SimplifySelect(*properties, optimized_graph, node));
|
||||
RETURN_IF_MODIFIED(
|
||||
RemoveRedundantVariableUpdates(*properties, optimized_graph, node));
|
||||
|
||||
graph_modified_ = graph_modified_cached;
|
||||
return Status::OK();
|
||||
@ -2460,6 +2484,51 @@ bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
|
||||
return true;
|
||||
}
|
||||
|
||||
void ConstantFolding::RemoveRedundantVariableUpdates(
|
||||
const GraphProperties& properties, GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
static const absl::flat_hash_set<string>* kVariableReadOps =
|
||||
new absl::flat_hash_set<string>{"AssignAddVariableOp",
|
||||
"AssignSubVariableOp",
|
||||
"AssignAdd",
|
||||
"AssignSub",
|
||||
"ScatterAdd",
|
||||
"ScatterSub",
|
||||
"ScatterMul",
|
||||
"ScatterDiv",
|
||||
"ScatterNdAdd",
|
||||
"ScatterNdSub",
|
||||
"ScatterNdMul",
|
||||
"ScatterNdDiv",
|
||||
"ResourceScatterAdd",
|
||||
"ResourceScatterSub",
|
||||
"ResourceScatterMul",
|
||||
"ResourceScatterDiv",
|
||||
"ResourceScatterNdAdd",
|
||||
"ResourceScatterNdSub",
|
||||
"ResourceScatterNdMul",
|
||||
"ResourceScatterNdDiv"};
|
||||
if (kVariableReadOps == nullptr ||
|
||||
kVariableReadOps->find(node->op()) == kVariableReadOps->end())
|
||||
return;
|
||||
const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
|
||||
const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
|
||||
if (delta_node == nullptr) return;
|
||||
const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
|
||||
absl::StrContains(node->op(), "Sub");
|
||||
if ((is_add_or_sub && IsZeros(*delta_node)) ||
|
||||
(!is_add_or_sub && IsOnes(*delta_node))) {
|
||||
VLOG(1) << "Removing redundant variable update: " << node->DebugString();
|
||||
if (absl::StrContains(node->op(), "Variable") ||
|
||||
absl::StrContains(node->op(), "Resource")) {
|
||||
ReplaceOperationWithNoOp(node, optimized_graph);
|
||||
} else {
|
||||
ReplaceOperationWithIdentity(0 /* input_to_forward */, properties, node,
|
||||
optimized_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (!IsEnter(*node) || node->input_size() == 0 ||
|
||||
|
@ -106,6 +106,7 @@ class ConstantFolding : public GraphOptimizer {
|
||||
void ReplaceOperationWithSnapshot(int input_to_forward,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceOperationWithNoOp(NodeDef* node, GraphDef* graph);
|
||||
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
@ -229,6 +230,7 @@ class ConstantFolding : public GraphOptimizer {
|
||||
const gtl::InlinedVector<TensorValue, 4>& reduction_indices_vector) const;
|
||||
// Changes a reduction into an Identity op, returning true on success.
|
||||
bool ReplaceReductionWithIdentity(NodeDef* node) const;
|
||||
|
||||
// Simplifies a Reduction operation to an Identity/Reshape operation if
|
||||
// applicable.
|
||||
bool SimplifyReduction(GraphDef* optimized_graph,
|
||||
@ -286,6 +288,10 @@ class ConstantFolding : public GraphOptimizer {
|
||||
bool SimplifySelect(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Replaces variable updates that are effectively no-ops with NoOp nodes.
|
||||
void RemoveRedundantVariableUpdates(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Removes Reverse op over dimensions with size 1.
|
||||
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
@ -94,14 +94,33 @@ bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
|
||||
bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
|
||||
if (HasRegularOutputs(node, *node_map_)) {
|
||||
// The output values of this node may be needed.
|
||||
VLOG(3) << "Not safe to convert '" << node.name()
|
||||
<< " to NoOp. Node has outputs.";
|
||||
return false;
|
||||
}
|
||||
if (!fetch_nodes_known_ ||
|
||||
nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
if (!fetch_nodes_known_) {
|
||||
VLOG(3) << "Not safe to convert '" << node.name()
|
||||
<< " to NoOp. Fetches unknown.";
|
||||
return false;
|
||||
}
|
||||
if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node) ||
|
||||
!IsFreeOfSideEffect(node)) {
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
VLOG(3) << "Not safe to convert to NoOp: " << node.name()
|
||||
<< " is in preserve set.";
|
||||
return false;
|
||||
}
|
||||
if (IsMerge(node) || IsSwitch(node) || ModifiesFrameInfo(node)) {
|
||||
VLOG(3) << "Not safe to convert '" << node.name()
|
||||
<< " to NoOp. Node modifies frame info.";
|
||||
return false;
|
||||
}
|
||||
// Ops reading variables are marked as stateful, but are safe to remove if
|
||||
// redundant.
|
||||
const bool is_variable_read = IsReadVariableOp(node) ||
|
||||
IsReadVariablesOp(node) ||
|
||||
absl::StrContains(node.op(), "Gather");
|
||||
if (!is_variable_read && !IsFreeOfSideEffect(node)) {
|
||||
VLOG(3) << "Not safe to convert '" << node.name()
|
||||
<< " to NoOp. Node has side effect.";
|
||||
return false;
|
||||
}
|
||||
if (node.op().rfind("Submodel", 0) == 0) {
|
||||
|
@ -20,12 +20,17 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -63,6 +68,48 @@ class ConstantFoldingTest(test.TestCase):
|
||||
y_v = self.evaluate(y)
|
||||
self.assertAllEqual(np.zeros([10, 20, 30]), y_v)
|
||||
|
||||
# See b/159753857.
|
||||
def testGradientGraphOptimization(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
with backprop.GradientTape() as tape:
|
||||
z = math_ops.mul(x, array_ops.zeros_like(x))
|
||||
l = math_ops.add(z, y)
|
||||
l = math_ops.reduce_sum(l)
|
||||
|
||||
gx, gy = tape.gradient(l, [x, y])
|
||||
x.assign_add(gx)
|
||||
y.assign_add(gy)
|
||||
return x + y
|
||||
|
||||
# XLA completely optimizes away the variable reads and
|
||||
# assignments, so skip the test.
|
||||
if test_util.is_xla_enabled():
|
||||
self.skipTest('Not relevant for XLA')
|
||||
with context.eager_mode():
|
||||
x = resource_variable_ops.ResourceVariable(
|
||||
np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
|
||||
y = resource_variable_ops.ResourceVariable(
|
||||
np.random.uniform(size=[2, 2]), dtype=dtypes.float32)
|
||||
with context.collect_graphs(optimized=True) as graphs:
|
||||
f(x, y).numpy()
|
||||
self.assertLen(graphs, 1)
|
||||
assign_count = 0
|
||||
read_count = 0
|
||||
for node in graphs[0].node:
|
||||
if node.op == 'AssignAddVariableOp':
|
||||
self.assertEqual(node.input[0], 'y')
|
||||
assign_count += 1
|
||||
if node.op == 'ReadVariableOp':
|
||||
read_count += 1
|
||||
|
||||
# Make sure that the only variable update that remains after
|
||||
# grappler optimization is that of y, and that we prune all
|
||||
# but the 2 necessary variable reads.
|
||||
self.assertEqual(assign_count, 1)
|
||||
self.assertEqual(read_count, 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -56,6 +56,7 @@ class MemoryOptimizerSwapTest(test.TestCase):
|
||||
rewriter_config_pb2.RewriterConfig(
|
||||
disable_model_pruning=True,
|
||||
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
|
||||
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
|
||||
graph = tf_optimizer.OptimizeGraph(config, mg)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user