Extracts the following optimizations into methods:
PartialConstPropThroughIdentityN ConstantPushDown PiperOrigin-RevId: 196520167
This commit is contained in:
parent
0c59fdb949
commit
6d41d9fb0c
@ -2157,6 +2157,30 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ConstantPushDown(node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialConstPropThroughIdentityN(node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConstantFolding::ConstantPushDown(NodeDef* node) {
|
||||
// Consider the transformation
|
||||
//
|
||||
// + + = parent
|
||||
@ -2178,22 +2202,22 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
|
||||
// division/multiplication.
|
||||
// Don't touch BiasAdd since they can't handle vectors as their first
|
||||
// inputs.
|
||||
if (has_fetch_ && (IsAdd(*node) || is_mul) &&
|
||||
if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
|
||||
NumNonControlInputs(*node) == 2) {
|
||||
NodeDef* left_child = node_map_->GetNode(node->input(0));
|
||||
NodeDef* right_child = node_map_->GetNode(node->input(1));
|
||||
// One child must be constant, and the other the same op as the parent.
|
||||
if (node->op() != left_child->op() && node->op() != right_child->op()) {
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
const bool left_child_is_constant = IsReallyConstant(*left_child);
|
||||
const bool right_child_is_constant = IsReallyConstant(*right_child);
|
||||
if (!left_child_is_constant && !right_child_is_constant) {
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
if (node->device() != left_child->device() ||
|
||||
node->device() != right_child->device()) {
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
|
||||
NodeDef* const_child_node =
|
||||
@ -2203,7 +2227,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
|
||||
nodes_to_preserve_.find(op_child_node->name()) !=
|
||||
nodes_to_preserve_.end() ||
|
||||
NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
|
||||
// Identify the nodes to swap.
|
||||
@ -2213,7 +2237,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
|
||||
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
|
||||
if (left_leaf_is_constant && right_leaf_is_constant) {
|
||||
// Child is already foldable, leave it alone.
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
|
||||
const int parent_const_input = left_child_is_constant ? 0 : 1;
|
||||
@ -2238,10 +2262,12 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
|
||||
node->input(parent_const_input));
|
||||
std::swap(*node->mutable_input(parent_const_input),
|
||||
*op_child_node->mutable_input(non_const_leaf_input));
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
|
||||
// Partial constant propagation through IdentityN.
|
||||
if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
|
||||
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
|
||||
@ -2294,22 +2320,10 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
|
||||
for (NodeDef* consumer : consumers) {
|
||||
DedupControlInputs(consumer);
|
||||
}
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
|
||||
|
@ -113,6 +113,14 @@ class ConstantFolding : public GraphOptimizer {
|
||||
bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
|
||||
GraphProperties* properties, NodeDef* node);
|
||||
|
||||
// Applies partial constant propagation through IdentityN operator.
|
||||
// Returns true if the transformation applied successfully.
|
||||
bool PartialConstPropThroughIdentityN(NodeDef* node);
|
||||
|
||||
// Pushes down constants on '+' and '*' operators if applicable. Returns true
|
||||
// the transformation applied successfully.
|
||||
bool ConstantPushDown(NodeDef* node);
|
||||
|
||||
// Points to an externally provided device or to owned_device_;
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
DeviceBase* cpu_device_;
|
||||
|
Loading…
Reference in New Issue
Block a user