[Grappler] Update optimized_graph instead of graph_ in ConstantPushDown and MulConvPushDown of ConstantFolding optimizer.
PiperOrigin-RevId: 226931428
This commit is contained in:
parent
c1d57ad8e1
commit
87af907e8c
@ -1697,12 +1697,12 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ConstantPushDown(node)) {
|
||||
if (ConstantPushDown(optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (MulConvPushDown(node, *properties)) {
|
||||
if (MulConvPushDown(optimized_graph, node, *properties)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -2612,7 +2612,8 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::ConstantPushDown(NodeDef* node) {
|
||||
bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
// Consider the transformation
|
||||
//
|
||||
// + + = parent
|
||||
@ -2680,10 +2681,10 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
|
||||
// edge. We can replace such a control edge with a control edge from A
|
||||
// to C.
|
||||
CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
|
||||
graph_, node_map_.get()));
|
||||
optimized_graph, node_map_.get()));
|
||||
string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
|
||||
: op_child_node->input(1);
|
||||
MaybeAddControlInput(other_leaf_input, const_child_node, graph_,
|
||||
MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
|
||||
node_map_.get());
|
||||
}
|
||||
|
||||
@ -2700,7 +2701,7 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::MulConvPushDown(NodeDef* node,
|
||||
bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
|
||||
const GraphProperties& properties) {
|
||||
// Push down multiplication on ConvND.
|
||||
// * ConvND
|
||||
@ -2792,12 +2793,13 @@ bool ConstantFolding::MulConvPushDown(NodeDef* node,
|
||||
}
|
||||
// Make sure we don't introduce loops in the graph by removing control
|
||||
// dependencies from the conv2d node to c2.
|
||||
NodeDef* conv_const_node =
|
||||
conv_left_is_constant ? conv_left_child : conv_right_child;
|
||||
if (MaybeRemoveControlInput(conv_node->name(), const_node, graph_,
|
||||
string conv_const_input =
|
||||
conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
|
||||
if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
|
||||
node_map_.get())) {
|
||||
// Add a control dep from c1 to c2 to ensure c2 is in the right frame
|
||||
*const_node->add_input() = AsControlDependency(*conv_const_node);
|
||||
MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
|
||||
node_map_.get());
|
||||
}
|
||||
|
||||
conv_node->set_name(node->name());
|
||||
@ -2809,6 +2811,8 @@ bool ConstantFolding::MulConvPushDown(NodeDef* node,
|
||||
node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
|
||||
conv_node->set_input(1, mul_new_name);
|
||||
}
|
||||
NodeDef* conv_const_node =
|
||||
conv_left_is_constant ? conv_left_child : conv_right_child;
|
||||
if (left_child_is_constant) {
|
||||
node->set_input(1, conv_const_node->name());
|
||||
} else {
|
||||
|
@ -124,11 +124,12 @@ class ConstantFolding : public GraphOptimizer {
|
||||
|
||||
// Pushes down constants on '+' and '*' operators if applicable. Returns true
|
||||
// the transformation applied successfully.
|
||||
bool ConstantPushDown(NodeDef* node);
|
||||
bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Aggregate constants present around a conv operator. Returns true if the
|
||||
// transformation was applied successfully.
|
||||
bool MulConvPushDown(NodeDef* node, const GraphProperties& properties);
|
||||
bool MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
|
||||
const GraphProperties& properties);
|
||||
|
||||
// Strength reduces floating point division by a constant Div(x, const) to
|
||||
// multiplication by the reciprocal Mul(x, Reciprocal(const)).
|
||||
|
Loading…
Reference in New Issue
Block a user