Fix code format
This commit is contained in:
parent
dfb45cd2ea
commit
04e7e1da0a
@ -472,7 +472,7 @@ bool FindContractionWithBiasAndActivation(
|
|||||||
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
|
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
|
||||||
|
|
||||||
// Currently, only (conv | matmul) + bias + leakyrelu is enabled
|
// Currently, only (conv | matmul) + bias + leakyrelu is enabled
|
||||||
if ((!IsConv2D(*contraction_node_def) && !IsMatMul(*contraction_node_def)) &&
|
if (!(IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def)) &&
|
||||||
IsLeakyRelu(*node_def))
|
IsLeakyRelu(*node_def))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
@ -712,7 +712,6 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
|||||||
if (activation == "LeakyRelu") {
|
if (activation == "LeakyRelu") {
|
||||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
|
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
|
||||||
}
|
}
|
||||||
|
|
||||||
found++;
|
found++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -108,8 +108,7 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
|
|||||||
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
|
executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
|
||||||
break;
|
break;
|
||||||
case FusedComputationType::kBiasAddWithLeakyRelu:
|
case FusedComputationType::kBiasAddWithLeakyRelu:
|
||||||
out.device(d) = lhs.contract(rhs, dim_pair,
|
executeWithOutputKernel(WithBiasAddAndLeakyRelu<T>(bias_add_args));
|
||||||
WithBiasAddAndLeakyRelu<T>(bias_add_args));
|
|
||||||
break;
|
break;
|
||||||
case FusedComputationType::kUndefined:
|
case FusedComputationType::kUndefined:
|
||||||
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
|
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
|
||||||
|
@ -954,7 +954,7 @@ REGISTER_OP("_FusedMatMul")
|
|||||||
.Attr("fused_ops: list(string) = []")
|
.Attr("fused_ops: list(string) = []")
|
||||||
// Attributes for the FusedBatchNorm ----------- //
|
// Attributes for the FusedBatchNorm ----------- //
|
||||||
.Attr("epsilon: float = 0.0001")
|
.Attr("epsilon: float = 0.0001")
|
||||||
// Attributes for the LeakyRelu ----------------------------------------- //
|
// Attributes for the LeakyRelu ---------------- //
|
||||||
.Attr("leakyrelu_alpha: float = 0.2")
|
.Attr("leakyrelu_alpha: float = 0.2")
|
||||||
// --------------------------------------------- //
|
// --------------------------------------------- //
|
||||||
.SetShapeFn(shape_inference::MatMulShape)
|
.SetShapeFn(shape_inference::MatMulShape)
|
||||||
|
Loading…
Reference in New Issue
Block a user