From 7c0164495c9a94c0babdd3efd7670c315b2d98f5 Mon Sep 17 00:00:00 2001 From: yunfeima Date: Mon, 21 Sep 2020 14:37:37 +0800 Subject: [PATCH] Enable MKL Matmul + Bias + LeakyRelu Fusion --- .../common_runtime/mkl_layout_pass_test.cc | 30 +++++++++++++++++++ .../core/grappler/optimizers/remapper.cc | 16 +++++++--- .../core/grappler/optimizers/remapper_test.cc | 14 ++++++++- .../core/kernels/mkl/mkl_fused_ops_test.cc | 17 +++++++++++ .../core/kernels/mkl/mkl_matmul_op_fused.cc | 7 +++++ .../core/kernels/mkl/mkl_matmul_ops_common.h | 6 ++-- tensorflow/core/ops/math_ops.cc | 2 ++ tensorflow/core/ops/mkl_nn_ops.cc | 2 ++ 8 files changed, 87 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc index fda5ad93352..7ae254366a4 100644 --- a/tensorflow/core/common_runtime/mkl_layout_pass_test.cc +++ b/tensorflow/core/common_runtime/mkl_layout_pass_test.cc @@ -2033,6 +2033,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Positive) REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Negative); #undef REGISTER_TEST +// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: '" #INPUT "'}" \ + "node { name: 'C' op: '" #INPUT "'}" \ + "node { name: 'D' op: '_FusedMatMul'" \ + " attr { key: 'T' value { type: " #T "} }" \ + " attr { key: 'transpose_a' value { b: false } }" \ + " attr { key: 'transpose_b' value { b: false } }" \ + " attr { key: 'num_args' value { i: 1 } }" \ + " attr { key: 'fused_ops'" \ + " value { list: {s: 'BiasAdd', s: 'LeakyRelu'} } }" \ + " attr { key: 'epsilon' value { f: 0.001 }}" \ + " attr { key: 'leakyrelu_alpha' value { f: 0.3 }}" \ + " input: ['A', 'B', 'C']}" \ + "node { name: 'Z' op: 'Zeta'" \ + " attr {key: 'T' value { type: " #T " } }" \ + " input: ['D', 'C']}"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedMatMul);" \ + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);Z(Zeta)" \ + "|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;" \ + "A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;" \ + "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \ +} +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_LeakyRelu_Positive); +#undef REGISTER_TEST + // Merge test for PadWithFusedConv2D Op with BiasAdd fusion // padding is VALID type // A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D, diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index d7705e91f52..88011daa291 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -471,8 +471,10 @@ bool FindContractionWithBiasAndActivation( // Currently, only matmul + bias + tanh is enable if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false; - // Currently, only conv + bias + leakyrelu is enabled - if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false; + // Currently, only (conv | matmul) + bias + leakyrelu is enabled + if ((!IsConv2D(*contraction_node_def) && !IsMatMul(*contraction_node_def)) && + IsLeakyRelu(*node_def)) + return false; // Check that data type and data format are supported on assigned device. const ContractionWithBiasAddAndActivation pattern{base.contraction, @@ -1028,7 +1030,8 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm, } } -void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul) { +void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul, + const NodeDef* activation = nullptr) { DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul"; auto* attr = fused_matmul->mutable_attr(); @@ -1037,6 +1040,11 @@ void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul) { (*attr)["T"] = src_attr.at("T"); (*attr)["transpose_a"] = src_attr.at("transpose_a"); (*attr)["transpose_b"] = src_attr.at("transpose_b"); + // Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha + if (activation != nullptr && IsLeakyRelu(*activation)) { + auto& activation_attr = activation->attr(); + (*attr)["leakyrelu_alpha"] = activation_attr.at("alpha"); + } } void SetFusedOpAttributes(NodeDef* fused, @@ -1125,7 +1133,7 @@ Status AddFusedContractionNode( CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op); } else if (IsMatMul(contraction)) { fused_op.set_op(kFusedMatMul); - CopyMatMulAttributes(contraction, &fused_op); + CopyMatMulAttributes(contraction, &fused_op, &activation); } SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()}); diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 2aa564104db..20bbf3c62bd 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -635,7 +635,8 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest { void RunTest() { using ::tensorflow::ops::Placeholder; - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : + {"Relu", "Relu6", "Elu", "Tanh", "LeakyRelu"}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); auto lhs_shape = ops::Placeholder::Shape({8, 32}); @@ -659,6 +660,12 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest { return ops::Identity(fetch, ops::Relu6(activate, bias_add)); } else if (activation == "Elu") { return ops::Identity(fetch, ops::Elu(activate, bias_add)); + } else if (activation == "Tanh") { + return ops::Identity(fetch, ops::Tanh(activate, bias_add)); + } else if (activation == "LeakyRelu") { + auto attr = ops::internal::LeakyRelu::Alpha(0.5); + return ops::Identity( + fetch, ops::internal::LeakyRelu(activate, bias_add, attr)); } return ops::Identity(fetch, bias); @@ -697,6 +704,11 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest { ASSERT_EQ(fused_ops.size(), 2); EXPECT_EQ(fused_ops[0], "BiasAdd"); EXPECT_EQ(fused_ops[1], activation); + + if (activation == "LeakyRelu") { + EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5); + } + found++; } } diff --git a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc index 9bb26535cbf..58c7f2c4734 100644 --- a/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl/mkl_fused_ops_test.cc @@ -1073,6 +1073,13 @@ class MklFusedMatMulOpTest : public OpsTestBase { next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op); } + if (std::find(fused_ops.begin(), fused_ops.end(), "LeakyRelu") != + fused_ops.end()) { + last_op = "with_leakyrelu"; + next_op = + ops::internal::LeakyRelu(root.WithOpName(last_op), next_op); + } + CommonTestUtilities::RunAndFetch(root, last_op, output); }; @@ -1148,12 +1155,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) { {"BiasAdd", "Add"}); } +TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndLeakyRelu) { + const int batch = 3; + const int input_channel = 4; + const int output_channel = 5; + + this->VerifyFusedMatMul(batch, input_channel, output_channel, + {"BiasAdd", "LeakyRelu"}); +} + REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, // WithBias, // WithBiasAndRelu, // WithBiasAndRelu6, // WithBiasAndElu, // WithBiasAndTanh, // + WithBiasAndLeakyRelu, // WithBiasAndAdd); using MklFusedMatMulDataTypes = ::testing::Types; diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc index 246efacb615..6b8a280e97b 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc @@ -49,6 +49,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { OP_REQUIRES( ctx, transpose_a_ == false, errors::InvalidArgument("In[0] of MklMatMul can't be transposed.")); + if (fused_ops_.size() == 2 && fused_ops_[1] == "LeakyRelu") { + OP_REQUIRES_OK(ctx, ctx->GetAttr("leakyrelu_alpha", &leakyrelu_alpha)); + } } void Compute(OpKernelContext* ctx) override { @@ -287,6 +290,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}}); } else if (post_op == "Add") { params.post_op_params.push_back({"sum", {1.0}}); + } else if (post_op == "LeakyRelu") { + params.post_op_params.push_back( + {"leakyrelu", {1.0, leakyrelu_alpha, 0.0}}); } else { OP_REQUIRES_OK( ctx, errors::InvalidArgument( @@ -299,6 +305,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { bool fuse_add_ = false; bool transpose_a_; bool transpose_b_; + float leakyrelu_alpha = 0.2; std::vector fused_ops_; const int kInputIndex_Add = 3; const int kOutputIndex_Dst = 0; diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 375047d290f..bf841a8b8fc 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -204,7 +204,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { mkldnn::post_ops post_ops; if (!post_op_params.empty()) { for (auto const& post_op_param : post_op_params) { - if (post_op_param.name == "relu") { + if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") { DCHECK_EQ(post_op_param.param.size(), 3); float op_scale = post_op_param.param[0]; float op_alpha = post_op_param.param[1]; @@ -249,6 +249,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { (post_op_param.name == "elu") || (post_op_param.name == "tanh") || (post_op_param.name == "sum") || + (post_op_param.name == "leakyrelu") || (post_op_param.name == "output_scale")); } } @@ -342,7 +343,8 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { // Generate keys for post-ops for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { if (post_op_param.name == "relu" || post_op_param.name == "relu6" || - post_op_param.name == "elu" || post_op_param.name == "tanh") { + post_op_param.name == "elu" || post_op_param.name == "tanh" || + post_op_param.name == "leakyrelu") { DCHECK_EQ(post_op_param.param.size(), 3); key_creator.AddAsKey(post_op_param.name); key_creator.AddAsKey(post_op_param.param[0]); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index d6fde7248ab..7edf7f7a843 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -954,6 +954,8 @@ REGISTER_OP("_FusedMatMul") .Attr("fused_ops: list(string) = []") // Attributes for the FusedBatchNorm ----------- // .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") // --------------------------------------------- // .SetShapeFn(shape_inference::MatMulShape) .Doc(R"doc( diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 1604527b941..e60f6710d26 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -295,6 +295,8 @@ REGISTER_OP("_MklFusedMatMul") .Attr("fused_ops: list(string) = []") // Attributes for the FusedBatchNorm ----------- // .Attr("epsilon: float = 0.0001") + // Attributes for the LeakyRelu ----------------------------------------- // + .Attr("leakyrelu_alpha: float = 0.2") // --------------------------------------------- // .SetShapeFn(shape_inference::MatMulShape) .Doc(R"doc(