Enable Eigen MatMul + Bias + LeakyRelu fusion

This commit is contained in:
yunfeima 2020-09-21 15:22:27 +08:00
parent 7c0164495c
commit 0ef037fde1
3 changed files with 25 additions and 9 deletions

View File

@ -3551,6 +3551,7 @@ tf_cuda_cc_test(
":ops_util",
":quantized_ops",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:client_session",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -86,7 +86,12 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
BiasAddArgs<T> bias_add_args;
if (BiasAddArgs<T>::IsSupported(fusion)) {
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
&fusion_args.leakyrelu_alpha));
} else {
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
}
}
switch (fusion) {
@ -102,6 +107,10 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
case FusedComputationType::kBiasAddWithElu:
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
break;
case FusedComputationType::kBiasAddWithLeakyRelu:
out.device(d) = lhs.contract(rhs, dim_pair,
WithBiasAddAndLeakyRelu<T>(bias_add_args));
break;
case FusedComputationType::kUndefined:
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
break;
@ -148,10 +157,13 @@ class FusedMatMulOp : public OpKernel {
using FCT = FusedComputationType;
if (std::is_same<Device, CPUDevice>::value) {
patterns = {{FCT::kBiasAdd, {"BiasAdd"}},
{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
{FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
{FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}};
patterns = {
{FCT::kBiasAdd, {"BiasAdd"}},
{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
{FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
{FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
{FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
};
}
OP_REQUIRES_OK(context, InitializeFusedComputation(

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "absl/algorithm/container.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h"
@ -137,6 +138,8 @@ class FusedMatMulOpTest : public OpsTestBase {
ops::Relu6(root.WithOpName("with_activation"), with_bias);
} else if (activation_type == "Elu") {
ops::Elu(root.WithOpName("with_activation"), with_bias);
} else if (activation_type == "LeakyRelu") {
ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias);
} else {
ops::Identity(root.WithOpName("with_activation"), with_bias);
}
@ -291,7 +294,7 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1) {
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) {
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(256, 256, 256, false, false,
activation);
this->VerifyConv2DWithBiasAndActivation(256, 256, 256, true, false,
@ -304,21 +307,21 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) {
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) {
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(1, 256, 256, false, false,
activation);
}
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) {
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(256, 256, 1, false, false,
activation);
}
}
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) {
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
this->VerifyConv2DWithBiasAndActivation(1, 256, 1, false, false,
activation);
}