Don't push down constant tensors if they increase the amount of broadcasting, unless they are guaranteed to be constant folded.

PiperOrigin-RevId: 267185933
This commit is contained in:
A. Unique TensorFlower 2019-09-04 10:51:02 -07:00 committed by TensorFlower Gardener
parent 357615e12a
commit b883c28ffe
2 changed files with 40 additions and 4 deletions

View File

@ -1987,7 +1987,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
*properties, use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
ConstantPushDown(*properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
MulConvPushDown(optimized_graph, node, *properties));
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
@ -2840,7 +2841,8 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
return false;
}
bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
GraphDef* optimized_graph,
NodeDef* node) {
// Consider the transformation
//
@ -2853,7 +2855,7 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
// where C is constant, X is non-constant, Y may be constant or non-constant,
// and '+' denotes an associative and commutative operator like addition or
// multiplication. This optimization pushes constants down in the tree to
// canonicalize it. Moreoever, in cases where the child node has a second
// canonicalize it. Moreover, in cases where the child node has a second
// constant input Y we will create a leaf node that can be folded, e.g.
//
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
@ -2944,6 +2946,7 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
node->device() != right_leaf->device()) {
return false;
}
// Get the node names corresponding to X, Y, and C.
const string input_x =
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
@ -2957,6 +2960,38 @@ bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
const NodeDef* y_node = left_leaf_is_constant ? left_leaf : right_leaf;
if (!IsReallyConstant(*y_node)) {
// Now make sure that we do not push a tensor that is larger than the tensor
// it replaces down, since that would create more broadcasting and increase
// work.
const std::vector<OpInfo::TensorProperties>& root_props =
properties.GetInputProperties(node->name());
const std::vector<OpInfo::TensorProperties>& op_props =
properties.GetInputProperties(op_child->name());
if (!root_props.empty() && !op_props.empty()) {
DCHECK_EQ(2, root_props.size()) << node->DebugString();
DCHECK_EQ(2, op_props.size()) << op_child->DebugString();
const PartialTensorShape c_shape(
root_props[left_child_is_constant ? 0 : 1].shape());
const PartialTensorShape x_shape(
op_props[left_leaf_is_constant ? 0 : 1].shape());
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
c_shape.num_elements() > x_shape.num_elements()) {
return false;
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
c_shape.dims() > 0) {
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims());
++idx) {
if (x_shape.dim_size(idx) >= 0 &&
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
return false;
}
}
}
}
}
// Now we have identified the nodes to swap (non_const_leaf_input and
// const_child).
node_map_->UpdateInput(node->name(), input_c, input_x);

View File

@ -145,7 +145,8 @@ class ConstantFolding : public GraphOptimizer {
// Pushes down constants on '+' and '*' operators if applicable. Returns true
// the transformation applied successfully.
bool ConstantPushDown(GraphDef* optimized_graph, NodeDef* node);
bool ConstantPushDown(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
// Aggregate constants present around a conv operator. Returns true if the
// transformation was applied successfully.