Extracts the following optimizations into methods:

PartialConstPropThroughIdentityN
ConstantPushDown

PiperOrigin-RevId: 196520167
This commit is contained in:
A. Unique TensorFlower 2018-05-14 09:45:42 -07:00 committed by TensorFlower Gardener
parent 0c59fdb949
commit 6d41d9fb0c
2 changed files with 44 additions and 22 deletions

View File

@ -2157,6 +2157,30 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
return Status::OK();
}
if (ConstantPushDown(node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConstPropThroughIdentityN(node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
return Status::OK();
}
bool ConstantFolding::ConstantPushDown(NodeDef* node) {
// Consider the transformation
//
// + + = parent
@ -2178,22 +2202,22 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
// division/multiplication.
// Don't touch BiasAdd since they can't handle vectors as their first
// inputs.
if (has_fetch_ && (IsAdd(*node) || is_mul) &&
if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
NumNonControlInputs(*node) == 2) {
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
// One child must be constant, and the other the same op as the parent.
if (node->op() != left_child->op() && node->op() != right_child->op()) {
return Status::OK();
return false;
}
const bool left_child_is_constant = IsReallyConstant(*left_child);
const bool right_child_is_constant = IsReallyConstant(*right_child);
if (!left_child_is_constant && !right_child_is_constant) {
return Status::OK();
return false;
}
if (node->device() != left_child->device() ||
node->device() != right_child->device()) {
return Status::OK();
return false;
}
NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
NodeDef* const_child_node =
@ -2203,7 +2227,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
nodes_to_preserve_.find(op_child_node->name()) !=
nodes_to_preserve_.end() ||
NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
return Status::OK();
return false;
}
// Identify the nodes to swap.
@ -2213,7 +2237,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
// Child is already foldable, leave it alone.
return Status::OK();
return false;
}
const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
const int parent_const_input = left_child_is_constant ? 0 : 1;
@ -2238,10 +2262,12 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
node->input(parent_const_input));
std::swap(*node->mutable_input(parent_const_input),
*op_child_node->mutable_input(non_const_leaf_input));
graph_modified_ = true;
return Status::OK();
return true;
}
return false;
}
bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
// Partial constant propagation through IdentityN.
if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) {
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
@ -2294,22 +2320,10 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
for (NodeDef* consumer : consumers) {
DedupControlInputs(consumer);
}
graph_modified_ = true;
return Status::OK();
return true;
}
}
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
return Status::OK();
return false;
}
bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,

View File

@ -113,6 +113,14 @@ class ConstantFolding : public GraphOptimizer {
bool PartialAssocOpConstFolding(GraphDef* optimized_graph,
GraphProperties* properties, NodeDef* node);
// Applies partial constant propagation through IdentityN operator.
// Returns true if the transformation applied successfully.
bool PartialConstPropThroughIdentityN(NodeDef* node);
// Pushes down constants on '+' and '*' operators if applicable. Returns true
// the transformation applied successfully.
bool ConstantPushDown(NodeDef* node);
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;