Improve floating-point NMS tests to use smaller error thresholds

PiperOrigin-RevId: 316717898
Change-Id: Iab097dcf4ac3feca17c6d54ad84a2437341d0bb3
This commit is contained in:
Sachin Joglekar 2020-06-16 11:08:54 -07:00 committed by TensorFlower Gardener
parent 9a0838f66a
commit 30357d1d9b

View File

@ -174,17 +174,17 @@ TEST(DetectionPostprocessOpTest, FloatTest) {
std::vector<int> output_shape2 = m.GetOutputShape2();
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput2<float>(),
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
// detection_scores
std::vector<int> output_shape3 = m.GetOutputShape3();
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput3<float>(),
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
// num_detections
std::vector<int> output_shape4 = m.GetOutputShape4();
EXPECT_THAT(output_shape4, ElementsAre(1));
EXPECT_THAT(m.GetOutput4<float>(),
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
}
TEST(DetectionPostprocessOpTest, QuantizedTest) {
@ -385,17 +385,17 @@ TEST(DetectionPostprocessOpTest, FloatTestFastNMS) {
std::vector<int> output_shape2 = m.GetOutputShape2();
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput2<float>(),
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
// detection_scores
std::vector<int> output_shape3 = m.GetOutputShape3();
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput3<float>(),
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
// num_detections
std::vector<int> output_shape4 = m.GetOutputShape4();
EXPECT_THAT(output_shape4, ElementsAre(1));
EXPECT_THAT(m.GetOutput4<float>(),
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
}
TEST(DetectionPostprocessOpTest, QuantizedTestFastNMS) {
@ -492,22 +492,22 @@ TEST(DetectionPostprocessOpTest, FloatTestRegularNMS) {
EXPECT_THAT(m.GetOutput1<float>(),
ElementsAreArray(ArrayFloatNear({0.0, 10.0, 1.0, 11.0, 0.0, 10.0,
1.0, 11.0, 0.0, 0.0, 0.0, 0.0},
3e-1)));
3e-4)));
// detection_classes
std::vector<int> output_shape2 = m.GetOutputShape2();
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput2<float>(),
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
// detection_scores
std::vector<int> output_shape3 = m.GetOutputShape3();
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput3<float>(),
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({0.95, 0.93, 0.0}, 1e-4)));
// num_detections
std::vector<int> output_shape4 = m.GetOutputShape4();
EXPECT_THAT(output_shape4, ElementsAre(1));
EXPECT_THAT(m.GetOutput4<float>(),
ElementsAreArray(ArrayFloatNear({2.0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({2.0}, 1e-4)));
}
TEST(DetectionPostprocessOpTest, QuantizedTestRegularNMS) {
@ -666,17 +666,17 @@ TEST(DetectionPostprocessOpTest, FloatTestwithBackgroundClassAndKeypoints) {
std::vector<int> output_shape2 = m.GetOutputShape2();
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput2<float>(),
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
// detection_scores
std::vector<int> output_shape3 = m.GetOutputShape3();
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput3<float>(),
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
// num_detections
std::vector<int> output_shape4 = m.GetOutputShape4();
EXPECT_THAT(output_shape4, ElementsAre(1));
EXPECT_THAT(m.GetOutput4<float>(),
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
}
TEST(DetectionPostprocessOpTest,
@ -780,17 +780,17 @@ TEST(DetectionPostprocessOpTest, FloatTestwithNoBackgroundClassAndKeypoints) {
std::vector<int> output_shape2 = m.GetOutputShape2();
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput2<float>(),
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
// detection_scores
std::vector<int> output_shape3 = m.GetOutputShape3();
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
EXPECT_THAT(m.GetOutput3<float>(),
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
// num_detections
std::vector<int> output_shape4 = m.GetOutputShape4();
EXPECT_THAT(output_shape4, ElementsAre(1));
EXPECT_THAT(m.GetOutput4<float>(),
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
}
} // namespace
} // namespace custom