Use variable to hold the const value

This commit is contained in:
yunfeima 2020-11-13 10:18:30 +08:00
parent 0ef037fde1
commit 22e30fbdbf

View File

@ -555,6 +555,8 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
float leakyrelu_alpha = 0.5;
std::vector<int> strides = {1, 1, 1, 1};
auto conv =
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
@ -571,7 +573,7 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, bias_add));
} else if (activation == "LeakyRelu") {
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
return ops::Identity(
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
}
@ -614,7 +616,7 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
EXPECT_EQ(fused_ops[1], activation);
if (activation == "LeakyRelu") {
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
}
found++;
}
@ -650,6 +652,8 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
float leakyrelu_alpha = 0.5;
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
@ -663,7 +667,7 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
} else if (activation == "Tanh") {
return ops::Identity(fetch, ops::Tanh(activate, bias_add));
} else if (activation == "LeakyRelu") {
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
return ops::Identity(
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
}
@ -706,7 +710,7 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
EXPECT_EQ(fused_ops[1], activation);
if (activation == "LeakyRelu") {
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
}
found++;
@ -843,6 +847,8 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv,
scale, offset, mean, variance, attrs);
float leakyrelu_alpha = 0.5;
ops::Identity fetch = [&]() -> ops::Identity {
auto activate = s.WithOpName("activation");
auto fetch = s.WithOpName("fetch");
@ -854,7 +860,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
} else if (activation == "Elu") {
return ops::Identity(fetch, ops::Elu(activate, batch_norm.y));
} else if (activation == "LeakyRelu") {
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
return ops::Identity(
fetch, ops::internal::LeakyRelu(activate, batch_norm.y, attr));
}
@ -905,7 +911,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
EXPECT_EQ(fused_ops[1], activation);
if (activation == "LeakyRelu") {
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
}
found++;
}