[Grappler] Extend ConstantPushDown optimization to also handle BiasAdd when it is safe to do so.

Fix a bug in existing constant push down: We need to clear input properties when swapping inputs.

PiperOrigin-RevId: 274690178
This commit is contained in:
A. Unique TensorFlower 2019-10-14 16:48:12 -07:00 committed by TensorFlower Gardener
parent b3cbdd5b68
commit 025e871a4a
3 changed files with 386 additions and 139 deletions

View File

@ -1549,7 +1549,9 @@ Status ConstantFolding::FoldGraph(
std::unordered_set<string> processed_nodes;
std::deque<NodeDef*> queue;
for (int i = 0; i < graph_->node_size(); i++) {
if (IsFoldable(graph_->node(i), &properties)) {
bool foldable = IsFoldable(graph_->node(i), &properties);
VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable;
if (foldable) {
queue.push_back(graph_->mutable_node(i));
}
}
@ -1988,7 +1990,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
*properties, use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
ConstantPushDown(*properties, optimized_graph, node));
ConstantPushDown(properties, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
MulConvPushDown(optimized_graph, node, *properties));
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
@ -1998,6 +2000,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
MergeConcat(use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
ConstantPushDownBiasAdd(properties, optimized_graph, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
@ -2841,7 +2845,197 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
return false;
}
bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
bool ConstantFolding::PrepareConstantPushDown(
const NodeDef& parent, const GraphProperties& properties,
bool must_have_properties, ConstantPushDownContext* ctx) const {
if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
return false;
}
NodeDef* left_child = node_map_->GetNode(parent.input(0));
NodeDef* right_child = node_map_->GetNode(parent.input(1));
ctx->left_child_is_const = IsReallyConstant(*left_child);
ctx->right_child_is_const = IsReallyConstant(*right_child);
ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
// Nothing to do unless the parent has a constant child node.
if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
return false;
}
// Don't move nodes across devices.
if (parent.device() != ctx->op_child->device() ||
parent.device() != ctx->const_child->device()) {
return false;
}
// Make sure that it is safe to change the value of the child node result.
if (ctx->op_child->input_size() < 2 ||
nodes_to_preserve_.find(ctx->op_child->name()) !=
nodes_to_preserve_.end() ||
NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
return false;
}
// Don't apply reassociation to floating point types of low precision.
// The danger of significant numerical changes is too high.
if (!CheckAttrExists(parent, "T").ok()) return false;
DataType dtype = parent.attr().at("T").type();
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
return false;
}
// Don't rewrite the tree if it might create cycles.
// TODO(rmlarsen): Add back handling of control dependency from op to C.
const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
if (child_output.find(ctx->const_child) != child_output.end()) {
return false;
}
// Get leaf nodes.
ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
// Child is already foldable, leave it alone.
return false;
}
// Don't move nodes across devices.
if (parent.device() != ctx->left_leaf->device() ||
parent.device() != ctx->right_leaf->device()) {
return false;
}
// Get shape and type information.
ctx->parent_input_props = &properties.GetInputProperties(parent.name());
ctx->op_child_input_props =
&properties.GetInputProperties(ctx->op_child->name());
if (must_have_properties && (ctx->parent_input_props == nullptr ||
ctx->parent_input_props->size() < 2 ||
ctx->op_child_input_props == nullptr ||
ctx->op_child_input_props->size() < 2)) {
return false;
}
VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
<< parent.op() << "(" << left_child->op() << ", " << right_child->op()
<< ")";
return true;
}
bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
GraphDef* optimized_graph,
NodeDef* node) {
// This implements constant push-down for BiasAdd. In the following "CV" is a
// constant vector (tensor of rank 1), "V" is a (possibly) non-constant
// vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
// non-constant matrix, and "BA" is BiasAdd.
// For a valid input graph, the following 4 rewrites are legal:
//
// 1) + +
// / \ / \
// BA CV -- > BA V
// / \ / \
// M V M CV
//
// 2) + +
// / \ / \
// BA CM -- > BA M
// / \ / \
// M V CM V
//
// 3) BA BA
// / \ / \
// + CV -- > + V
// / \ / \
// M V M CV
//
// 4) BA BA = parent
// / \ / \
// BA CV -- > BA V = children
// / \ / \
// M V M CV = leaves
//
// Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
const bool parent_is_bias_add = IsBiasAdd(*node);
if (!parent_is_bias_add && !IsAdd(*node)) return false;
ConstantPushDownContext ctx;
if (!PrepareConstantPushDown(*node, *properties,
/*must_have_properties=*/true, &ctx)) {
return false;
}
// Special case for BiasAdd: Since the left argument to BiasAdd must be rank
// >= 2 and the leaves must be vectors, we cannot swap them.
if (ctx.left_child_is_const && parent_is_bias_add) return false;
const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
// Get properties to validate rank and dtype constraints.
if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
(*ctx.parent_input_props)[0].shape().unknown_rank() ||
(*ctx.parent_input_props)[1].shape().unknown_rank() ||
(*ctx.op_child_input_props)[0].shape().unknown_rank() ||
(*ctx.op_child_input_props)[1].shape().unknown_rank()) {
return false;
}
// Now get the ranks and types of the 3 leaf nodes.
const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
// At least one leaf must be a vector.
if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
const int matrix_idx = 1 - vector_idx;
const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
if (vector_rank != 1) return false; // this should never happen.
const DataType vector_type = vector_prop.dtype();
const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
const int matrix_rank = matrix_prop.shape().dim_size();
const DataType matrix_type = matrix_prop.dtype();
const int const_idx = ctx.left_child_is_const ? 0 : 1;
const auto& const_prop = (*ctx.parent_input_props)[const_idx];
const int const_rank = const_prop.shape().dim_size();
const DataType const_type = const_prop.dtype();
int input_to_swap = -1;
if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
const_type == matrix_type) {
// Case 2:
input_to_swap = matrix_idx;
} else if (const_rank == 1 && const_type == vector_type) {
// Case 1, 3, and, 4:
input_to_swap = vector_idx;
}
if (input_to_swap == -1) return false;
node_map_->UpdateInput(node->name(), node->input(const_idx),
ctx.op_child->input(input_to_swap));
node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
if (ctx.op_child->input(input_to_swap) !=
ctx.op_child->input(1 - input_to_swap)) {
node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
ctx.op_child->name());
}
std::swap(*node->mutable_input(const_idx),
*ctx.op_child->mutable_input(input_to_swap));
properties->ClearInputProperties(node->name());
properties->ClearInputProperties(ctx.op_child->name());
return true;
}
bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
GraphDef* optimized_graph,
NodeDef* node) {
// Consider the transformation
@ -2864,141 +3058,78 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
// by rotating the tree locally, e.g.
// Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
// Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
//
// Note: Don't touch BiasAdd since they can't handle vectors as their first
// inputs.
// Get parent op type.
const bool is_add = IsAdd(*node);
const bool is_mul = IsMul(*node);
const bool is_sub = IsSub(*node);
const bool is_div = IsDiv(*node);
if (!(is_add || is_sub || is_mul || is_div)) return false;
const bool is_symmetric = is_add || is_mul;
if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
NumNonControlInputs(*node) != 2) {
ConstantPushDownContext ctx;
if (!PrepareConstantPushDown(*node, *properties,
/*must_have_properties=*/false, &ctx)) {
return false;
}
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
const bool left_child_is_constant = IsReallyConstant(*left_child);
const bool right_child_is_constant = IsReallyConstant(*right_child);
if (!left_child_is_constant && !right_child_is_constant) {
return false;
}
// Don't move nodes across devices.
if (node->device() != left_child->device() ||
node->device() != right_child->device()) {
return false;
}
NodeDef* op_child = left_child_is_constant ? right_child : left_child;
NodeDef* const_child = left_child_is_constant ? left_child : right_child;
// Don't rewrite the tree if it might create cycles.
// TODO(rmlarsen): Add back handling of control dependency from op to C.
const auto& child_output = node_map_->GetOutputs(op_child->name());
if (child_output.find(const_child) != child_output.end()) {
return false;
}
// Get child op type.
const bool is_child_add = IsAdd(*op_child);
const bool is_child_mul = IsMul(*op_child);
const bool is_child_sub = IsSub(*op_child);
const bool is_child_div = IsDiv(*op_child);
const bool is_child_add = IsAdd(*ctx.op_child);
const bool is_child_mul = IsMul(*ctx.op_child);
const bool is_child_sub = IsSub(*ctx.op_child);
const bool is_child_div = IsDiv(*ctx.op_child);
const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
if (!is_add_sub && !is_mul_div) {
return false;
}
const bool is_child_symmetric = is_child_add || is_child_mul;
// Make sure that it is safe to change the value of the child node result.
if (op_child->input_size() < 2 ||
nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
NumNonControlOutputs(*op_child, *node_map_) > 1) {
return false;
}
// Do not rewrite integer expressions with subtraction or division.
// if (node->name().find("filter_boxes") != std::string::npos) return false;
if (!CheckAttrExists(*node, "T").ok()) return false;
DataType dtype = node->attr().at("T").type();
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
// Don't apply reassociation to floating point types of low precision.
// The danger of significant numerical changes is too high.
return false;
}
if (!(is_symmetric && is_child_symmetric) &&
!(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
return false;
}
const NodeDef* y_node =
ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
!ctx.op_child_input_props->empty()) {
// If we know the shapes of the nodes being swapped, make sure we don't push
// down a larger node and create more work by broadcasting earlier in the
// expressions tree.
const PartialTensorShape c_shape(
(*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
const PartialTensorShape x_shape(
(*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
// Identify the nodes to swap.
NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
// Child is already foldable, leave it alone.
return false;
}
// Don't move nodes across devices.
if (node->device() != left_leaf->device() ||
node->device() != right_leaf->device()) {
return false;
}
// Get the node names corresponding to X, Y, and C.
const string input_x =
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
const string input_y =
input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
const string input_c =
left_child_is_constant ? node->input(0) : node->input(1);
const string input_op =
left_child_is_constant ? node->input(1) : node->input(0);
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
const NodeDef* y_node = left_leaf_is_constant ? left_leaf : right_leaf;
if (!IsReallyConstant(*y_node)) {
// Now make sure that we do not push a tensor that is larger than the tensor
// it replaces down, since that would create more broadcasting and increase
// work.
const std::vector<OpInfo::TensorProperties>& root_props =
properties.GetInputProperties(node->name());
const std::vector<OpInfo::TensorProperties>& op_props =
properties.GetInputProperties(op_child->name());
if (!root_props.empty() && !op_props.empty()) {
DCHECK_EQ(2, root_props.size()) << node->DebugString();
DCHECK_EQ(2, op_props.size()) << op_child->DebugString();
const PartialTensorShape c_shape(
root_props[left_child_is_constant ? 0 : 1].shape());
const PartialTensorShape x_shape(
op_props[left_leaf_is_constant ? 0 : 1].shape());
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
c_shape.num_elements() > x_shape.num_elements()) {
return false;
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
c_shape.dims() > 0) {
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims());
++idx) {
if (x_shape.dim_size(idx) >= 0 &&
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
return false;
}
if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
c_shape.num_elements() > x_shape.num_elements()) {
return false;
} else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
c_shape.dims() > 0) {
for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
if (x_shape.dim_size(idx) >= 0 &&
c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
return false;
}
}
}
}
// Now we have identified the nodes to swap (non_const_leaf_input and
// const_child).
// Get the node names corresponding to X, Y, and C.
const string input_x =
ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
const string input_y = input_x == ctx.op_child->input(0)
? ctx.op_child->input(1)
: ctx.op_child->input(0);
const string input_c =
ctx.left_child_is_const ? node->input(0) : node->input(1);
const string input_op =
ctx.left_child_is_const ? node->input(1) : node->input(0);
VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
// Now we have identified the nodes to swap, updare the nodemap accordingly.
node_map_->UpdateInput(node->name(), input_c, input_x);
node_map_->AddOutput(input_c, op_child->name());
node_map_->AddOutput(input_c, ctx.op_child->name());
if (input_x != input_y) {
node_map_->RemoveOutput(input_x, op_child->name());
node_map_->RemoveOutput(input_x, ctx.op_child->name());
}
properties->ClearInputProperties(node->name());
properties->ClearInputProperties(ctx.op_child->name());
if (is_symmetric && is_child_symmetric) {
// Easy case (only commutative ops). We always write this as one of
@ -3009,8 +3140,8 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
// C Y
node->set_input(0, input_x);
node->set_input(1, input_op);
op_child->set_input(0, input_c);
op_child->set_input(1, input_y);
ctx.op_child->set_input(0, input_c);
ctx.op_child->set_input(1, input_y);
} else {
// More complicated case: When there are non-commutative operations like
// subtractions or divisions involved, we may have to rotate the tree
@ -3019,6 +3150,7 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
// Here are the final trees we want to generate for those 6 cases:
//
// (CYX signs): ++- +-- -+- --+ +-+ -++
//
// - - - - + +
// / \ / \ / \ / \ / \ / \
// + X - X - X X + X - X -
@ -3030,22 +3162,22 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties,
// expression
auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
bool leaf_negated = !is_child_symmetric && is_right_leaf;
bool child_negated = !is_symmetric && (op_child == right_child);
bool child_negated = !is_symmetric && (ctx.left_child_is_const);
return leaf_negated != child_negated;
};
const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
bool neg_c = !is_symmetric && (const_child == right_child);
bool neg_x = is_leaf_negated(left_leaf_is_constant);
bool neg_y = is_leaf_negated(!left_leaf_is_constant);
bool neg_c = !is_symmetric && !ctx.left_child_is_const;
bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
// Rewrite the parent node.
node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
node->set_input(0, neg_x ? input_op : input_x);
node->set_input(1, neg_x ? input_x : input_op);
// Rewrite the child node.
op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
op_child->set_input(0, neg_c ? input_y : input_c);
op_child->set_input(1, neg_c ? input_c : input_y);
ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
ctx.op_child->set_input(0, neg_c ? input_y : input_c);
ctx.op_child->set_input(1, neg_c ? input_c : input_y);
}
return true;
}
@ -3060,6 +3192,9 @@ bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
// X C1 C1 C2
//
// where C1 and C2 are constants and X is non-constant.
//
// TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
NodeDef* mul_left_child = node_map_->GetNode(node->input(0));

