Enable MKL Matmul + Bias + LeakyRelu Fusion

This commit is contained in:
yunfeima 2020-09-21 14:37:37 +08:00
parent c97e809422
commit 7c0164495c
8 changed files with 87 additions and 7 deletions

View File

@ -2033,6 +2033,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Positive)
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Negative);
#undef REGISTER_TEST
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph( \
"node { name: 'A' op: '" #INPUT "'}" \
"node { name: 'B' op: '" #INPUT "'}" \
"node { name: 'C' op: '" #INPUT "'}" \
"node { name: 'D' op: '_FusedMatMul'" \
" attr { key: 'T' value { type: " #T "} }" \
" attr { key: 'transpose_a' value { b: false } }" \
" attr { key: 'transpose_b' value { b: false } }" \
" attr { key: 'num_args' value { i: 1 } }" \
" attr { key: 'fused_ops'" \
" value { list: {s: 'BiasAdd', s: 'LeakyRelu'} } }" \
" attr { key: 'epsilon' value { f: 0.001 }}" \
" attr { key: 'leakyrelu_alpha' value { f: 0.3 }}" \
" input: ['A', 'B', 'C']}" \
"node { name: 'Z' op: 'Zeta'" \
" attr {key: 'T' value { type: " #T " } }" \
" input: ['D', 'C']}"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedMatMul);" \
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);Z(Zeta)" \
"|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
"A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;" \
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_LeakyRelu_Positive);
#undef REGISTER_TEST
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
// padding is VALID type
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,

View File

@ -471,8 +471,10 @@ bool FindContractionWithBiasAndActivation(
// Currently, only matmul + bias + tanh is enable
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
// Currently, only conv + bias + leakyrelu is enabled
if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
// Currently, only (conv | matmul) + bias + leakyrelu is enabled
if ((!IsConv2D(*contraction_node_def) && !IsMatMul(*contraction_node_def)) &&
IsLeakyRelu(*node_def))
return false;
// Check that data type and data format are supported on assigned device.
const ContractionWithBiasAddAndActivation pattern{base.contraction,
@ -1028,7 +1030,8 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
}
}
void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul) {
void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
const NodeDef* activation = nullptr) {
DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
auto* attr = fused_matmul->mutable_attr();
@ -1037,6 +1040,11 @@ void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul) {
(*attr)["T"] = src_attr.at("T");
(*attr)["transpose_a"] = src_attr.at("transpose_a");
(*attr)["transpose_b"] = src_attr.at("transpose_b");
// Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha
if (activation != nullptr && IsLeakyRelu(*activation)) {
auto& activation_attr = activation->attr();
(*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
}
}
void SetFusedOpAttributes(NodeDef* fused,
@ -1125,7 +1133,7 @@ Status AddFusedContractionNode(
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
} else if (IsMatMul(contraction)) {
fused_op.set_op(kFusedMatMul);
CopyMatMulAttributes(contraction, &fused_op);
CopyMatMulAttributes(contraction, &fused_op, &activation);
}
SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()});

View File

@ -635,7 +635,8 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
void RunTest() {
using ::tensorflow::ops::Placeholder;
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
for (const string& activation :
{"Relu", "Relu6", "Elu", "Tanh", "LeakyRelu"}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto lhs_shape = ops::Placeholder::Shape({8, 32});
@ -659,6 +660,12 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
} else if (activation == "Tanh") {
return ops::Identity(fetch, ops::Tanh(activate, bias_add));
} else if (activation == "LeakyRelu") {
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
return ops::Identity(
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
}
return ops::Identity(fetch, bias);
@ -697,6 +704,11 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
ASSERT_EQ(fused_ops.size(), 2);
EXPECT_EQ(fused_ops[0], "BiasAdd");
EXPECT_EQ(fused_ops[1], activation);
if (activation == "LeakyRelu") {
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
}
found++;
}
}

View File

@ -1073,6 +1073,13 @@ class MklFusedMatMulOpTest : public OpsTestBase {
next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
}
if (std::find(fused_ops.begin(), fused_ops.end(), "LeakyRelu") !=
fused_ops.end()) {
last_op = "with_leakyrelu";
next_op =
ops::internal::LeakyRelu(root.WithOpName(last_op), next_op);
}
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
};
@ -1148,12 +1155,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
{"BiasAdd", "Add"});
}
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndLeakyRelu) {
const int batch = 3;
const int input_channel = 4;
const int output_channel = 5;
this->VerifyFusedMatMul(batch, input_channel, output_channel,
{"BiasAdd", "LeakyRelu"});
}
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
WithBias, //
WithBiasAndRelu, //
WithBiasAndRelu6, //
WithBiasAndElu, //
WithBiasAndTanh, //
WithBiasAndLeakyRelu, //
WithBiasAndAdd);
using MklFusedMatMulDataTypes = ::testing::Types<float>;

View File

@ -49,6 +49,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
OP_REQUIRES(
ctx, transpose_a_ == false,
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
if (fused_ops_.size() == 2 && fused_ops_[1] == "LeakyRelu") {
OP_REQUIRES_OK(ctx, ctx->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
}
}
void Compute(OpKernelContext* ctx) override {
@ -287,6 +290,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
} else if (post_op == "Add") {
params.post_op_params.push_back({"sum", {1.0}});
} else if (post_op == "LeakyRelu") {
params.post_op_params.push_back(
{"leakyrelu", {1.0, leakyrelu_alpha, 0.0}});
} else {
OP_REQUIRES_OK(
ctx, errors::InvalidArgument(
@ -299,6 +305,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
bool fuse_add_ = false;
bool transpose_a_;
bool transpose_b_;
float leakyrelu_alpha = 0.2;
std::vector<string> fused_ops_;
const int kInputIndex_Add = 3;
const int kOutputIndex_Dst = 0;

View File

@ -204,7 +204,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
mkldnn::post_ops post_ops;
if (!post_op_params.empty()) {
for (auto const& post_op_param : post_op_params) {
if (post_op_param.name == "relu") {
if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
DCHECK_EQ(post_op_param.param.size(), 3);
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
@ -249,6 +249,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
(post_op_param.name == "elu") ||
(post_op_param.name == "tanh") ||
(post_op_param.name == "sum") ||
(post_op_param.name == "leakyrelu") ||
(post_op_param.name == "output_scale"));
}
}
@ -342,7 +343,8 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
// Generate keys for post-ops
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
post_op_param.name == "elu" || post_op_param.name == "tanh") {
post_op_param.name == "elu" || post_op_param.name == "tanh" ||
post_op_param.name == "leakyrelu") {
DCHECK_EQ(post_op_param.param.size(), 3);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);

View File

@ -954,6 +954,8 @@ REGISTER_OP("_FusedMatMul")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- //
.Attr("epsilon: float = 0.0001")
// Attributes for the LeakyRelu ----------------------------------------- //
.Attr("leakyrelu_alpha: float = 0.2")
// --------------------------------------------- //
.SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc(

View File

@ -295,6 +295,8 @@ REGISTER_OP("_MklFusedMatMul")
.Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- //
.Attr("epsilon: float = 0.0001")
// Attributes for the LeakyRelu ----------------------------------------- //
.Attr("leakyrelu_alpha: float = 0.2")
// --------------------------------------------- //
.SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc(