Don't push down constant tensors if they increase the amount of broadcasting, unless they are guaranteed to be constant folded.
PiperOrigin-RevId: 267185933
This commit is contained in:
parent
357615e12a
commit
b883c28ffe
@ -1987,7 +1987,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
|
||||
*properties, use_shape_info, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
ConstantPushDown(*properties, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
MulConvPushDown(optimized_graph, node, *properties));
|
||||
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
|
||||
@ -2840,7 +2841,8 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
||||
bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
// Consider the transformation
|
||||
//
|
||||
@ -2853,7 +2855,7 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
||||
// where C is constant, X is non-constant, Y may be constant or non-constant,
|
||||
// and '+' denotes an associative and commutative operator like addition or
|
||||
// multiplication. This optimization pushes constants down in the tree to
|
||||
// canonicalize it. Moreoever, in cases where the child node has a second
|
||||
// canonicalize it. Moreover, in cases where the child node has a second
|
||||
// constant input Y we will create a leaf node that can be folded, e.g.
|
||||
//
|
||||
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
|
||||
@ -2944,6 +2946,7 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
||||
node->device() != right_leaf->device()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the node names corresponding to X, Y, and C.
|
||||
const string input_x =
|
||||
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
|
||||
@ -2957,6 +2960,38 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
|
||||
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
|
||||
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
|
||||
|
||||
const NodeDef* y_node = left_leaf_is_constant ? left_leaf : right_leaf;
|
||||
if (!IsReallyConstant(*y_node)) {
|
||||
// Now make sure that we do not push a tensor that is larger than the tensor
|
||||
// it replaces down, since that would create more broadcasting and increase
|
||||
// work.
|
||||
const std::vector<OpInfo::TensorProperties>& root_props =
|
||||
properties.GetInputProperties(node->name());
|
||||
const std::vector<OpInfo::TensorProperties>& op_props =
|
||||
properties.GetInputProperties(op_child->name());
|
||||
if (!root_props.empty() && !op_props.empty()) {
|
||||
DCHECK_EQ(2, root_props.size()) << node->DebugString();
|
||||
DCHECK_EQ(2, op_props.size()) << op_child->DebugString();
|
||||
const PartialTensorShape c_shape(
|
||||
root_props[left_child_is_constant ? 0 : 1].shape());
|
||||
const PartialTensorShape x_shape(
|
||||
op_props[left_leaf_is_constant ? 0 : 1].shape());
|
||||
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
|
||||
c_shape.num_elements() > x_shape.num_elements()) {
|
||||
return false;
|
||||
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
|
||||
c_shape.dims() > 0) {
|
||||
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims());
|
||||
++idx) {
|
||||
if (x_shape.dim_size(idx) >= 0 &&
|
||||
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now we have identified the nodes to swap (non_const_leaf_input and
|
||||
// const_child).
|
||||
node_map_->UpdateInput(node->name(), input_c, input_x);
|
||||
|
@ -145,7 +145,8 @@ class ConstantFolding : public GraphOptimizer {
|
||||
|
||||
// Pushes down constants on '+' and '*' operators if applicable. Returns true
|
||||
// the transformation applied successfully.
|
||||
bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node);
|
||||
bool ConstantPushDown(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Aggregate constants present around a conv operator. Returns true if the
|
||||
// transformation was applied successfully.
|
||||
|
Loading…
Reference in New Issue
Block a user