Enable MKL Matmul + Bias + LeakyRelu Fusion
This commit is contained in:
parent
c97e809422
commit
7c0164495c
tensorflow/core
common_runtime
grappler/optimizers
kernels/mkl
ops
@ -2033,6 +2033,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Positive)
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_Negative);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests
|
||||
#define REGISTER_TEST(NAME, T, INPUT) \
|
||||
TEST_F(MklLayoutPassTest, NAME##_##T) { \
|
||||
InitGraph( \
|
||||
"node { name: 'A' op: '" #INPUT "'}" \
|
||||
"node { name: 'B' op: '" #INPUT "'}" \
|
||||
"node { name: 'C' op: '" #INPUT "'}" \
|
||||
"node { name: 'D' op: '_FusedMatMul'" \
|
||||
" attr { key: 'T' value { type: " #T "} }" \
|
||||
" attr { key: 'transpose_a' value { b: false } }" \
|
||||
" attr { key: 'transpose_b' value { b: false } }" \
|
||||
" attr { key: 'num_args' value { i: 1 } }" \
|
||||
" attr { key: 'fused_ops'" \
|
||||
" value { list: {s: 'BiasAdd', s: 'LeakyRelu'} } }" \
|
||||
" attr { key: 'epsilon' value { f: 0.001 }}" \
|
||||
" attr { key: 'leakyrelu_alpha' value { f: 0.3 }}" \
|
||||
" input: ['A', 'B', 'C']}" \
|
||||
"node { name: 'Z' op: 'Zeta'" \
|
||||
" attr {key: 'T' value { type: " #T " } }" \
|
||||
" input: ['D', 'C']}"); \
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
|
||||
"A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");D(_MklFusedMatMul);" \
|
||||
"DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);Z(Zeta)" \
|
||||
"|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;" \
|
||||
"A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;" \
|
||||
"DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \
|
||||
}
|
||||
REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedMatMul_LeakyRelu_Positive);
|
||||
#undef REGISTER_TEST
|
||||
|
||||
// Merge test for PadWithFusedConv2D Op with BiasAdd fusion
|
||||
// padding is VALID type
|
||||
// A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D,
|
||||
|
@ -471,8 +471,10 @@ bool FindContractionWithBiasAndActivation(
|
||||
// Currently, only matmul + bias + tanh is enable
|
||||
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
|
||||
|
||||
// Currently, only conv + bias + leakyrelu is enabled
|
||||
if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
|
||||
// Currently, only (conv | matmul) + bias + leakyrelu is enabled
|
||||
if ((!IsConv2D(*contraction_node_def) && !IsMatMul(*contraction_node_def)) &&
|
||||
IsLeakyRelu(*node_def))
|
||||
return false;
|
||||
|
||||
// Check that data type and data format are supported on assigned device.
|
||||
const ContractionWithBiasAddAndActivation pattern{base.contraction,
|
||||
@ -1028,7 +1030,8 @@ void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul) {
|
||||
void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul,
|
||||
const NodeDef* activation = nullptr) {
|
||||
DCHECK(IsMatMul(matmul)) << "Input node must be a MatMul";
|
||||
|
||||
auto* attr = fused_matmul->mutable_attr();
|
||||
@ -1037,6 +1040,11 @@ void CopyMatMulAttributes(const NodeDef& matmul, NodeDef* fused_matmul) {
|
||||
(*attr)["T"] = src_attr.at("T");
|
||||
(*attr)["transpose_a"] = src_attr.at("transpose_a");
|
||||
(*attr)["transpose_b"] = src_attr.at("transpose_b");
|
||||
// Copy LeakyRelu's attr alpha to _FusedMatMul's attr leakyrelu_alpha
|
||||
if (activation != nullptr && IsLeakyRelu(*activation)) {
|
||||
auto& activation_attr = activation->attr();
|
||||
(*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
|
||||
}
|
||||
}
|
||||
|
||||
void SetFusedOpAttributes(NodeDef* fused,
|
||||
@ -1125,7 +1133,7 @@ Status AddFusedContractionNode(
|
||||
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
|
||||
} else if (IsMatMul(contraction)) {
|
||||
fused_op.set_op(kFusedMatMul);
|
||||
CopyMatMulAttributes(contraction, &fused_op);
|
||||
CopyMatMulAttributes(contraction, &fused_op, &activation);
|
||||
}
|
||||
|
||||
SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()});
|
||||
|
@ -635,7 +635,8 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
void RunTest() {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation :
|
||||
{"Relu", "Relu6", "Elu", "Tanh", "LeakyRelu"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto lhs_shape = ops::Placeholder::Shape({8, 32});
|
||||
@ -659,6 +660,12 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
} else if (activation == "Tanh") {
|
||||
return ops::Identity(fetch, ops::Tanh(activate, bias_add));
|
||||
} else if (activation == "LeakyRelu") {
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
|
||||
return ops::Identity(
|
||||
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, bias);
|
||||
@ -697,6 +704,11 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
ASSERT_EQ(fused_ops.size(), 2);
|
||||
EXPECT_EQ(fused_ops[0], "BiasAdd");
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
|
||||
if (activation == "LeakyRelu") {
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
|
||||
}
|
||||
|
||||
found++;
|
||||
}
|
||||
}
|
||||
|
@ -1073,6 +1073,13 @@ class MklFusedMatMulOpTest : public OpsTestBase {
|
||||
next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op);
|
||||
}
|
||||
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "LeakyRelu") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_leakyrelu";
|
||||
next_op =
|
||||
ops::internal::LeakyRelu(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||
};
|
||||
|
||||
@ -1148,12 +1155,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) {
|
||||
{"BiasAdd", "Add"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndLeakyRelu) {
|
||||
const int batch = 3;
|
||||
const int input_channel = 4;
|
||||
const int output_channel = 5;
|
||||
|
||||
this->VerifyFusedMatMul(batch, input_channel, output_channel,
|
||||
{"BiasAdd", "LeakyRelu"});
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
|
||||
WithBias, //
|
||||
WithBiasAndRelu, //
|
||||
WithBiasAndRelu6, //
|
||||
WithBiasAndElu, //
|
||||
WithBiasAndTanh, //
|
||||
WithBiasAndLeakyRelu, //
|
||||
WithBiasAndAdd);
|
||||
|
||||
using MklFusedMatMulDataTypes = ::testing::Types<float>;
|
||||
|
@ -49,6 +49,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
OP_REQUIRES(
|
||||
ctx, transpose_a_ == false,
|
||||
errors::InvalidArgument("In[0] of MklMatMul can't be transposed."));
|
||||
if (fused_ops_.size() == 2 && fused_ops_[1] == "LeakyRelu") {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
@ -287,6 +290,9 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
|
||||
} else if (post_op == "Add") {
|
||||
params.post_op_params.push_back({"sum", {1.0}});
|
||||
} else if (post_op == "LeakyRelu") {
|
||||
params.post_op_params.push_back(
|
||||
{"leakyrelu", {1.0, leakyrelu_alpha, 0.0}});
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, errors::InvalidArgument(
|
||||
@ -299,6 +305,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
bool fuse_add_ = false;
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
float leakyrelu_alpha = 0.2;
|
||||
std::vector<string> fused_ops_;
|
||||
const int kInputIndex_Add = 3;
|
||||
const int kOutputIndex_Dst = 0;
|
||||
|
@ -204,7 +204,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
mkldnn::post_ops post_ops;
|
||||
if (!post_op_params.empty()) {
|
||||
for (auto const& post_op_param : post_op_params) {
|
||||
if (post_op_param.name == "relu") {
|
||||
if (post_op_param.name == "relu" || post_op_param.name == "leakyrelu") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
@ -249,6 +249,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
(post_op_param.name == "elu") ||
|
||||
(post_op_param.name == "tanh") ||
|
||||
(post_op_param.name == "sum") ||
|
||||
(post_op_param.name == "leakyrelu") ||
|
||||
(post_op_param.name == "output_scale"));
|
||||
}
|
||||
}
|
||||
@ -342,7 +343,8 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
// Generate keys for post-ops
|
||||
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
|
||||
if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
|
||||
post_op_param.name == "elu" || post_op_param.name == "tanh") {
|
||||
post_op_param.name == "elu" || post_op_param.name == "tanh" ||
|
||||
post_op_param.name == "leakyrelu") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
|
@ -954,6 +954,8 @@ REGISTER_OP("_FusedMatMul")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ----------- //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// Attributes for the LeakyRelu ----------------------------------------- //
|
||||
.Attr("leakyrelu_alpha: float = 0.2")
|
||||
// --------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::MatMulShape)
|
||||
.Doc(R"doc(
|
||||
|
@ -295,6 +295,8 @@ REGISTER_OP("_MklFusedMatMul")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ----------- //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// Attributes for the LeakyRelu ----------------------------------------- //
|
||||
.Attr("leakyrelu_alpha: float = 0.2")
|
||||
// --------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::MatMulShape)
|
||||
.Doc(R"doc(
|
||||
|
Loading…
Reference in New Issue
Block a user