Add Native support for fusedmatmul with leakyrelu

This commit is contained in:
yunfeima 2020-12-21 18:01:42 +08:00
parent 22e30fbdbf
commit dfb45cd2ea

View File

@ -293,11 +293,11 @@ REGISTER_OP("_MklFusedMatMul")
.Attr("T: {bfloat16, float}") .Attr("T: {bfloat16, float}")
.Attr("num_args: int >= 0") .Attr("num_args: int >= 0")
.Attr("fused_ops: list(string) = []") .Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- // // Attributes for the FusedBatchNorm ------------------------------------ //
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
// Attributes for the LeakyRelu ----------------------------------------- // // Attributes for the LeakyRelu ----------------------------------------- //
.Attr("leakyrelu_alpha: float = 0.2") .Attr("leakyrelu_alpha: float = 0.2")
// --------------------------------------------- // // ---------------------------------------------------------------------- //
.SetShapeFn(shape_inference::MatMulShape) .SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc( .Doc(R"doc(
MKL version of FusedMatMul operator. Uses MKL-DNN APIs to implement MatMul MKL version of FusedMatMul operator. Uses MKL-DNN APIs to implement MatMul
@ -318,9 +318,11 @@ REGISTER_OP("_MklNativeFusedMatMul")
.Attr("T: {bfloat16, float}") .Attr("T: {bfloat16, float}")
.Attr("num_args: int >= 0") .Attr("num_args: int >= 0")
.Attr("fused_ops: list(string) = []") .Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- // // Attributes for the FusedBatchNorm ------------------------------------ //
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
// --------------------------------------------- // // Attributes for the LeakyRelu ----------------------------------------- //
.Attr("leakyrelu_alpha: float = 0.2")
// ---------------------------------------------------------------------- //
.SetShapeFn(shape_inference::MatMulShape) .SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc( .Doc(R"doc(
oneDNN version of FusedMatMul operator that does not depend oneDNN version of FusedMatMul operator that does not depend