Fix code format

This commit is contained in:
yunfeima 2020-12-23 16:47:36 +08:00
parent dfb45cd2ea
commit 04e7e1da0a
4 changed files with 3 additions and 5 deletions

View File

@ -472,7 +472,7 @@ bool FindContractionWithBiasAndActivation(
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false; if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
// Currently, only (conv | matmul) + bias + leakyrelu is enabled // Currently, only (conv | matmul) + bias + leakyrelu is enabled
if ((!IsConv2D(*contraction_node_def) && !IsMatMul(*contraction_node_def)) && if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def)) &&
IsLeakyRelu(*node_def)) IsLeakyRelu(*node_def))
return false; return false;

View File

@ -712,7 +712,6 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
if (activation == "LeakyRelu") { if (activation == "LeakyRelu") {
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha); EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
} }
found++; found++;
} }
} }

View File

@ -108,8 +108,7 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args)); executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
break; break;
case FusedComputationType::kBiasAddWithLeakyRelu: case FusedComputationType::kBiasAddWithLeakyRelu:
out.device(d) = lhs.contract(rhs, dim_pair, executeWithOutputKernel(WithBiasAddAndLeakyRelu<T>(bias_add_args));
WithBiasAddAndLeakyRelu<T>(bias_add_args));
break; break;
case FusedComputationType::kUndefined: case FusedComputationType::kUndefined:
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined")); OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));

View File

@ -954,7 +954,7 @@ REGISTER_OP("_FusedMatMul")
.Attr("fused_ops: list(string) = []") .Attr("fused_ops: list(string) = []")
// Attributes for the FusedBatchNorm ----------- // // Attributes for the FusedBatchNorm ----------- //
.Attr("epsilon: float = 0.0001") .Attr("epsilon: float = 0.0001")
// Attributes for the LeakyRelu ----------------------------------------- // // Attributes for the LeakyRelu ---------------- //
.Attr("leakyrelu_alpha: float = 0.2") .Attr("leakyrelu_alpha: float = 0.2")
// --------------------------------------------- // // --------------------------------------------- //
.SetShapeFn(shape_inference::MatMulShape) .SetShapeFn(shape_inference::MatMulShape)