[Grappler] Extend ConstantPushDown optimization to also handle BiasAdd when it is safe to do so.
Fix a bug in existing constant push down: We need to clear input properties when swapping inputs. PiperOrigin-RevId: 274690178
This commit is contained in:
parent
b3cbdd5b68
commit
025e871a4a
@ -1549,7 +1549,9 @@ Status ConstantFolding::FoldGraph(
|
||||
std::unordered_set<string> processed_nodes;
|
||||
std::deque<NodeDef*> queue;
|
||||
for (int i = 0; i < graph_->node_size(); i++) {
|
||||
if (IsFoldable(graph_->node(i), &properties)) {
|
||||
bool foldable = IsFoldable(graph_->node(i), &properties);
|
||||
VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable;
|
||||
if (foldable) {
|
||||
queue.push_back(graph_->mutable_node(i));
|
||||
}
|
||||
}
|
||||
@ -1988,7 +1990,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
*properties, use_shape_info, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
ConstantPushDown(*properties, optimized_graph, node));
|
||||
ConstantPushDown(properties, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
MulConvPushDown(optimized_graph, node, *properties));
|
||||
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
|
||||
@ -1998,6 +2000,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
MergeConcat(use_shape_info, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialConcatConstFolding(optimized_graph, properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
ConstantPushDownBiasAdd(properties, optimized_graph, node));
|
||||
|
||||
graph_modified_ = graph_modified_cached;
|
||||
return Status::OK();
|
||||
@ -2841,7 +2845,197 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||
bool ConstantFolding::PrepareConstantPushDown(
|
||||
const NodeDef& parent, const GraphProperties& properties,
|
||||
bool must_have_properties, ConstantPushDownContext* ctx) const {
|
||||
if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeDef* left_child = node_map_->GetNode(parent.input(0));
|
||||
NodeDef* right_child = node_map_->GetNode(parent.input(1));
|
||||
ctx->left_child_is_const = IsReallyConstant(*left_child);
|
||||
ctx->right_child_is_const = IsReallyConstant(*right_child);
|
||||
ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
|
||||
ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
|
||||
|
||||
// Nothing to do unless the parent has a constant child node.
|
||||
if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't move nodes across devices.
|
||||
if (parent.device() != ctx->op_child->device() ||
|
||||
parent.device() != ctx->const_child->device()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Make sure that it is safe to change the value of the child node result.
|
||||
if (ctx->op_child->input_size() < 2 ||
|
||||
nodes_to_preserve_.find(ctx->op_child->name()) !=
|
||||
nodes_to_preserve_.end() ||
|
||||
NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't apply reassociation to floating point types of low precision.
|
||||
// The danger of significant numerical changes is too high.
|
||||
if (!CheckAttrExists(parent, "T").ok()) return false;
|
||||
DataType dtype = parent.attr().at("T").type();
|
||||
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't rewrite the tree if it might create cycles.
|
||||
// TODO(rmlarsen): Add back handling of control dependency from op to C.
|
||||
const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
|
||||
if (child_output.find(ctx->const_child) != child_output.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get leaf nodes.
|
||||
ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
|
||||
ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
|
||||
ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
|
||||
ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
|
||||
|
||||
if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
|
||||
// Child is already foldable, leave it alone.
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't move nodes across devices.
|
||||
if (parent.device() != ctx->left_leaf->device() ||
|
||||
parent.device() != ctx->right_leaf->device()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get shape and type information.
|
||||
ctx->parent_input_props = &properties.GetInputProperties(parent.name());
|
||||
ctx->op_child_input_props =
|
||||
&properties.GetInputProperties(ctx->op_child->name());
|
||||
if (must_have_properties && (ctx->parent_input_props == nullptr ||
|
||||
ctx->parent_input_props->size() < 2 ||
|
||||
ctx->op_child_input_props == nullptr ||
|
||||
ctx->op_child_input_props->size() < 2)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
|
||||
<< parent.op() << "(" << left_child->op() << ", " << right_child->op()
|
||||
<< ")";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
// This implements constant push-down for BiasAdd. In the following "CV" is a
|
||||
// constant vector (tensor of rank 1), "V" is a (possibly) non-constant
|
||||
// vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
|
||||
// non-constant matrix, and "BA" is BiasAdd.
|
||||
// For a valid input graph, the following 4 rewrites are legal:
|
||||
//
|
||||
// 1) + +
|
||||
// / \ / \
|
||||
// BA CV -- > BA V
|
||||
// / \ / \
|
||||
// M V M CV
|
||||
//
|
||||
// 2) + +
|
||||
// / \ / \
|
||||
// BA CM -- > BA M
|
||||
// / \ / \
|
||||
// M V CM V
|
||||
//
|
||||
// 3) BA BA
|
||||
// / \ / \
|
||||
// + CV -- > + V
|
||||
// / \ / \
|
||||
// M V M CV
|
||||
//
|
||||
// 4) BA BA = parent
|
||||
// / \ / \
|
||||
// BA CV -- > BA V = children
|
||||
// / \ / \
|
||||
// M V M CV = leaves
|
||||
//
|
||||
// Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
|
||||
|
||||
const bool parent_is_bias_add = IsBiasAdd(*node);
|
||||
if (!parent_is_bias_add && !IsAdd(*node)) return false;
|
||||
ConstantPushDownContext ctx;
|
||||
if (!PrepareConstantPushDown(*node, *properties,
|
||||
/*must_have_properties=*/true, &ctx)) {
|
||||
return false;
|
||||
}
|
||||
// Special case for BiasAdd: Since the left argument to BiasAdd must be rank
|
||||
// >= 2 and the leaves must be vectors, we cannot swap them.
|
||||
if (ctx.left_child_is_const && parent_is_bias_add) return false;
|
||||
const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
|
||||
if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
|
||||
|
||||
// Get properties to validate rank and dtype constraints.
|
||||
if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
|
||||
(*ctx.parent_input_props)[0].shape().unknown_rank() ||
|
||||
(*ctx.parent_input_props)[1].shape().unknown_rank() ||
|
||||
(*ctx.op_child_input_props)[0].shape().unknown_rank() ||
|
||||
(*ctx.op_child_input_props)[1].shape().unknown_rank()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Now get the ranks and types of the 3 leaf nodes.
|
||||
const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
|
||||
const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
|
||||
// At least one leaf must be a vector.
|
||||
if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
|
||||
const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
|
||||
const int matrix_idx = 1 - vector_idx;
|
||||
|
||||
const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
|
||||
const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
|
||||
if (vector_rank != 1) return false; // this should never happen.
|
||||
const DataType vector_type = vector_prop.dtype();
|
||||
|
||||
const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
|
||||
const int matrix_rank = matrix_prop.shape().dim_size();
|
||||
const DataType matrix_type = matrix_prop.dtype();
|
||||
|
||||
const int const_idx = ctx.left_child_is_const ? 0 : 1;
|
||||
const auto& const_prop = (*ctx.parent_input_props)[const_idx];
|
||||
const int const_rank = const_prop.shape().dim_size();
|
||||
const DataType const_type = const_prop.dtype();
|
||||
|
||||
int input_to_swap = -1;
|
||||
|
||||
if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
|
||||
const_type == matrix_type) {
|
||||
// Case 2:
|
||||
input_to_swap = matrix_idx;
|
||||
} else if (const_rank == 1 && const_type == vector_type) {
|
||||
// Case 1, 3, and, 4:
|
||||
input_to_swap = vector_idx;
|
||||
}
|
||||
if (input_to_swap == -1) return false;
|
||||
|
||||
node_map_->UpdateInput(node->name(), node->input(const_idx),
|
||||
ctx.op_child->input(input_to_swap));
|
||||
node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
|
||||
if (ctx.op_child->input(input_to_swap) !=
|
||||
ctx.op_child->input(1 - input_to_swap)) {
|
||||
node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
|
||||
ctx.op_child->name());
|
||||
}
|
||||
std::swap(*node->mutable_input(const_idx),
|
||||
*ctx.op_child->mutable_input(input_to_swap));
|
||||
properties->ClearInputProperties(node->name());
|
||||
properties->ClearInputProperties(ctx.op_child->name());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
// Consider the transformation
|
||||
@ -2864,141 +3058,78 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||
// by rotating the tree locally, e.g.
|
||||
// Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
|
||||
// Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
|
||||
//
|
||||
// Note: Don't touch BiasAdd since they can't handle vectors as their first
|
||||
// inputs.
|
||||
|
||||
// Get parent op type.
|
||||
const bool is_add = IsAdd(*node);
|
||||
const bool is_mul = IsMul(*node);
|
||||
const bool is_sub = IsSub(*node);
|
||||
const bool is_div = IsDiv(*node);
|
||||
if (!(is_add || is_sub || is_mul || is_div)) return false;
|
||||
const bool is_symmetric = is_add || is_mul;
|
||||
if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
|
||||
NumNonControlInputs(*node) != 2) {
|
||||
|
||||
ConstantPushDownContext ctx;
|
||||
if (!PrepareConstantPushDown(*node, *properties,
|
||||
/*must_have_properties=*/false, &ctx)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeDef* left_child = node_map_->GetNode(node->input(0));
|
||||
NodeDef* right_child = node_map_->GetNode(node->input(1));
|
||||
|
||||
const bool left_child_is_constant = IsReallyConstant(*left_child);
|
||||
const bool right_child_is_constant = IsReallyConstant(*right_child);
|
||||
if (!left_child_is_constant && !right_child_is_constant) {
|
||||
return false;
|
||||
}
|
||||
// Don't move nodes across devices.
|
||||
if (node->device() != left_child->device() ||
|
||||
node->device() != right_child->device()) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* op_child = left_child_is_constant ? right_child : left_child;
|
||||
NodeDef* const_child = left_child_is_constant ? left_child : right_child;
|
||||
// Don't rewrite the tree if it might create cycles.
|
||||
// TODO(rmlarsen): Add back handling of control dependency from op to C.
|
||||
const auto& child_output = node_map_->GetOutputs(op_child->name());
|
||||
if (child_output.find(const_child) != child_output.end()) {
|
||||
return false;
|
||||
}
|
||||
// Get child op type.
|
||||
const bool is_child_add = IsAdd(*op_child);
|
||||
const bool is_child_mul = IsMul(*op_child);
|
||||
const bool is_child_sub = IsSub(*op_child);
|
||||
const bool is_child_div = IsDiv(*op_child);
|
||||
const bool is_child_add = IsAdd(*ctx.op_child);
|
||||
const bool is_child_mul = IsMul(*ctx.op_child);
|
||||
const bool is_child_sub = IsSub(*ctx.op_child);
|
||||
const bool is_child_div = IsDiv(*ctx.op_child);
|
||||
const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
|
||||
const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
|
||||
if (!is_add_sub && !is_mul_div) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_child_symmetric = is_child_add || is_child_mul;
|
||||
// Make sure that it is safe to change the value of the child node result.
|
||||
if (op_child->input_size() < 2 ||
|
||||
nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
|
||||
NumNonControlOutputs(*op_child, *node_map_) > 1) {
|
||||
return false;
|
||||
}
|
||||
// Do not rewrite integer expressions with subtraction or division.
|
||||
// if (node->name().find("filter_boxes") != std::string::npos) return false;
|
||||
if (!CheckAttrExists(*node, "T").ok()) return false;
|
||||
DataType dtype = node->attr().at("T").type();
|
||||
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
|
||||
// Don't apply reassociation to floating point types of low precision.
|
||||
// The danger of significant numerical changes is too high.
|
||||
return false;
|
||||
}
|
||||
if (!(is_symmetric && is_child_symmetric) &&
|
||||
!(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
|
||||
return false;
|
||||
}
|
||||
const NodeDef* y_node =
|
||||
ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
|
||||
if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
|
||||
!ctx.op_child_input_props->empty()) {
|
||||
// If we know the shapes of the nodes being swapped, make sure we don't push
|
||||
// down a larger node and create more work by broadcasting earlier in the
|
||||
// expressions tree.
|
||||
const PartialTensorShape c_shape(
|
||||
(*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
|
||||
const PartialTensorShape x_shape(
|
||||
(*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
|
||||
|
||||
// Identify the nodes to swap.
|
||||
NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
|
||||
NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
|
||||
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
|
||||
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
|
||||
if (left_leaf_is_constant && right_leaf_is_constant) {
|
||||
// Child is already foldable, leave it alone.
|
||||
return false;
|
||||
}
|
||||
// Don't move nodes across devices.
|
||||
if (node->device() != left_leaf->device() ||
|
||||
node->device() != right_leaf->device()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the node names corresponding to X, Y, and C.
|
||||
const string input_x =
|
||||
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
|
||||
const string input_y =
|
||||
input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
|
||||
const string input_c =
|
||||
left_child_is_constant ? node->input(0) : node->input(1);
|
||||
const string input_op =
|
||||
left_child_is_constant ? node->input(1) : node->input(0);
|
||||
|
||||
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
|
||||
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
|
||||
|
||||
const NodeDef* y_node = left_leaf_is_constant ? left_leaf : right_leaf;
|
||||
if (!IsReallyConstant(*y_node)) {
|
||||
// Now make sure that we do not push a tensor that is larger than the tensor
|
||||
// it replaces down, since that would create more broadcasting and increase
|
||||
// work.
|
||||
const std::vector<OpInfo::TensorProperties>& root_props =
|
||||
properties.GetInputProperties(node->name());
|
||||
const std::vector<OpInfo::TensorProperties>& op_props =
|
||||
properties.GetInputProperties(op_child->name());
|
||||
if (!root_props.empty() && !op_props.empty()) {
|
||||
DCHECK_EQ(2, root_props.size()) << node->DebugString();
|
||||
DCHECK_EQ(2, op_props.size()) << op_child->DebugString();
|
||||
const PartialTensorShape c_shape(
|
||||
root_props[left_child_is_constant ? 0 : 1].shape());
|
||||
const PartialTensorShape x_shape(
|
||||
op_props[left_leaf_is_constant ? 0 : 1].shape());
|
||||
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
|
||||
c_shape.num_elements() > x_shape.num_elements()) {
|
||||
return false;
|
||||
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
|
||||
c_shape.dims() > 0) {
|
||||
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims());
|
||||
++idx) {
|
||||
if (x_shape.dim_size(idx) >= 0 &&
|
||||
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
|
||||
return false;
|
||||
}
|
||||
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
|
||||
c_shape.num_elements() > x_shape.num_elements()) {
|
||||
return false;
|
||||
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
|
||||
c_shape.dims() > 0) {
|
||||
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
|
||||
if (x_shape.dim_size(idx) >= 0 &&
|
||||
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now we have identified the nodes to swap (non_const_leaf_input and
|
||||
// const_child).
|
||||
// Get the node names corresponding to X, Y, and C.
|
||||
const string input_x =
|
||||
ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
|
||||
const string input_y = input_x == ctx.op_child->input(0)
|
||||
? ctx.op_child->input(1)
|
||||
: ctx.op_child->input(0);
|
||||
const string input_c =
|
||||
ctx.left_child_is_const ? node->input(0) : node->input(1);
|
||||
const string input_op =
|
||||
ctx.left_child_is_const ? node->input(1) : node->input(0);
|
||||
VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
|
||||
|
||||
// Now we have identified the nodes to swap, updare the nodemap accordingly.
|
||||
node_map_->UpdateInput(node->name(), input_c, input_x);
|
||||
node_map_->AddOutput(input_c, op_child->name());
|
||||
node_map_->AddOutput(input_c, ctx.op_child->name());
|
||||
if (input_x != input_y) {
|
||||
node_map_->RemoveOutput(input_x, op_child->name());
|
||||
node_map_->RemoveOutput(input_x, ctx.op_child->name());
|
||||
}
|
||||
properties->ClearInputProperties(node->name());
|
||||
properties->ClearInputProperties(ctx.op_child->name());
|
||||
|
||||
if (is_symmetric && is_child_symmetric) {
|
||||
// Easy case (only commutative ops). We always write this as one of
|
||||
@ -3009,8 +3140,8 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||
// C Y
|
||||
node->set_input(0, input_x);
|
||||
node->set_input(1, input_op);
|
||||
op_child->set_input(0, input_c);
|
||||
op_child->set_input(1, input_y);
|
||||
ctx.op_child->set_input(0, input_c);
|
||||
ctx.op_child->set_input(1, input_y);
|
||||
} else {
|
||||
// More complicated case: When there are non-commutative operations like
|
||||
// subtractions or divisions involved, we may have to rotate the tree
|
||||
@ -3019,6 +3150,7 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||
// Here are the final trees we want to generate for those 6 cases:
|
||||
//
|
||||
// (CYX signs): ++- +-- -+- --+ +-+ -++
|
||||
//
|
||||
// - - - - + +
|
||||
// / \ / \ / \ / \ / \ / \
|
||||
// + X - X - X X + X - X -
|
||||
@ -3030,22 +3162,22 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
|
||||
// expression
|
||||
auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
|
||||
bool leaf_negated = !is_child_symmetric && is_right_leaf;
|
||||
bool child_negated = !is_symmetric && (op_child == right_child);
|
||||
bool child_negated = !is_symmetric && (ctx.left_child_is_const);
|
||||
return leaf_negated != child_negated;
|
||||
};
|
||||
const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
|
||||
const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
|
||||
bool neg_c = !is_symmetric && (const_child == right_child);
|
||||
bool neg_x = is_leaf_negated(left_leaf_is_constant);
|
||||
bool neg_y = is_leaf_negated(!left_leaf_is_constant);
|
||||
bool neg_c = !is_symmetric && !ctx.left_child_is_const;
|
||||
bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
|
||||
bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
|
||||
// Rewrite the parent node.
|
||||
node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
|
||||
node->set_input(0, neg_x ? input_op : input_x);
|
||||
node->set_input(1, neg_x ? input_x : input_op);
|
||||
// Rewrite the child node.
|
||||
op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
|
||||
op_child->set_input(0, neg_c ? input_y : input_c);
|
||||
op_child->set_input(1, neg_c ? input_c : input_y);
|
||||
ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
|
||||
ctx.op_child->set_input(0, neg_c ? input_y : input_c);
|
||||
ctx.op_child->set_input(1, neg_c ? input_c : input_y);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
@ -3060,6 +3192,9 @@ bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
|
||||
// X C1 C1 C2
|
||||
//
|
||||
// where C1 and C2 are constants and X is non-constant.
|
||||
//
|
||||
// TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
|
||||
|
||||
if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
|
||||
|
||||
NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
|
||||
|
||||
@ -143,10 +143,48 @@ class ConstantFolding : public GraphOptimizer {
|
||||
// Returns true if the transformation applied successfully.
|
||||
bool PartialConstPropThroughIdentityN(NodeDef* node);
|
||||
|
||||
// Pushes down constants on '+' and '*' operators if applicable. Returns true
|
||||
// the transformation applied successfully.
|
||||
bool ConstantPushDown(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
struct ConstantPushDownContext {
|
||||
NodeDef* op_child;
|
||||
NodeDef* const_child;
|
||||
bool left_child_is_const;
|
||||
bool right_child_is_const;
|
||||
NodeDef* left_leaf;
|
||||
NodeDef* right_leaf;
|
||||
bool left_leaf_is_const;
|
||||
bool right_leaf_is_const;
|
||||
|
||||
// Shape & type information.
|
||||
const std::vector<OpInfo::TensorProperties>* parent_input_props;
|
||||
const std::vector<OpInfo::TensorProperties>* op_child_input_props;
|
||||
};
|
||||
|
||||
// Populates ctx with pointers to the nodes in expression tree for which
|
||||
// constant pushdown optimization is being considered, corresponding to one of
|
||||
// the following configurations:
|
||||
//
|
||||
// parent parent
|
||||
// / \ / \
|
||||
// op_child const_child const_child op_child
|
||||
// / \ / \
|
||||
// left_leaf right_leaf left_leaf right_leaf
|
||||
//
|
||||
// Returns true if the expression is possible amenable for optimization.
|
||||
// Returns false if must_have_properties is true and input properties for
|
||||
// parent and op_child are not known.
|
||||
bool PrepareConstantPushDown(const NodeDef& parent,
|
||||
const GraphProperties& properties,
|
||||
bool must_have_properties,
|
||||
ConstantPushDownContext* ctx) const;
|
||||
|
||||
// Pushes down constants on '+', '-', '*', and '/' operators if applicable.
|
||||
// Returns true if the transformation applied successfully.
|
||||
bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph,
|
||||
NodeDef* node);
|
||||
|
||||
// Pushes down constants on '+' and 'BiasAdd' operators if applicable.
|
||||
// Returns true if the graph was modified.
|
||||
bool ConstantPushDownBiasAdd(GraphProperties* properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Aggregate constants present around a conv operator. Returns true if the
|
||||
// transformation was applied successfully.
|
||||
|
||||
@ -400,7 +400,7 @@ TEST_F(ConstantFoldingTest, AddSubtactTree) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, TreeCanonicalization) {
|
||||
TEST_F(ConstantFoldingTest, ConstantPushDown) {
|
||||
for (int is_add : {true, false}) {
|
||||
for (int is_parent_commutative : {true, false}) {
|
||||
for (int is_child_commutative : {true, false}) {
|
||||
@ -413,28 +413,32 @@ TEST_F(ConstantFoldingTest, TreeCanonicalization) {
|
||||
ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
|
||||
auto get_op = [&](bool is_commutative, bool is_left_arg_cont,
|
||||
auto get_op = [&](bool is_commutative, bool is_left_arg_const,
|
||||
const string& name, const Output& const_arg,
|
||||
const Output non_const_arg) -> Output {
|
||||
if (is_add) {
|
||||
if (is_commutative) {
|
||||
return ops::Add(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
return ops::Add(
|
||||
s.WithOpName(name),
|
||||
is_left_arg_const ? const_arg : non_const_arg,
|
||||
is_left_arg_const ? non_const_arg : const_arg);
|
||||
} else {
|
||||
return ops::Sub(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
return ops::Sub(
|
||||
s.WithOpName(name),
|
||||
is_left_arg_const ? const_arg : non_const_arg,
|
||||
is_left_arg_const ? non_const_arg : const_arg);
|
||||
}
|
||||
} else {
|
||||
if (is_commutative) {
|
||||
return ops::Mul(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
return ops::Mul(
|
||||
s.WithOpName(name),
|
||||
is_left_arg_const ? const_arg : non_const_arg,
|
||||
is_left_arg_const ? non_const_arg : const_arg);
|
||||
} else {
|
||||
return ops::Div(s.WithOpName(name),
|
||||
is_left_arg_cont ? const_arg : non_const_arg,
|
||||
is_left_arg_cont ? non_const_arg : const_arg);
|
||||
return ops::Div(
|
||||
s.WithOpName(name),
|
||||
is_left_arg_const ? const_arg : non_const_arg,
|
||||
is_left_arg_const ? non_const_arg : const_arg);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -472,6 +476,76 @@ TEST_F(ConstantFoldingTest, TreeCanonicalization) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2});
|
||||
Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2});
|
||||
Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2, 2})));
|
||||
Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(TensorShape({2})));
|
||||
// Rewrite expected for cases 1 through 3 and their symmetric equivalents,
|
||||
// and case 4.
|
||||
Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec);
|
||||
Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec);
|
||||
Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec);
|
||||
Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a);
|
||||
|
||||
Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec);
|
||||
Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat);
|
||||
Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec);
|
||||
Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a);
|
||||
|
||||
Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec);
|
||||
Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec);
|
||||
Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat);
|
||||
Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec);
|
||||
|
||||
Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec);
|
||||
Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec);
|
||||
|
||||
// No rewrite expected.
|
||||
Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec);
|
||||
Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5);
|
||||
Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec);
|
||||
Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"parent1", "parent2", "parent3", "parent1a", "parent2a",
|
||||
"parent3a", "parent4", "parent5", "parent6"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(22, output.node_size());
|
||||
for (const auto& node : output.node()) {
|
||||
if (node.name() == "child1" || node.name() == "child1a" ||
|
||||
node.name() == "child2" || node.name() == "child2a" ||
|
||||
node.name() == "child3" || node.name() == "child3a" ||
|
||||
node.name() == "child4") {
|
||||
EXPECT_EQ(node.op(), "Const") << " node: " << node.name();
|
||||
} else if (node.name() != "c_mat" && node.name() != "c_vec") {
|
||||
EXPECT_NE(node.op(), "Const") << " node: " << node.name();
|
||||
}
|
||||
}
|
||||
// Check that the result nodes have the expected value.
|
||||
auto x_mat_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
|
||||
auto x_vec_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
|
||||
std::vector<string> fetch = item.fetch;
|
||||
auto tensor_expected = EvaluateNodes(
|
||||
item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
|
||||
ASSERT_EQ(fetch.size(), tensor_expected.size());
|
||||
auto tensors =
|
||||
EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
|
||||
ASSERT_EQ(fetch.size(), tensors.size());
|
||||
for (int i = 0; i < fetch.size(); i++) {
|
||||
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
|
||||
for (string data_format : {
|
||||
"NHWC",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user