Fix code format
This commit is contained in:
parent
dfb45cd2ea
commit
04e7e1da0a
@ -472,7 +472,7 @@ bool FindContractionWithBiasAndActivation(
|
||||
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
|
||||
|
||||
// Currently, only (conv | matmul) + bias + leakyrelu is enabled
|
||||
if ((!IsConv2D(*contraction_node_def) && !IsMatMul(*contraction_node_def)) &&
|
||||
if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def)) &&
|
||||
IsLeakyRelu(*node_def))
|
||||
return false;
|
||||
|
||||
|
@ -712,7 +712,6 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
if (activation == "LeakyRelu") {
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
|
||||
}
|
||||
|
||||
found++;
|
||||
}
|
||||
}
|
||||
|
@ -108,8 +108,7 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
|
||||
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
|
||||
break;
|
||||
case FusedComputationType::kBiasAddWithLeakyRelu:
|
||||
out.device(d) = lhs.contract(rhs, dim_pair,
|
||||
WithBiasAddAndLeakyRelu<T>(bias_add_args));
|
||||
executeWithOutputKernel(WithBiasAddAndLeakyRelu<T>(bias_add_args));
|
||||
break;
|
||||
case FusedComputationType::kUndefined:
|
||||
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
|
||||
|
@ -954,7 +954,7 @@ REGISTER_OP("_FusedMatMul")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ----------- //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// Attributes for the LeakyRelu ----------------------------------------- //
|
||||
// Attributes for the LeakyRelu ---------------- //
|
||||
.Attr("leakyrelu_alpha: float = 0.2")
|
||||
// --------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::MatMulShape)
|
||||
|
Loading…
Reference in New Issue
Block a user