Use variable to hold the const value
This commit is contained in:
parent
0ef037fde1
commit
22e30fbdbf
@ -555,6 +555,8 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
float leakyrelu_alpha = 0.5;
|
||||
|
||||
std::vector<int> strides = {1, 1, 1, 1};
|
||||
auto conv =
|
||||
ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME");
|
||||
@ -571,7 +573,7 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
} else if (activation == "LeakyRelu") {
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
|
||||
return ops::Identity(
|
||||
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
|
||||
}
|
||||
@ -614,7 +616,7 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
|
||||
if (activation == "LeakyRelu") {
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
|
||||
}
|
||||
found++;
|
||||
}
|
||||
@ -650,6 +652,8 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), lhs, rhs);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
|
||||
float leakyrelu_alpha = 0.5;
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
@ -663,7 +667,7 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
} else if (activation == "Tanh") {
|
||||
return ops::Identity(fetch, ops::Tanh(activate, bias_add));
|
||||
} else if (activation == "LeakyRelu") {
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
|
||||
return ops::Identity(
|
||||
fetch, ops::internal::LeakyRelu(activate, bias_add, attr));
|
||||
}
|
||||
@ -706,7 +710,7 @@ class RemapperFuseMatMulWithBiasAndActivationTest : public RemapperTest {
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
|
||||
if (activation == "LeakyRelu") {
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
|
||||
}
|
||||
|
||||
found++;
|
||||
@ -843,6 +847,8 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||
auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv,
|
||||
scale, offset, mean, variance, attrs);
|
||||
|
||||
float leakyrelu_alpha = 0.5;
|
||||
|
||||
ops::Identity fetch = [&]() -> ops::Identity {
|
||||
auto activate = s.WithOpName("activation");
|
||||
auto fetch = s.WithOpName("fetch");
|
||||
@ -854,7 +860,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, batch_norm.y));
|
||||
} else if (activation == "LeakyRelu") {
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(0.5);
|
||||
auto attr = ops::internal::LeakyRelu::Alpha(leakyrelu_alpha);
|
||||
return ops::Identity(
|
||||
fetch, ops::internal::LeakyRelu(activate, batch_norm.y, attr));
|
||||
}
|
||||
@ -905,7 +911,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||
EXPECT_EQ(fused_ops[1], activation);
|
||||
|
||||
if (activation == "LeakyRelu") {
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), 0.5);
|
||||
EXPECT_EQ(node.attr().at("leakyrelu_alpha").f(), leakyrelu_alpha);
|
||||
}
|
||||
found++;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user