From 025e871a4aac109e6e168ff7f2690a9093364199 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 Oct 2019 16:48:12 -0700 Subject: [PATCH] [Grappler] Extend ConstantPushDown optimization to also handle BiasAdd when it is safe to do so. Fix a bug in existing constant push down: We need to clear input properties when swapping inputs. PiperOrigin-RevId: 274690178 --- .../grappler/optimizers/constant_folding.cc | 377 ++++++++++++------ .../grappler/optimizers/constant_folding.h | 46 ++- .../optimizers/constant_folding_test.cc | 102 ++++- 3 files changed, 386 insertions(+), 139 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 37eb4ff3bce..4ac3a611623 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1549,7 +1549,9 @@ Status ConstantFolding::FoldGraph( std::unordered_set processed_nodes; std::deque queue; for (int i = 0; i < graph_->node_size(); i++) { - if (IsFoldable(graph_->node(i), &properties)) { + bool foldable = IsFoldable(graph_->node(i), &properties); + VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable; + if (foldable) { queue.push_back(graph_->mutable_node(i)); } } @@ -1988,7 +1990,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, *properties, use_shape_info, optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED( - ConstantPushDown(*properties, optimized_graph, node)); + ConstantPushDown(properties, optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED( MulConvPushDown(optimized_graph, node, *properties)); SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node)); @@ -1998,6 +2000,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, MergeConcat(use_shape_info, optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED( PartialConcatConstFolding(optimized_graph, properties, node)); + SET_AND_RETURN_IF_MODIFIED( + ConstantPushDownBiasAdd(properties, optimized_graph, node)); graph_modified_ = graph_modified_cached; return Status::OK(); @@ -2841,7 +2845,197 @@ bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph, return false; } -bool ConstantFolding::ConstantPushDown(const GraphProperties& properties, +bool ConstantFolding::PrepareConstantPushDown( + const NodeDef& parent, const GraphProperties& properties, + bool must_have_properties, ConstantPushDownContext* ctx) const { + if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) { + return false; + } + + NodeDef* left_child = node_map_->GetNode(parent.input(0)); + NodeDef* right_child = node_map_->GetNode(parent.input(1)); + ctx->left_child_is_const = IsReallyConstant(*left_child); + ctx->right_child_is_const = IsReallyConstant(*right_child); + ctx->op_child = ctx->left_child_is_const ? right_child : left_child; + ctx->const_child = ctx->left_child_is_const ? left_child : right_child; + + // Nothing to do unless the parent has a constant child node. + if (!ctx->left_child_is_const && !ctx->right_child_is_const) { + return false; + } + + // Don't move nodes across devices. + if (parent.device() != ctx->op_child->device() || + parent.device() != ctx->const_child->device()) { + return false; + } + + // Make sure that it is safe to change the value of the child node result. + if (ctx->op_child->input_size() < 2 || + nodes_to_preserve_.find(ctx->op_child->name()) != + nodes_to_preserve_.end() || + NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) { + return false; + } + + // Don't apply reassociation to floating point types of low precision. + // The danger of significant numerical changes is too high. + if (!CheckAttrExists(parent, "T").ok()) return false; + DataType dtype = parent.attr().at("T").type(); + if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { + return false; + } + + // Don't rewrite the tree if it might create cycles. + // TODO(rmlarsen): Add back handling of control dependency from op to C. + const auto& child_output = node_map_->GetOutputs(ctx->op_child->name()); + if (child_output.find(ctx->const_child) != child_output.end()) { + return false; + } + + // Get leaf nodes. + ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0)); + ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1)); + ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf); + ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf); + + if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) { + // Child is already foldable, leave it alone. + return false; + } + + // Don't move nodes across devices. + if (parent.device() != ctx->left_leaf->device() || + parent.device() != ctx->right_leaf->device()) { + return false; + } + + // Get shape and type information. + ctx->parent_input_props = &properties.GetInputProperties(parent.name()); + ctx->op_child_input_props = + &properties.GetInputProperties(ctx->op_child->name()); + if (must_have_properties && (ctx->parent_input_props == nullptr || + ctx->parent_input_props->size() < 2 || + ctx->op_child_input_props == nullptr || + ctx->op_child_input_props->size() < 2)) { + return false; + } + + VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": " + << parent.op() << "(" << left_child->op() << ", " << right_child->op() + << ")"; + + return true; +} + +bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties, + GraphDef* optimized_graph, + NodeDef* node) { + // This implements constant push-down for BiasAdd. In the following "CV" is a + // constant vector (tensor of rank 1), "V" is a (possibly) non-constant + // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly) + // non-constant matrix, and "BA" is BiasAdd. + // For a valid input graph, the following 4 rewrites are legal: + // + // 1) + + + // / \ / \ + // BA CV -- > BA V + // / \ / \ + // M V M CV + // + // 2) + + + // / \ / \ + // BA CM -- > BA M + // / \ / \ + // M V CM V + // + // 3) BA BA + // / \ / \ + // + CV -- > + V + // / \ / \ + // M V M CV + // + // 4) BA BA = parent + // / \ / \ + // BA CV -- > BA V = children + // / \ / \ + // M V M CV = leaves + // + // Cases 1 through 3 have additional sub-cases due to the symmetry of Add. + + const bool parent_is_bias_add = IsBiasAdd(*node); + if (!parent_is_bias_add && !IsAdd(*node)) return false; + ConstantPushDownContext ctx; + if (!PrepareConstantPushDown(*node, *properties, + /*must_have_properties=*/true, &ctx)) { + return false; + } + // Special case for BiasAdd: Since the left argument to BiasAdd must be rank + // >= 2 and the leaves must be vectors, we cannot swap them. + if (ctx.left_child_is_const && parent_is_bias_add) return false; + const bool child_is_bias_add = IsBiasAdd(*ctx.op_child); + if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false; + + // Get properties to validate rank and dtype constraints. + if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() || + (*ctx.parent_input_props)[0].shape().unknown_rank() || + (*ctx.parent_input_props)[1].shape().unknown_rank() || + (*ctx.op_child_input_props)[0].shape().unknown_rank() || + (*ctx.op_child_input_props)[1].shape().unknown_rank()) { + return false; + } + + // Now get the ranks and types of the 3 leaf nodes. + const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size(); + const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size(); + // At least one leaf must be a vector. + if (left_leaf_rank != 1 && right_leaf_rank != 1) return false; + const int vector_idx = left_leaf_rank == 1 ? 0 : 1; + const int matrix_idx = 1 - vector_idx; + + const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx]; + const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank; + if (vector_rank != 1) return false; // this should never happen. + const DataType vector_type = vector_prop.dtype(); + + const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx]; + const int matrix_rank = matrix_prop.shape().dim_size(); + const DataType matrix_type = matrix_prop.dtype(); + + const int const_idx = ctx.left_child_is_const ? 0 : 1; + const auto& const_prop = (*ctx.parent_input_props)[const_idx]; + const int const_rank = const_prop.shape().dim_size(); + const DataType const_type = const_prop.dtype(); + + int input_to_swap = -1; + + if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank && + const_type == matrix_type) { + // Case 2: + input_to_swap = matrix_idx; + } else if (const_rank == 1 && const_type == vector_type) { + // Case 1, 3, and, 4: + input_to_swap = vector_idx; + } + if (input_to_swap == -1) return false; + + node_map_->UpdateInput(node->name(), node->input(const_idx), + ctx.op_child->input(input_to_swap)); + node_map_->AddOutput(node->input(const_idx), ctx.op_child->name()); + if (ctx.op_child->input(input_to_swap) != + ctx.op_child->input(1 - input_to_swap)) { + node_map_->RemoveOutput(ctx.op_child->input(input_to_swap), + ctx.op_child->name()); + } + std::swap(*node->mutable_input(const_idx), + *ctx.op_child->mutable_input(input_to_swap)); + properties->ClearInputProperties(node->name()); + properties->ClearInputProperties(ctx.op_child->name()); + + return true; +} + +bool ConstantFolding::ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) { // Consider the transformation @@ -2864,141 +3058,78 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties, // by rotating the tree locally, e.g. // Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X) // Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)). - // - // Note: Don't touch BiasAdd since they can't handle vectors as their first - // inputs. // Get parent op type. const bool is_add = IsAdd(*node); const bool is_mul = IsMul(*node); const bool is_sub = IsSub(*node); const bool is_div = IsDiv(*node); + if (!(is_add || is_sub || is_mul || is_div)) return false; const bool is_symmetric = is_add || is_mul; - if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) || - NumNonControlInputs(*node) != 2) { + + ConstantPushDownContext ctx; + if (!PrepareConstantPushDown(*node, *properties, + /*must_have_properties=*/false, &ctx)) { return false; } - NodeDef* left_child = node_map_->GetNode(node->input(0)); - NodeDef* right_child = node_map_->GetNode(node->input(1)); - - const bool left_child_is_constant = IsReallyConstant(*left_child); - const bool right_child_is_constant = IsReallyConstant(*right_child); - if (!left_child_is_constant && !right_child_is_constant) { - return false; - } - // Don't move nodes across devices. - if (node->device() != left_child->device() || - node->device() != right_child->device()) { - return false; - } - NodeDef* op_child = left_child_is_constant ? right_child : left_child; - NodeDef* const_child = left_child_is_constant ? left_child : right_child; - // Don't rewrite the tree if it might create cycles. - // TODO(rmlarsen): Add back handling of control dependency from op to C. - const auto& child_output = node_map_->GetOutputs(op_child->name()); - if (child_output.find(const_child) != child_output.end()) { - return false; - } // Get child op type. - const bool is_child_add = IsAdd(*op_child); - const bool is_child_mul = IsMul(*op_child); - const bool is_child_sub = IsSub(*op_child); - const bool is_child_div = IsDiv(*op_child); + const bool is_child_add = IsAdd(*ctx.op_child); + const bool is_child_mul = IsMul(*ctx.op_child); + const bool is_child_sub = IsSub(*ctx.op_child); + const bool is_child_div = IsDiv(*ctx.op_child); const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub); const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div); if (!is_add_sub && !is_mul_div) { return false; } - const bool is_child_symmetric = is_child_add || is_child_mul; - // Make sure that it is safe to change the value of the child node result. - if (op_child->input_size() < 2 || - nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() || - NumNonControlOutputs(*op_child, *node_map_) > 1) { - return false; - } - // Do not rewrite integer expressions with subtraction or division. - // if (node->name().find("filter_boxes") != std::string::npos) return false; - if (!CheckAttrExists(*node, "T").ok()) return false; - DataType dtype = node->attr().at("T").type(); - if (dtype == DT_BFLOAT16 || dtype == DT_HALF) { - // Don't apply reassociation to floating point types of low precision. - // The danger of significant numerical changes is too high. - return false; - } - if (!(is_symmetric && is_child_symmetric) && - !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) { - return false; - } + const NodeDef* y_node = + ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf; + if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() && + !ctx.op_child_input_props->empty()) { + // If we know the shapes of the nodes being swapped, make sure we don't push + // down a larger node and create more work by broadcasting earlier in the + // expressions tree. + const PartialTensorShape c_shape( + (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape()); + const PartialTensorShape x_shape( + (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape()); - // Identify the nodes to swap. - NodeDef* left_leaf = node_map_->GetNode(op_child->input(0)); - NodeDef* right_leaf = node_map_->GetNode(op_child->input(1)); - const bool left_leaf_is_constant = IsReallyConstant(*left_leaf); - const bool right_leaf_is_constant = IsReallyConstant(*right_leaf); - if (left_leaf_is_constant && right_leaf_is_constant) { - // Child is already foldable, leave it alone. - return false; - } - // Don't move nodes across devices. - if (node->device() != left_leaf->device() || - node->device() != right_leaf->device()) { - return false; - } - - // Get the node names corresponding to X, Y, and C. - const string input_x = - left_leaf_is_constant ? op_child->input(1) : op_child->input(0); - const string input_y = - input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0); - const string input_c = - left_child_is_constant ? node->input(0) : node->input(1); - const string input_op = - left_child_is_constant ? node->input(1) : node->input(0); - - VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op() - << "(" << left_child->op() << ", " << right_child->op() << ")\n"; - - const NodeDef* y_node = left_leaf_is_constant ? left_leaf : right_leaf; - if (!IsReallyConstant(*y_node)) { - // Now make sure that we do not push a tensor that is larger than the tensor - // it replaces down, since that would create more broadcasting and increase - // work. - const std::vector& root_props = - properties.GetInputProperties(node->name()); - const std::vector& op_props = - properties.GetInputProperties(op_child->name()); - if (!root_props.empty() && !op_props.empty()) { - DCHECK_EQ(2, root_props.size()) << node->DebugString(); - DCHECK_EQ(2, op_props.size()) << op_child->DebugString(); - const PartialTensorShape c_shape( - root_props[left_child_is_constant ? 0 : 1].shape()); - const PartialTensorShape x_shape( - op_props[left_leaf_is_constant ? 0 : 1].shape()); - if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() && - c_shape.num_elements() > x_shape.num_elements()) { - return false; - } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() && - c_shape.dims() > 0) { - for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); - ++idx) { - if (x_shape.dim_size(idx) >= 0 && - c_shape.dim_size(idx) > x_shape.dim_size(idx)) { - return false; - } + if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() && + c_shape.num_elements() > x_shape.num_elements()) { + return false; + } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() && + c_shape.dims() > 0) { + for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) { + if (x_shape.dim_size(idx) >= 0 && + c_shape.dim_size(idx) > x_shape.dim_size(idx)) { + return false; } } } } - // Now we have identified the nodes to swap (non_const_leaf_input and - // const_child). + // Get the node names corresponding to X, Y, and C. + const string input_x = + ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0); + const string input_y = input_x == ctx.op_child->input(0) + ? ctx.op_child->input(1) + : ctx.op_child->input(0); + const string input_c = + ctx.left_child_is_const ? node->input(0) : node->input(1); + const string input_op = + ctx.left_child_is_const ? node->input(1) : node->input(0); + VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x; + + // Now we have identified the nodes to swap, updare the nodemap accordingly. node_map_->UpdateInput(node->name(), input_c, input_x); - node_map_->AddOutput(input_c, op_child->name()); + node_map_->AddOutput(input_c, ctx.op_child->name()); if (input_x != input_y) { - node_map_->RemoveOutput(input_x, op_child->name()); + node_map_->RemoveOutput(input_x, ctx.op_child->name()); } + properties->ClearInputProperties(node->name()); + properties->ClearInputProperties(ctx.op_child->name()); if (is_symmetric && is_child_symmetric) { // Easy case (only commutative ops). We always write this as one of @@ -3009,8 +3140,8 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties, // C Y node->set_input(0, input_x); node->set_input(1, input_op); - op_child->set_input(0, input_c); - op_child->set_input(1, input_y); + ctx.op_child->set_input(0, input_c); + ctx.op_child->set_input(1, input_y); } else { // More complicated case: When there are non-commutative operations like // subtractions or divisions involved, we may have to rotate the tree @@ -3019,6 +3150,7 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties, // Here are the final trees we want to generate for those 6 cases: // // (CYX signs): ++- +-- -+- --+ +-+ -++ + // // - - - - + + // / \ / \ / \ / \ / \ / \ // + X - X - X X + X - X - @@ -3030,22 +3162,22 @@ bool ConstantFolding::ConstantPushDown(const GraphProperties& properties, // expression auto is_leaf_negated = [&](const bool is_right_leaf) -> bool { bool leaf_negated = !is_child_symmetric && is_right_leaf; - bool child_negated = !is_symmetric && (op_child == right_child); + bool child_negated = !is_symmetric && (ctx.left_child_is_const); return leaf_negated != child_negated; }; const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul"; const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div"; - bool neg_c = !is_symmetric && (const_child == right_child); - bool neg_x = is_leaf_negated(left_leaf_is_constant); - bool neg_y = is_leaf_negated(!left_leaf_is_constant); + bool neg_c = !is_symmetric && !ctx.left_child_is_const; + bool neg_x = is_leaf_negated(ctx.left_leaf_is_const); + bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const); // Rewrite the parent node. node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op); node->set_input(0, neg_x ? input_op : input_x); node->set_input(1, neg_x ? input_x : input_op); // Rewrite the child node. - op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op); - op_child->set_input(0, neg_c ? input_y : input_c); - op_child->set_input(1, neg_c ? input_c : input_y); + ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op); + ctx.op_child->set_input(0, neg_c ? input_y : input_c); + ctx.op_child->set_input(1, neg_c ? input_c : input_y); } return true; } @@ -3060,6 +3192,9 @@ bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node, // X C1 C1 C2 // // where C1 and C2 are constants and X is non-constant. + // + // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code. + if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false; NodeDef* mul_left_child = node_map_->GetNode(node->input(0)); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 21ad5144c24..5c29591d939 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -143,10 +143,48 @@ class ConstantFolding : public GraphOptimizer { // Returns true if the transformation applied successfully. bool PartialConstPropThroughIdentityN(NodeDef* node); - // Pushes down constants on '+' and '*' operators if applicable. Returns true - // the transformation applied successfully. - bool ConstantPushDown(const GraphProperties& properties, - GraphDef* optimized_graph, NodeDef* node); + struct ConstantPushDownContext { + NodeDef* op_child; + NodeDef* const_child; + bool left_child_is_const; + bool right_child_is_const; + NodeDef* left_leaf; + NodeDef* right_leaf; + bool left_leaf_is_const; + bool right_leaf_is_const; + + // Shape & type information. + const std::vector* parent_input_props; + const std::vector* op_child_input_props; + }; + + // Populates ctx with pointers to the nodes in expression tree for which + // constant pushdown optimization is being considered, corresponding to one of + // the following configurations: + // + // parent parent + // / \ / \ + // op_child const_child const_child op_child + // / \ / \ + // left_leaf right_leaf left_leaf right_leaf + // + // Returns true if the expression is possible amenable for optimization. + // Returns false if must_have_properties is true and input properties for + // parent and op_child are not known. + bool PrepareConstantPushDown(const NodeDef& parent, + const GraphProperties& properties, + bool must_have_properties, + ConstantPushDownContext* ctx) const; + + // Pushes down constants on '+', '-', '*', and '/' operators if applicable. + // Returns true if the transformation applied successfully. + bool ConstantPushDown(GraphProperties* properties, GraphDef* optimized_graph, + NodeDef* node); + + // Pushes down constants on '+' and 'BiasAdd' operators if applicable. + // Returns true if the graph was modified. + bool ConstantPushDownBiasAdd(GraphProperties* properties, + GraphDef* optimized_graph, NodeDef* node); // Aggregate constants present around a conv operator. Returns true if the // transformation was applied successfully. diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 7bcae29c63a..8f0c9f930a5 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -400,7 +400,7 @@ TEST_F(ConstantFoldingTest, AddSubtactTree) { } } -TEST_F(ConstantFoldingTest, TreeCanonicalization) { +TEST_F(ConstantFoldingTest, ConstantPushDown) { for (int is_add : {true, false}) { for (int is_parent_commutative : {true, false}) { for (int is_child_commutative : {true, false}) { @@ -413,28 +413,32 @@ TEST_F(ConstantFoldingTest, TreeCanonicalization) { ops::Placeholder(s.WithOpName("x"), DT_FLOAT, ops::Placeholder::Shape(TensorShape({2, 2}))); - auto get_op = [&](bool is_commutative, bool is_left_arg_cont, + auto get_op = [&](bool is_commutative, bool is_left_arg_const, const string& name, const Output& const_arg, const Output non_const_arg) -> Output { if (is_add) { if (is_commutative) { - return ops::Add(s.WithOpName(name), - is_left_arg_cont ? const_arg : non_const_arg, - is_left_arg_cont ? non_const_arg : const_arg); + return ops::Add( + s.WithOpName(name), + is_left_arg_const ? const_arg : non_const_arg, + is_left_arg_const ? non_const_arg : const_arg); } else { - return ops::Sub(s.WithOpName(name), - is_left_arg_cont ? const_arg : non_const_arg, - is_left_arg_cont ? non_const_arg : const_arg); + return ops::Sub( + s.WithOpName(name), + is_left_arg_const ? const_arg : non_const_arg, + is_left_arg_const ? non_const_arg : const_arg); } } else { if (is_commutative) { - return ops::Mul(s.WithOpName(name), - is_left_arg_cont ? const_arg : non_const_arg, - is_left_arg_cont ? non_const_arg : const_arg); + return ops::Mul( + s.WithOpName(name), + is_left_arg_const ? const_arg : non_const_arg, + is_left_arg_const ? non_const_arg : const_arg); } else { - return ops::Div(s.WithOpName(name), - is_left_arg_cont ? const_arg : non_const_arg, - is_left_arg_cont ? non_const_arg : const_arg); + return ops::Div( + s.WithOpName(name), + is_left_arg_const ? const_arg : non_const_arg, + is_left_arg_const ? non_const_arg : const_arg); } } }; @@ -472,6 +476,76 @@ TEST_F(ConstantFoldingTest, TreeCanonicalization) { } } +TEST_F(ConstantFoldingTest, ConstantPushDownBiasAdd) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c_mat = ops::Const(s.WithOpName("c_mat"), 2.0f, {2, 2}); + Output c_vec = ops::Const(s.WithOpName("c_vec"), 3.0f, {2}); + Output x_mat = ops::Placeholder(s.WithOpName("x_mat"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output x_vec = ops::Placeholder(s.WithOpName("x_vec"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2}))); + // Rewrite expected for cases 1 through 3 and their symmetric equivalents, + // and case 4. + Output child1 = ops::BiasAdd(s.WithOpName("child1"), c_mat, x_vec); + Output parent1 = ops::Add(s.WithOpName("parent1"), child1, c_vec); + Output child1a = ops::BiasAdd(s.WithOpName("child1a"), c_mat, x_vec); + Output parent1a = ops::Add(s.WithOpName("parent1a"), c_vec, child1a); + + Output child2 = ops::BiasAdd(s.WithOpName("child2"), x_mat, c_vec); + Output parent2 = ops::Add(s.WithOpName("parent2"), child2, c_mat); + Output child2a = ops::BiasAdd(s.WithOpName("child2a"), x_mat, c_vec); + Output parent2a = ops::Add(s.WithOpName("parent2a"), c_mat, child2a); + + Output child3 = ops::Add(s.WithOpName("child3"), c_mat, x_vec); + Output parent3 = ops::BiasAdd(s.WithOpName("parent3"), child3, c_vec); + Output child3a = ops::Add(s.WithOpName("child3a"), x_vec, c_mat); + Output parent3a = ops::BiasAdd(s.WithOpName("parent3a"), child3a, c_vec); + + Output child4 = ops::BiasAdd(s.WithOpName("child4"), c_mat, x_vec); + Output parent4 = ops::BiasAdd(s.WithOpName("parent4"), child4, c_vec); + + // No rewrite expected. + Output child5 = ops::Add(s.WithOpName("child5"), x_vec, x_vec); + Output parent5 = ops::BiasAdd(s.WithOpName("parent5"), c_mat, child5); + Output child6 = ops::Add(s.WithOpName("child6"), x_vec, c_vec); + Output parent6 = ops::BiasAdd(s.WithOpName("parent6"), c_mat, child6); + + GrapplerItem item; + item.fetch = {"parent1", "parent2", "parent3", "parent1a", "parent2a", + "parent3a", "parent4", "parent5", "parent6"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(/*cpu_device=*/nullptr); + GraphDef output; + Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(22, output.node_size()); + for (const auto& node : output.node()) { + if (node.name() == "child1" || node.name() == "child1a" || + node.name() == "child2" || node.name() == "child2a" || + node.name() == "child3" || node.name() == "child3a" || + node.name() == "child4") { + EXPECT_EQ(node.op(), "Const") << " node: " << node.name(); + } else if (node.name() != "c_mat" && node.name() != "c_vec") { + EXPECT_NE(node.op(), "Const") << " node: " << node.name(); + } + } + // Check that the result nodes have the expected value. + auto x_mat_t = GenerateRandomTensor(TensorShape({2, 2})); + auto x_vec_t = GenerateRandomTensor(TensorShape({2})); + std::vector fetch = item.fetch; + auto tensor_expected = EvaluateNodes( + item.graph, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}}); + ASSERT_EQ(fetch.size(), tensor_expected.size()); + auto tensors = + EvaluateNodes(output, fetch, {{"x_vec", x_vec_t}, {"x_mat", x_mat_t}}); + ASSERT_EQ(fetch.size(), tensors.size()); + for (int i = 0; i < fetch.size(); i++) { + test::ExpectTensorEqual(tensor_expected[i], tensors[i]); + } +} + TEST_F(ConstantFoldingTest, MulConvPushDownTest_Conv2D_ScalarConst) { for (string data_format : { "NHWC",