From 0ef037fde1c80856219c907c8cbca8d6aee67a3d Mon Sep 17 00:00:00 2001 From: yunfeima Date: Mon, 21 Sep 2020 15:22:27 +0800 Subject: [PATCH] Enable Eigen MatMul + Bias + LeakyRelu fusion --- tensorflow/core/kernels/BUILD | 1 + tensorflow/core/kernels/matmul_op_fused.cc | 22 +++++++++++++++++----- tensorflow/core/kernels/matmul_op_test.cc | 11 +++++++---- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 2ea2a8aff85..a21366d8929 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3551,6 +3551,7 @@ tf_cuda_cc_test( ":ops_util", ":quantized_ops", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:client_session", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index 9ba9ed6c63f..b24797da901 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -86,7 +86,12 @@ struct LaunchFusedMatMulOp { BiasAddArgs bias_add_args; if (BiasAddArgs::IsSupported(fusion)) { - OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args)); + if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args, + &fusion_args.leakyrelu_alpha)); + } else { + OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args)); + } } switch (fusion) { @@ -102,6 +107,10 @@ struct LaunchFusedMatMulOp { case FusedComputationType::kBiasAddWithElu: executeWithOutputKernel(WithBiasAddAndElu(bias_add_args)); break; + case FusedComputationType::kBiasAddWithLeakyRelu: + out.device(d) = lhs.contract(rhs, dim_pair, + WithBiasAddAndLeakyRelu(bias_add_args)); + break; case FusedComputationType::kUndefined: OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined")); break; @@ -148,10 +157,13 @@ class FusedMatMulOp : public OpKernel { using FCT = FusedComputationType; if (std::is_same::value) { - patterns = {{FCT::kBiasAdd, {"BiasAdd"}}, - {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, - {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, - {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}}; + patterns = { + {FCT::kBiasAdd, {"BiasAdd"}}, + {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}, + {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}}, + {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}, + {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}}, + }; } OP_REQUIRES_OK(context, InitializeFusedComputation( diff --git a/tensorflow/core/kernels/matmul_op_test.cc b/tensorflow/core/kernels/matmul_op_test.cc index 4f986e34acd..a18ec3916ba 100644 --- a/tensorflow/core/kernels/matmul_op_test.cc +++ b/tensorflow/core/kernels/matmul_op_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "absl/algorithm/container.h" +#include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/tensor.h" @@ -137,6 +138,8 @@ class FusedMatMulOpTest : public OpsTestBase { ops::Relu6(root.WithOpName("with_activation"), with_bias); } else if (activation_type == "Elu") { ops::Elu(root.WithOpName("with_activation"), with_bias); + } else if (activation_type == "LeakyRelu") { + ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias); } else { ops::Identity(root.WithOpName("with_activation"), with_bias); } @@ -291,7 +294,7 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1) { } TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) { - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(256, 256, 256, false, false, activation); this->VerifyConv2DWithBiasAndActivation(256, 256, 256, true, false, @@ -304,21 +307,21 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) { } TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) { - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(1, 256, 256, false, false, activation); } } TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) { - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(256, 256, 1, false, false, activation); } } TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) { - for (const string& activation : {"Relu", "Relu6", "Elu"}) { + for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) { this->VerifyConv2DWithBiasAndActivation(1, 256, 1, false, false, activation); }