[Grappler] Update optimized_graph instead of graph_ in ConstantPushDown and MulConvPushDown of ConstantFolding optimizer.

PiperOrigin-RevId: 226931428
This commit is contained in:
Andy Ly 2018-12-26 10:30:24 -08:00 committed by TensorFlower Gardener
parent c1d57ad8e1
commit 87af907e8c
2 changed files with 17 additions and 12 deletions

View File

@ -1697,12 +1697,12 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK(); return Status::OK();
} }
if (ConstantPushDown(node)) { if (ConstantPushDown(optimized_graph, node)) {
graph_modified_ = true; graph_modified_ = true;
return Status::OK(); return Status::OK();
} }
if (MulConvPushDown(node, *properties)) { if (MulConvPushDown(optimized_graph, node, *properties)) {
graph_modified_ = true; graph_modified_ = true;
return Status::OK(); return Status::OK();
} }
@ -2612,7 +2612,8 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
return false; return false;
} }
bool ConstantFolding::ConstantPushDown(NodeDef* node) { bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
NodeDef* node) {
// Consider the transformation // Consider the transformation
// //
// + + = parent // + + = parent
@ -2680,10 +2681,10 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
// edge. We can replace such a control edge with a control edge from A // edge. We can replace such a control edge with a control edge from A
// to C. // to C.
CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node, CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
graph_, node_map_.get())); optimized_graph, node_map_.get()));
string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0) string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
: op_child_node->input(1); : op_child_node->input(1);
MaybeAddControlInput(other_leaf_input, const_child_node, graph_, MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
node_map_.get()); node_map_.get());
} }
@ -2700,7 +2701,7 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
return false; return false;
} }
bool ConstantFolding::MulConvPushDown(NodeDef* node, bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
const GraphProperties& properties) { const GraphProperties& properties) {
// Push down multiplication on ConvND. // Push down multiplication on ConvND.
// * ConvND // * ConvND
@ -2792,12 +2793,13 @@ bool ConstantFolding::MulConvPushDown(NodeDef* node,
} }
// Make sure we don't introduce loops in the graph by removing control // Make sure we don't introduce loops in the graph by removing control
// dependencies from the conv2d node to c2. // dependencies from the conv2d node to c2.
NodeDef* conv_const_node = string conv_const_input =
conv_left_is_constant ? conv_left_child : conv_right_child; conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
if (MaybeRemoveControlInput(conv_node->name(), const_node, graph_, if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
node_map_.get())) { node_map_.get())) {
// Add a control dep from c1 to c2 to ensure c2 is in the right frame // Add a control dep from c1 to c2 to ensure c2 is in the right frame
*const_node->add_input() = AsControlDependency(*conv_const_node); MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
node_map_.get());
} }
conv_node->set_name(node->name()); conv_node->set_name(node->name());
@ -2809,6 +2811,8 @@ bool ConstantFolding::MulConvPushDown(NodeDef* node,
node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name); node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
conv_node->set_input(1, mul_new_name); conv_node->set_input(1, mul_new_name);
} }
NodeDef* conv_const_node =
conv_left_is_constant ? conv_left_child : conv_right_child;
if (left_child_is_constant) { if (left_child_is_constant) {
node->set_input(1, conv_const_node->name()); node->set_input(1, conv_const_node->name());
} else { } else {

View File

@ -124,11 +124,12 @@ class ConstantFolding : public GraphOptimizer {
// Pushes down constants on '+' and '*' operators if applicable. Returns true // Pushes down constants on '+' and '*' operators if applicable. Returns true
// the transformation applied successfully. // the transformation applied successfully.
bool ConstantPushDown(NodeDef* node); bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node);
// Aggregate constants present around a conv operator. Returns true if the // Aggregate constants present around a conv operator. Returns true if the
// transformation was applied successfully. // transformation was applied successfully.
bool MulConvPushDown(NodeDef* node, const GraphProperties& properties); bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
const GraphProperties& properties);
// Strength reduces floating point division by a constant Div(x, const) to // Strength reduces floating point division by a constant Div(x, const) to
// multiplication by the reciprocal Mul(x, Reciprocal(const)). // multiplication by the reciprocal Mul(x, Reciprocal(const)).