From 0af8346d1fde62d94ff1824ba2ca9a01848a18b6 Mon Sep 17 00:00:00 2001 From: "Varghese, Jojimon" Date: Fri, 25 Sep 2020 14:37:05 -0700 Subject: [PATCH 1/2] Fix for concatv2 unit test --- .../core/common_runtime/mkl_layout_pass.cc | 13 +++++++++-- .../common_runtime/mkl_layout_pass_test.cc | 23 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 176670c8aa5..e046eb8ae1d 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -407,7 +407,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.concatv2, mkl_op_registry::GetMklOpName(csinfo_.concatv2), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()}); rinfo_.push_back( {csinfo_.conjugate_transpose, mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), @@ -549,7 +549,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.quantized_concatv2, mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()}); rinfo_.push_back({csinfo_.quantized_conv2d, mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), CopyAttrsQuantizedConv2D, AlwaysRewrite, @@ -1496,6 +1496,15 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } return false; } + // For NKL-DNN only int32 supported for axis data type + static bool ConcatV2Rewrite(const Node *n) { + DataType T; + GetNodeAttr(n->def(),"Tidx", &T); + if (T != DT_INT32) { + return false; + } + return true; + } static bool DequantizeRewrite(const Node* n) { DCHECK(n); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index bbe20f4436d..51ba0d913ea 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -4258,6 +4258,29 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); } + +TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_idxtest) { + InitGraph( + "node { name: 'A' op: 'Const' " + " attr { key: 'dtype' value { type: DT_INT64 } }" + " attr { key: 'value' value { " + " tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'B' op: 'InputList'" + " attr { key: 'N' value { i: 2 } }}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: 'ConcatV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'Tidx' value { type: DT_INT64 } }" + " attr { key: 'N' value { i: 2 } }" + " input: ['B:0', 'B:1', 'A']}" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['C', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Const);B(InputList);C(Input);D(ConcatV2);E(Zeta)|" + "A->D:2;B->D;B:1->D:1;C->E;D->E:1"); +} + TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNorm_DeviceTest) { InitGraph( "node { name: 'A' op: 'Input'}" From 36038c881b57e9bd651cdec3aa761e5979b9cc18 Mon Sep 17 00:00:00 2001 From: xiaohong1031 Date: Mon, 5 Oct 2020 12:34:29 -0700 Subject: [PATCH 2/2] update per Google code review feedback --- tensorflow/core/common_runtime/mkl_layout_pass.cc | 7 ++----- tensorflow/core/common_runtime/mkl_layout_pass_test.cc | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index e046eb8ae1d..c920b5c563c 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -1496,14 +1496,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { } return false; } - // For NKL-DNN only int32 supported for axis data type + // For oneDNN, only int32 is supported for axis data type static bool ConcatV2Rewrite(const Node *n) { DataType T; GetNodeAttr(n->def(),"Tidx", &T); - if (T != DT_INT32) { - return false; - } - return true; + return (T == DT_INT32); } static bool DequantizeRewrite(const Node* n) { diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index 51ba0d913ea..174be5a9036 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -4259,7 +4259,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_DeviceTest) { } -TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_idxtest) { +TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_IndexTest) { InitGraph( "node { name: 'A' op: 'Const' " " attr { key: 'dtype' value { type: DT_INT64 } }"