From 646bc6ecee9e3eb939ea4118ac8ff446b19d97ef Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Wed, 14 Oct 2020 18:08:00 -0700 Subject: [PATCH] Skip FusedBatchNorm rewrite when input is 5D tensor --- .../eager/mkl_eager_op_rewrite.cc | 18 +++++- .../eager/mkl_eager_op_rewrite_test.cc | 20 ++++++ .../core/common_runtime/mkl_layout_pass.cc | 14 ++++- .../common_runtime/mkl_layout_pass_test.cc | 63 +++++++++++++++++++ tensorflow/core/graph/mkl_graph_util.h | 10 +++ 5 files changed, 121 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc index 31d5e05462c..9f5eb90ab64 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc @@ -54,6 +54,9 @@ class MklEagerOpRewrite : public EagerOpRewrite { // Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter. static bool RewriteConv2D(EagerOperation* op); + // Rewrite rule for FusedBatchNormV3 and FusedBatchNormGradV3 + static bool RewriteFusedBatchNormV3(EagerOperation* op); + // Calls op-specific rewrite function to create new MKL op. Status RewriteToMklOp(EagerOperation* orig_op, std::unique_ptr* mkl_op); @@ -110,9 +113,10 @@ MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line) InsertMKLEagerOps( {"FusedBatchNormGradV2", AlwaysRewrite, CreateGenericMklOp}); InsertMKLEagerOps( - {"FusedBatchNormGradV3", AlwaysRewrite, CreateGenericMklOp}); + {"FusedBatchNormGradV3", RewriteFusedBatchNormV3, CreateGenericMklOp}); InsertMKLEagerOps({"FusedBatchNormV2", AlwaysRewrite, CreateGenericMklOp}); - InsertMKLEagerOps({"FusedBatchNormV3", AlwaysRewrite, CreateGenericMklOp}); + InsertMKLEagerOps( + {"FusedBatchNormV3", RewriteFusedBatchNormV3, CreateGenericMklOp}); InsertMKLEagerOps({"MatMul", AlwaysRewrite, CreateGenericMklOp}); }; @@ -246,5 +250,15 @@ bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) { return (padding != "EXPLICIT"); } +bool MklEagerOpRewrite::RewriteFusedBatchNormV3(EagerOperation* op) { + const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef(); + if (Check5DFormat(ndef)) { + VLOG(1) << "Eager Op Rewrite: FusedBatchNorm(Grad)V3 op currently does not " + << "support 5D tensors."; + return false; + } + return true; +} + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc index a18301231cf..b56d97428b3 100644 --- a/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc +++ b/tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite_test.cc @@ -130,6 +130,26 @@ REGISTER_TEST_ALL_TYPES(ConvOpsExplicitPadding_Negative); REGISTER_TEST_ALL_TYPES(MostOps_Positive); #undef REGISTER_TEST +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(EagerOpRewriteTest, NAME##_##T) { \ + std::vector Fused_BN_ops = {"FusedBatchNormV3", \ + "FusedBatchNormGradV3"}; \ + for (int i = 0; i < Fused_BN_ops.size(); ++i) { \ + auto orig_op = CreateOp(Fused_BN_ops[i]); \ + orig_op->MutableAttrs()->Set("T", T); \ + orig_op->MutableAttrs()->Set("data_format", "" DATA_FORMAT ""); \ + CheckRewrite(orig_op.get(), Fused_BN_ops[i]); \ + } \ + } +#define DATA_FORMAT "NCDHW" +REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_1); + +#define DATA_FORMAT "NDHWC" +REGISTER_TEST_ALL_TYPES(FusedBatchNormV3_5D_Negative_2); + +#undef DATA_FORMAT +#undef REGISTER_TEST + } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/common_runtime/mkl_layout_pass.cc b/tensorflow/core/common_runtime/mkl_layout_pass.cc index 99c41e4c75e..350aacee834 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass.cc @@ -475,11 +475,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.fused_batch_norm_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); rinfo_.push_back( {csinfo_.fused_batch_norm_grad_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), - CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); + CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); #ifdef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.fused_batch_norm_ex, native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex @@ -1705,6 +1705,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { return do_rewrite; } + static bool FusedBatchNormV3Rewrite(const Node* n) { + DCHECK(n); + if (Check5DFormat(n->def())) { + VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not " + << "support 5D tensors."; + return false; + } + return true; + } + static bool FusedBatchNormExRewrite(const Node* n) { DCHECK(n); diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index 4366a7892d3..fda5ad93352 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -3394,6 +3394,37 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV2_Negative) { REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_Positive); #undef REGISTER_TEST +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: 'Float32Input'}" \ + "node { name: 'C' op: 'Float32Input'}" \ + "node { name: 'D' op: 'Float32Input'}" \ + "node { name: 'E' op: 'Float32Input'}" \ + "node { name: 'F' op: 'FusedBatchNormV3'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " input: ['A', 'B', 'C', 'D', 'E'] }" \ + "node { name: 'G' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'F'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(Float32Input);C(Float32Input);" \ + "D(Float32Input);E(Float32Input);F(FusedBatchNormV3);G(Zeta)" \ + "|A->F;A->G;B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); \ +} +#define DATA_FORMAT "'NCDHW'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_1); + +#define DATA_FORMAT "'NDHWC'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormV3_5D_Negative_2); + +#undef DATA_FORMAT +#undef REGISTER_TEST + TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) { DCHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); InitGraph( @@ -3417,6 +3448,38 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedBatchNormV3_Negative) { "B->F:1;C->F:2;D->F:3;E->F:4;F->G:1"); } +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: '" #INPUT "'}" \ + "node { name: 'C' op: 'Float32Input'}" \ + "node { name: 'D' op: 'Float32Input'}" \ + "node { name: 'E' op: 'Float32Input'}" \ + "node { name: 'F' op: 'Float32Input'}" \ + "node { name: 'G' op: 'FusedBatchNormGradV3'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'U' value { type: DT_FLOAT } }" \ + " attr { key: 'data_format' value { s: " DATA_FORMAT " } }" \ + " attr { key: 'epsilon' value { f: 0.0001 } }" \ + " attr { key: 'is_training' value { b: true } }" \ + " input: ['A', 'B', 'C', 'D', 'E', 'F'] }" \ + "node { name: 'H' op: 'Zeta' attr { key: 'T' value { type: " #T " } }" \ + " input: ['A', 'G'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT ");C(Float32Input);D(Float32Input);" \ + "E(Float32Input);F(Float32Input);G(FusedBatchNormGradV3);H(Zeta)" \ + "|A->G;A->H;B->G:1;C->G:2;D->G:3;E->G:4;F->G:5;G->H:1"); \ +} +#define DATA_FORMAT "'NCDHW'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_1); + +#define DATA_FORMAT "'NDHWC'" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedBatchNormGradV3_5D_Negative_2); + +#undef DATA_FORMAT +#undef REGISTER_TEST + #ifdef ENABLE_MKLDNN_V1 #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklLayoutPassTest, NAME##_##T) { \ diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 9d9727bdb46..efa1cfb0d3c 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -102,6 +102,16 @@ bool inline NativeFormatEnabled() { return native_fmt_enabled; } +// Check if the data_format attribute in the node def represents 5D tensor +bool inline Check5DFormat(const NodeDef& ndef) { + string data_format; + TF_CHECK_OK(GetNodeAttr(ndef, "data_format", &data_format)); + if (data_format.compare("NCDHW") == 0 || data_format.compare("NDHWC") == 0) { + return true; + } + return false; +} + namespace mkl_op_registry { // MKL operators whose kernels are registered with 'MklLayoutDependentOp' label // (e.g., MklConv2D) understand input tensors in MKL layout. These operators