diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 94741a11ffa..625780e7c91 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -247,16 +247,10 @@ namespace tensorflow { // // P = Conv2DWithBiasBackpropBias(O, O_m) // -// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is -// the context matching depth. If _MklConv2DWithBias is not within the context -// matching depth, then we do not rewrite BiasAddGrad. - -// How many hops do we search for matching node in the backward dataflow graph? -// We use maxhop of 10 based on empirical observations. Also, these are -// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D) -// directly goes to backward nodes, we do not expect the hop-distance -// would be more than few nodes. -static size_t kNodeMergeContextMaxDepth = 10; +// Rewrite of BiasAddGrad into Conv2DWithBiasBackpropBias takes place depending +// on the matching 'context'. The term context is loosely related to which +// forward op is _associated_ to BiasAddGrad. If it is _MklConv2DWithBias then +// we consider it Conv2D context; if it is MatMul, then it is MatMul context. class MklLayoutRewritePass : public GraphOptimizationPass { public: @@ -280,6 +274,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.max_pool = "MaxPool"; csinfo_.max_pool_grad = "MaxPoolGrad"; csinfo_.mkl_conv2d = "_MklConv2D"; + csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput"; + csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter"; csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias"; csinfo_.mkl_conv2d_with_bias_backprop_bias = "_MklConv2DWithBiasBackpropBias"; @@ -360,16 +356,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0, csinfo_.mkl_conv2d_with_bias}); - // We use maxhop of 10 based on empirical observations. Also, these are - // maxhops in backward data-flow graph. Since input of forward nodes - // (Conv2D) directly goes to backward nodes, we do not expect the - // hop-distance would be more than few nodes. biasaddgrad_matmul_context_ = {csinfo_.bias_add_grad, csinfo_.matmul, - kNodeMergeContextMaxDepth}; + IsBiasAddGradInMatMulContext}; biasaddgrad_conv2dwithbias_context_ = {csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias, - kNodeMergeContextMaxDepth}; + IsBiasAddGradInConv2DWithBiasContext}; cinfo_.push_back(&biasaddgrad_matmul_context_); cinfo_.push_back(&biasaddgrad_conv2dwithbias_context_); @@ -392,9 +384,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string node; // Name of the node to be rewritten string fwd; // Name of the node in the forward pass that this node // corresponds to - size_t max_hop; // Maximum number of hops the fwd is located - // from this node. If the fwd is farther than max_hop - // then we do not rewrite the node. + std::function context_match_fn; } ContextInfo; /// Structure to specify the name of an original node, its new name after @@ -438,7 +428,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Structure to store all constant strings /// NOTE: names are alphabetically sorted. - struct { + typedef struct { string avg_pool; string avg_pool_grad; string bias_add; @@ -457,13 +447,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string max_pool; string max_pool_grad; string mkl_conv2d; + string mkl_conv2d_grad_input; + string mkl_conv2d_grad_filter; string mkl_conv2d_with_bias; string mkl_conv2d_with_bias_backprop_bias; string relu; string relu_grad; string reshape; string split; - } csinfo_; + } ConstStringsInfo; private: /// Maintain info about nodes to rewrite @@ -478,6 +470,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { /// Maintain info about nodes to rewrite static std::vector cinfo_; + /// Maintain structure of constant strings + static ConstStringsInfo csinfo_; + /// Context variables used in referencing rules static ContextInfo biasaddgrad_matmul_context_; static ContextInfo biasaddgrad_conv2dwithbias_context_; @@ -629,6 +624,173 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return false; } + // Is BiasAddGrad node in 'n' is associated with Conv2DWithBias node + // specified in contextinfo 'ci'. Function updates fwd_node to point + // to Conv2DWithBias node if 'n' is associated with Conv2DWithBias. + // + // Association checks for one of the following graphs: + // + // Graph A: + // + // _ = Conv2DWithBias(F, I, _) + // .. + // _ = Conv2DBackpropFilter(F, _, G) + // _ = Conv2DBackpropInput(_, I, G) + // _ = BiasAddGrad(G) + // + // OR + // + // Graph B: + // + // _ = Conv2DWithBias(F, _, _) + // .. + // _ = Conv2DBackpropFilter(F, _, G) + // _ = BiasAddGrad(G) + // + // Here F, G, and I are graph nodes; _ represents graph nodes that we + // don't care here. + // + // @return - true (if BiasAddGrad is associated with Conv2DWithBias); + // false otherwise. + static bool IsBiasAddGradInConv2DWithBiasContext(const Node* n, + const Node** fwd_node, + void* ci) { + CHECK_NOTNULL(n); + CHECK_NOTNULL(fwd_node); + CHECK_NOTNULL(ci); + *fwd_node = nullptr; + + CHECK_EQ(n->type_string(), csinfo_.bias_add_grad); + + // Get the only 1 input of BiasAddGrad. + CHECK_EQ(n->num_inputs(), 1); + const Node* bias_add_grad_inp = nullptr; + TF_CHECK_OK(n->input_node(0, &bias_add_grad_inp)); + CHECK_NOTNULL(bias_add_grad_inp); + + // Check if this input also goes to BackpropFilter and BackpropInput + // as 3rd input. + bool found_backprop_input = false; + bool found_backprop_filter = false; + Node* backprop_filter_node = nullptr; + Node* backprop_input_node = nullptr; + + for (const Edge* e : bias_add_grad_inp->out_edges()) { + Node* third_input = nullptr; + if (e->dst()->type_string() == csinfo_.conv2d_grad_input || + e->dst()->type_string() == csinfo_.mkl_conv2d_grad_input) { + // Third input (index 2) of BackpropInput + TF_CHECK_OK(e->dst()->input_node(2, &third_input)); + // Third input (index 2) of BackpropInput must be same as the input + // of BiasAddGrad. + if (third_input == bias_add_grad_inp) { + found_backprop_input = true; + backprop_input_node = e->dst(); + } + } + + if (e->dst()->type_string() == csinfo_.conv2d_grad_filter || + e->dst()->type_string() == csinfo_.mkl_conv2d_grad_filter) { + // Third input (index 2) of BackpropFilter + TF_CHECK_OK(e->dst()->input_node(2, &third_input)); + // Third input (index 2) of BackpropFilter must be same as the input + // of BiasAddGrad. + if (third_input == bias_add_grad_inp) { + found_backprop_filter = true; + backprop_filter_node = e->dst(); + } + } + + // If we found both the nodes, then we can stop the search. + if (found_backprop_input && found_backprop_filter) { + break; + } + } + + // If BackpropFilter node is not found, then this is not + // Conv2DWithBias context. For 2nd graph in the example above, only + // BackpropFilter would be present. + if (!found_backprop_filter) { + return false; + } + + // Otherwise, we found the nodes. + CHECK_NOTNULL(backprop_filter_node); + if (found_backprop_input) { + CHECK_NOTNULL(backprop_input_node); + } + + // Now that we confirmed that this is Conv2DWithBias context, we need to + // get access to the forward node (Conv2DWithBias). 2nd input of + // Conv2DWithBias is same as the 2nd input of Conv2DBackpropInput; 1st + // input of Conv2DWithBias is same as the 1st input of Conv2DBackpropFilter + // (This comes from definition of gradient computation for Conv2D). + if (found_backprop_input) { + // Graph A in the example. + Node* second_inp_of_input = nullptr; + Node* first_inp_of_filter = nullptr; + TF_CHECK_OK(backprop_input_node->input_node(1, &second_inp_of_input)); + TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); + CHECK_NOTNULL(second_inp_of_input); + CHECK_NOTNULL(first_inp_of_filter); + + // Now we need to find out Conv2DWithBias node from these input nodes. + // Conv2DWithBias node is the node that accepts both the nodes + // second_inp_of_input and first_inp_of_filter in 2nd and 1st input slots. + for (const Edge* fe : first_inp_of_filter->out_edges()) { + if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && + fe->dst_input() == 0) { + for (const Edge* ie : second_inp_of_input->out_edges()) { + if (ie->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && + ie->dst_input() == 1 && fe->dst() == ie->dst()) { + VLOG(1) << "MklLayoutRewritePass: found " + << fe->dst()->DebugString() + << " as the forward node for matching context, backward" + << " node is: " << n->DebugString(); + *fwd_node = fe->dst(); + return true; + } + } + } + } + } else { + // We did not find BackpropInput, so we work with BackpropFilter only. + // Graph B in the example. + Node* first_inp_of_filter = nullptr; + TF_CHECK_OK(backprop_filter_node->input_node(0, &first_inp_of_filter)); + CHECK_NOTNULL(first_inp_of_filter); + + // Now we need to find out Conv2DWithBias node from first input of + // BackpropFIlter. Conv2DWithBias node is the node that accepts + // first_inp_of_filter in 1st input slot. + for (const Edge* fe : first_inp_of_filter->out_edges()) { + if (fe->dst()->type_string() == csinfo_.mkl_conv2d_with_bias && + fe->dst_input() == 0) { + VLOG(1) << "MklLayoutRewritePass: found " + << fe->dst()->DebugString() + << " as the forward node for matching context, backward" + << " node is: " << n->DebugString(); + *fwd_node = fe->dst(); + return true; + } + } + } + + return false; + } + + // Is BiasAddGrad node in 'n' is associated with MatMul node + // specified in contextinfo 'ci'. Function does not update fwd_node. + // + // @return - true (if BiasAddGrad is associated with MatMul); + // false otherwise. + static bool IsBiasAddGradInMatMulContext(const Node* n, + const Node** fwd_node, + void* ci) { + return (!IsBiasAddGradInConv2DWithBiasContext(n, fwd_node, ci)); + } + + // Rewrite rule that uses context-information for matching, // used in scenario 2. // @@ -639,8 +801,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static bool ContextMatchRewrite(const Node* n, const ContextInfo* c); // Helper function that searches the matching contextinfo for the node. - // Implements depth-first search in the data dependence graph for the - // gradient op in the backward direction. // // @input n - Node (gradient op) whose contextinfo is to be searched, // fwd_node - pointer to node from the forward pass that this node @@ -788,6 +948,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { Node* orig_node); }; +MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; MklLayoutRewritePass::ContextInfo MklLayoutRewritePass::biasaddgrad_conv2dwithbias_context_; MklLayoutRewritePass::ContextInfo @@ -1667,12 +1828,12 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr* g, const ContextInfo* ci = nullptr; bool is_context_based_rewrite = false; if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) { - CHECK_NOTNULL(fwd_node); is_context_based_rewrite = true; // Sanity checks for context-based rewrite (if any) if (orig_node->type_string() == csinfo_.bias_add_grad && ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) { + CHECK_NOTNULL(fwd_node); DataType orig_T, ctx_T; string orig_data_format, ctx_data_format; TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T)); @@ -1784,69 +1945,17 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, CHECK_NOTNULL(fwd_node); *fwd_node = nullptr; - // Search for matching contextinfo based on node name. - // There could be more than one matching contextinfos. - bool is_matching_cinfo_found = false; - std::vector mci; + // Search for matching contextinfo based on node name and call + // callback function using matching contextinfo. + // There could be more than one matching contextinfos but whichever + // matches first is returned. for (auto ci = cinfo_.cbegin(); ci != cinfo_.cend(); ++ci) { - if (n->type_string() == (*ci)->node) { - mci.push_back(*ci); - is_matching_cinfo_found = true; + if (n->type_string() == (*ci)->node && + (*ci)->context_match_fn(n, fwd_node, *ci)) { + VLOG(1) << "Found context as matching: " << (*ci)->fwd; + return *ci; } } - // If no matching contextinfo is found, return immediately. - if (!is_matching_cinfo_found) { - return nullptr; - } - - VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string() - << " in backwards."; - - // Now we will check for forward op name for context info in data - // flow graph. Get the max hops we should search for the fwd node. - // We are now going to search (breadth-first) backwards in data - // dependence graph (for up to max hops) from n for the node - // specified in fwd. - // queue to maintain nodes to be visited and depth info for - // breadth-first search - std::queue> nqueue; - const Node* curr_node = n; - size_t curr_depth = 0; - nqueue.push(std::make_pair(curr_node, curr_depth)); - - while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) { - std::pair curr_pair = nqueue.front(); - nqueue.pop(); - - std::set visited_nodes; - curr_node = curr_pair.first; - curr_depth = curr_pair.second; - CHECK_NOTNULL(curr_node); - - VLOG(1) << "MklLayoutRewritePass: Visiting node: " - << curr_node->type_string() << " at depth: " << curr_depth - << " for node: " << n->type_string(); - - // If we find a match, we return immediately. - for (const ContextInfo* ci : mci) { - if (curr_node->type_string() == ci->fwd) { - *fwd_node = curr_node; - return ci; - } - } - - // Else we explore backward edges from current node. - // Add the source nodes of all incoming edges of the node to the queue. - for (const Edge* e : curr_node->in_edges()) { - // We do not visit already visited node. - if (visited_nodes.find(e->src()) == visited_nodes.end()) { - // Depth of these nodes is 1 more than the depth of current node. - nqueue.push(std::make_pair(e->src(), curr_depth + 1)); - visited_nodes.insert(e->src()); - } - } - } /* while */ - return nullptr; } diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 3c4a5263afd..efbe2134e0f 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -345,7 +345,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) { // Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias // rewrite tests -// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E) +// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter +// and BackpropInput TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { InitGraph( "node { name: 'A' op: 'Input'}" @@ -364,16 +365,255 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) { "node { name: 'E' op: 'Sub'" " attr {key: 'T' value { type: DT_FLOAT } }" " input: ['D', 'A']}" - "node { name: 'F' op: 'BiasAddGrad'" + "node { name: 'F' op: 'Int32Input'}" + "node { name: 'G' op: '_MklConv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" + "node { name: 'H' op: 'Int32Input'}" + "node { name: 'I' op: '_MklConv2DBackpropInput'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['H', 'B', 'E', 'M', 'N', 'O']}" + "node { name: 'J' op: 'BiasAddGrad'" " attr { key: 'T' value { type: DT_FLOAT } }" " attr { key: 'data_format' value { s: 'NCHW' } }" " input: ['E'] }"); EXPECT_EQ(DoMklLayoutOptimizationPass(), "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" - "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);" - "N(_MklInput);O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;" - "DMT/_0->F:1;E->F;E:control->DMT/_0:control;M->D:3;N->D:4;" - "O->D:5"); + "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" + "I(_MklConv2DBackpropInput);J(_MklConv2DWithBiasBackpropBias);" + "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G;B->D:1;" + "B->I:1;C->D:2;D->E;DMT/_0->J:1;E->G:2;E->I:2;E->J;" + "E:control->DMT/_0:control;F->G:1;H->I;M->D:3;M->G:3;M->I:3;" + "N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); +} + +// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter +// and BackpropInput. But nodes do not match criteria for rewrite. So +// rewrite should not happen. +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative1) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'C', 'M', 'N', 'O']}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'A']}" + "node { name: 'F' op: 'Int32Input'}" + "node { name: 'G' op: '_MklConv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" + "node { name: 'H' op: 'Int32Input'}" + "node { name: 'I' op: '_MklConv2DBackpropInput'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['H', 'B', 'E', 'M', 'N', 'O']}" + "node { name: 'J' op: 'BiasAddGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['E'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" + "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" + "I(_MklConv2DBackpropInput);J(BiasAddGrad);" + "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" + "B->I:1;C->D:2;D->E;E->G;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" + "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); +} + +// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter +// and BackpropInput. But nodes do not match criteria for rewrite. So +// rewrite should not happen. +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['B', 'A', 'C', 'M', 'N', 'O']}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'A']}" + "node { name: 'F' op: 'Int32Input'}" + "node { name: 'G' op: '_MklConv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" + "node { name: 'H' op: 'Int32Input'}" + "node { name: 'I' op: '_MklConv2DBackpropInput'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['H', 'B', 'E', 'M', 'N', 'O']}" + "node { name: 'J' op: 'BiasAddGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['E'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" + "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(Int32Input);" + "I(_MklConv2DBackpropInput);J(BiasAddGrad);" + "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" + "B->I:1;C->D:2;D->E;E->G:2;E->I:2;E->J;F->G:1;H->I;M->D:3;M->G:3;" + "M->I:3;N->D:4;N->G:4;N->I:4;O->D:5;O->G:5;O->I:5"); +} + + +// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'C', 'M', 'N', 'O']}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'A']}" + "node { name: 'F' op: 'Int32Input'}" + "node { name: 'G' op: '_MklConv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" + "node { name: 'H' op: 'BiasAddGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['E'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);" + "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);" + "H(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);" + "O(_MklInput)|A->D;A->E:1;A->G;B->D:1;C->D:2;D->E;DMT/_0->H:1;" + "E->G:2;E->H;E:control->DMT/_0:control;F->G:1;M->D:3;M->G:3;" + "N->D:4;N->G:4;O->D:5;O->G:5"); +} + +// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only +// But BackpropFilter node inputs do not satisfy criteria for rewrite. +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative1) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'B', 'C', 'M', 'N', 'O']}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'A']}" + "node { name: 'F' op: 'Int32Input'}" + "node { name: 'G' op: '_MklConv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['E', 'F', 'A', 'M', 'N', 'O'] }" + "node { name: 'H' op: 'BiasAddGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['E'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" + "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" + "M(_MklInput);N(_MklInput);O(_MklInput)|A->D;A->E:1;A->G:2;B->D:1;" + "C->D:2;D->E;E->G;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" + "O->G:5"); +} + +// BiasAddGrad rewrite to BackpropBias in the presence of BackpropFilter only +// But BackpropFilter node inputs do not satisfy criteria for rewrite. +TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_BpropFilter_Negative2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'M' op: '_MklInput'}" + "node { name: 'N' op: '_MklInput'}" + "node { name: 'O' op: '_MklInput'}" + "node { name: 'D' op: '_MklConv2DWithBias'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['B', 'A', 'C', 'M', 'N', 'O']}" + "node { name: 'E' op: 'Sub'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'A']}" + "node { name: 'F' op: 'Int32Input'}" + "node { name: 'G' op: '_MklConv2DBackpropFilter'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'use_cudnn_on_gpu' value { b: false } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " input: ['A', 'F', 'E', 'M', 'N', 'O'] }" + "node { name: 'H' op: 'BiasAddGrad'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " input: ['E'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);" + "E(Sub);F(Int32Input);G(_MklConv2DBackpropFilter);H(BiasAddGrad);" + "M(_MklInput);N(_MklInput);O(_MklInput)|A->D:1;A->E:1;A->G;B->D;" + "C->D:2;D->E;E->G:2;E->H;F->G:1;M->D:3;M->G:3;N->D:4;N->G:4;O->D:5;" + "O->G:5"); } // No _MklConv2DWithBias in context, but _MklConv2D in context.