Enable Eigen MatMul + Bias + LeakyRelu fusion
This commit is contained in:
parent
7c0164495c
commit
0ef037fde1
tensorflow/core/kernels
@ -3551,6 +3551,7 @@ tf_cuda_cc_test(
|
||||
":ops_util",
|
||||
":quantized_ops",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:client_session",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -86,7 +86,12 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
|
||||
|
||||
BiasAddArgs<T> bias_add_args;
|
||||
if (BiasAddArgs<T>::IsSupported(fusion)) {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
|
||||
if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
|
||||
&fusion_args.leakyrelu_alpha));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
|
||||
}
|
||||
}
|
||||
|
||||
switch (fusion) {
|
||||
@ -102,6 +107,10 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
|
||||
case FusedComputationType::kBiasAddWithElu:
|
||||
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
|
||||
break;
|
||||
case FusedComputationType::kBiasAddWithLeakyRelu:
|
||||
out.device(d) = lhs.contract(rhs, dim_pair,
|
||||
WithBiasAddAndLeakyRelu<T>(bias_add_args));
|
||||
break;
|
||||
case FusedComputationType::kUndefined:
|
||||
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
|
||||
break;
|
||||
@ -148,10 +157,13 @@ class FusedMatMulOp : public OpKernel {
|
||||
|
||||
using FCT = FusedComputationType;
|
||||
if (std::is_same<Device, CPUDevice>::value) {
|
||||
patterns = {{FCT::kBiasAdd, {"BiasAdd"}},
|
||||
{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
|
||||
{FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
|
||||
{FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}}};
|
||||
patterns = {
|
||||
{FCT::kBiasAdd, {"BiasAdd"}},
|
||||
{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
|
||||
{FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
|
||||
{FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
|
||||
{FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
|
||||
};
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, InitializeFusedComputation(
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -137,6 +138,8 @@ class FusedMatMulOpTest : public OpsTestBase {
|
||||
ops::Relu6(root.WithOpName("with_activation"), with_bias);
|
||||
} else if (activation_type == "Elu") {
|
||||
ops::Elu(root.WithOpName("with_activation"), with_bias);
|
||||
} else if (activation_type == "LeakyRelu") {
|
||||
ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias);
|
||||
} else {
|
||||
ops::Identity(root.WithOpName("with_activation"), with_bias);
|
||||
}
|
||||
@ -291,7 +294,7 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1) {
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(256, 256, 256, false, false,
|
||||
activation);
|
||||
this->VerifyConv2DWithBiasAndActivation(256, 256, 256, true, false,
|
||||
@ -304,21 +307,21 @@ TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256WithActivation) {
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256WithActivation) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(1, 256, 256, false, false,
|
||||
activation);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1WithActivation) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(256, 256, 1, false, false,
|
||||
activation);
|
||||
}
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1WithActivation) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(1, 256, 1, false, false,
|
||||
activation);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user