Don't push down constant tensors if they increase the amount of broadcasting, unless they are guaranteed to be constant folded.
PiperOrigin-RevId: 267185933
This commit is contained in:
parent
357615e12a
commit
b883c28ffe
@ -1987,7 +1987,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
|||||||
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
|
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
|
||||||
*properties, use_shape_info, optimized_graph, node));
|
*properties, use_shape_info, optimized_graph, node));
|
||||||
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
|
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
|
||||||
SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
|
ConstantPushDown(*properties, optimized_graph, node));
|
||||||
SET_AND_RETURN_IF_MODIFIED(
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
MulConvPushDown(optimized_graph, node, *properties));
|
MulConvPushDown(optimized_graph, node, *properties));
|
||||||
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
|
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
|
||||||
@ -2840,7 +2841,8 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||||
|
GraphDef* optimized_graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
// Consider the transformation
|
// Consider the transformation
|
||||||
//
|
//
|
||||||
@ -2853,7 +2855,7 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
|||||||
// where C is constant, X is non-constant, Y may be constant or non-constant,
|
// where C is constant, X is non-constant, Y may be constant or non-constant,
|
||||||
// and '+' denotes an associative and commutative operator like addition or
|
// and '+' denotes an associative and commutative operator like addition or
|
||||||
// multiplication. This optimization pushes constants down in the tree to
|
// multiplication. This optimization pushes constants down in the tree to
|
||||||
// canonicalize it. Moreoever, in cases where the child node has a second
|
// canonicalize it. Moreover, in cases where the child node has a second
|
||||||
// constant input Y we will create a leaf node that can be folded, e.g.
|
// constant input Y we will create a leaf node that can be folded, e.g.
|
||||||
//
|
//
|
||||||
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
|
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
|
||||||
@ -2944,6 +2946,7 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
|||||||
node->device() != right_leaf->device()) {
|
node->device() != right_leaf->device()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the node names corresponding to X, Y, and C.
|
// Get the node names corresponding to X, Y, and C.
|
||||||
const string input_x =
|
const string input_x =
|
||||||
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
|
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
|
||||||
@ -2957,6 +2960,38 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
|||||||
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
|
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
|
||||||
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
|
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
|
||||||
|
|
||||||
|
const NodeDef* y_node = left_leaf_is_constant ? left_leaf : right_leaf;
|
||||||
|
if (!IsReallyConstant(*y_node)) {
|
||||||
|
// Now make sure that we do not push a tensor that is larger than the tensor
|
||||||
|
// it replaces down, since that would create more broadcasting and increase
|
||||||
|
// work.
|
||||||
|
const std::vector<OpInfo::TensorProperties>& root_props =
|
||||||
|
properties.GetInputProperties(node->name());
|
||||||
|
const std::vector<OpInfo::TensorProperties>& op_props =
|
||||||
|
properties.GetInputProperties(op_child->name());
|
||||||
|
if (!root_props.empty() && !op_props.empty()) {
|
||||||
|
DCHECK_EQ(2, root_props.size()) << node->DebugString();
|
||||||
|
DCHECK_EQ(2, op_props.size()) << op_child->DebugString();
|
||||||
|
const PartialTensorShape c_shape(
|
||||||
|
root_props[left_child_is_constant ? 0 : 1].shape());
|
||||||
|
const PartialTensorShape x_shape(
|
||||||
|
op_props[left_leaf_is_constant ? 0 : 1].shape());
|
||||||
|
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
|
||||||
|
c_shape.num_elements() > x_shape.num_elements()) {
|
||||||
|
return false;
|
||||||
|
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
|
||||||
|
c_shape.dims() > 0) {
|
||||||
|
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims());
|
||||||
|
++idx) {
|
||||||
|
if (x_shape.dim_size(idx) >= 0 &&
|
||||||
|
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Now we have identified the nodes to swap (non_const_leaf_input and
|
// Now we have identified the nodes to swap (non_const_leaf_input and
|
||||||
// const_child).
|
// const_child).
|
||||||
node_map_->UpdateInput(node->name(), input_c, input_x);
|
node_map_->UpdateInput(node->name(), input_c, input_x);
|
||||||
|
@ -145,7 +145,8 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
|
|
||||||
// Pushes down constants on '+' and '*' operators if applicable. Returns true
|
// Pushes down constants on '+' and '*' operators if applicable. Returns true
|
||||||
// the transformation applied successfully.
|
// the transformation applied successfully.
|
||||||
bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node);
|
bool ConstantPushDown(const GraphProperties& properties,
|
||||||
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Aggregate constants present around a conv operator. Returns true if the
|
// Aggregate constants present around a conv operator. Returns true if the
|
||||||
// transformation was applied successfully.
|
// transformation was applied successfully.
|
||||||
|
Loading…
Reference in New Issue
Block a user