View File

@ -143,10 +143,48 @@ class ConstantFolding : public GraphOptimizer {
// Returns true if the transformation applied successfully.
bool PartialConstPropThroughIdentityN(NodeDef* node);
// Pushes down constants on '+' and '*' operators if applicable. Returns true
// the transformation applied successfully.
bool ConstantPushDown(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
struct ConstantPushDownContext {
NodeDef* op_child;
NodeDef* const_child;
bool left_child_is_const;
bool right_child_is_const;
NodeDef* left_leaf;
NodeDef* right_leaf;
bool left_leaf_is_const;
bool right_leaf_is_const;
// Shape & type information.
const std::vector<OpInfo::TensorProperties>* parent_input_props;
const std::vector<OpInfo::TensorProperties>* op_child_input_props;
};
// Populates ctx with pointers to the nodes in expression tree for which
// constant pushdown optimization is being considered, corresponding to one of
// the following configurations:
//
// parent parent
// / \ / \
// op_child const_child const_child op_child
// / \ / \
// left_leaf right_leaf left_leaf right_leaf
//
// Returns true if the expression is possible amenable for optimization.
// Returns false if must_have_properties is true and input properties for
// parent and op_child are not known.
bool PrepareConstantPushDown(const NodeDef& parent,
const GraphProperties& properties,
bool must_have_properties,
ConstantPushDownContext* ctx) const;
// Pushes down constants on '+', '-', '*', and '/' operators if applicable.
// Returns true if the transformation applied successfully.
bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph,
NodeDef* node);
// Pushes down constants on '+' and 'BiasAdd' operators if applicable.
// Returns true if the graph was modified.
bool ConstantPushDownBiasAdd(GraphProperties* properties,
GraphDef* optimized_graph, NodeDef* node);
// Aggregate constants present around a conv operator. Returns true if the
// transformation was applied successfully.

View File

@ -400,7 +400,7 @@ TEST_F(ConstantFoldingTest, AddSubtactTree) {
}
}
TEST_F(ConstantFoldingTest, TreeCanonicalization) {
TEST_F(ConstantFoldingTest, ConstantPushDown) {
for (int is_add : {true, false}) {
for (int is_parent_commutative : {true, false}) {
for (int is_child_commutative : {true, false}) {
@ -413,28 +413,32 @@ TEST_F(ConstantFoldingTest, TreeCanonicalization) {
ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
auto get_op = [&](bool is_commutative, bool is_left_arg_cont,
auto get_op = [&](bool is_commutative, bool is_left_arg_const,
const string& name, const Output& const_arg,
const Output non_const_arg) -> Output {
if (is_add) {
if (is_commutative) {
return ops::Add(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
return ops::Add(
s.WithOpName(name),
is_left_arg_const ? const_arg : non_const_arg,
is_left_arg_const ? non_const_arg : const_arg);
} else {
return ops::Sub(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
return ops::Sub(
s.WithOpName(name),
is_left_arg_const ? const_arg : non_const_arg,
is_left_arg_const ? non_const_arg : const_arg);
}
} else {
if (is_commutative) {
return ops::Mul(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
return ops::Mul(
s.WithOpName(name),
is_left_arg_const ? const_arg : non_const_arg,
is_left_arg_const ? non_const_arg : const_arg);
} else {
return ops::Div(s.WithOpName(name),
is_left_arg_cont ? const_arg : non_const_arg,
is_left_arg_cont ? non_const_arg : const_arg);
return ops::Div(
s.WithOpName(name),
is_left_arg_const ? const_arg : non_const_arg,
is_left_arg_const ? non_const_arg : const_arg);
}
}
};
@ -472,6 +476,76 @@ TEST_F(ConstantFoldingTest, TreeCanonicalization) {
}
}
TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2});
Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2});
Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2, 2})));
Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT,
ops::Placeholder::Shape(TensorShape({2})));
// Rewrite expected for cases 1 through 3 and their symmetric equivalents,
// and case 4.
Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec);
Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec);
Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec);
Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a);
Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec);
Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat);
Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec);
Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a);
Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec);
Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec);
Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat);
Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec);
Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec);
Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec);
// No rewrite expected.
Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec);
Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5);
Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec);
Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6);
GrapplerItem item;
item.fetch = {"parent1", "parent2", "parent3", "parent1a", "parent2a",
"parent3a", "parent4", "parent5", "parent6"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef output;
Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(22, output.node_size());
for (const auto& node : output.node()) {
if (node.name() == "child1" || node.name() == "child1a" ||
node.name() == "child2" || node.name() == "child2a" ||
node.name() == "child3" || node.name() == "child3a" ||
node.name() == "child4") {
EXPECT_EQ(node.op(), "Const") << " node: " << node.name();
} else if (node.name() != "c_mat" && node.name() != "c_vec") {
EXPECT_NE(node.op(), "Const") << " node: " << node.name();
}
}
// Check that the result nodes have the expected value.
auto x_mat_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
auto x_vec_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2}));
std::vector<string> fetch = item.fetch;
auto tensor_expected = EvaluateNodes(
item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
ASSERT_EQ(fetch.size(), tensor_expected.size());
auto tensors =
EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}});
ASSERT_EQ(fetch.size(), tensors.size());
for (int i = 0; i < fetch.size(); i++) {
test::ExpectTensorEqual<float>(tensor_expected[i], tensors[i]);
}
}
TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) {
for (string data_format : {
"NHWC",