Improve floating-point NMS tests to use smaller error thresholds
PiperOrigin-RevId: 316717898 Change-Id: Iab097dcf4ac3feca17c6d54ad84a2437341d0bb3
This commit is contained in:
parent
9a0838f66a
commit
30357d1d9b
@ -174,17 +174,17 @@ TEST(DetectionPostprocessOpTest, FloatTest) {
|
|||||||
std::vector<int> output_shape2 = m.GetOutputShape2();
|
std::vector<int> output_shape2 = m.GetOutputShape2();
|
||||||
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput2<float>(),
|
EXPECT_THAT(m.GetOutput2<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
|
||||||
// detection_scores
|
// detection_scores
|
||||||
std::vector<int> output_shape3 = m.GetOutputShape3();
|
std::vector<int> output_shape3 = m.GetOutputShape3();
|
||||||
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput3<float>(),
|
EXPECT_THAT(m.GetOutput3<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
|
||||||
// num_detections
|
// num_detections
|
||||||
std::vector<int> output_shape4 = m.GetOutputShape4();
|
std::vector<int> output_shape4 = m.GetOutputShape4();
|
||||||
EXPECT_THAT(output_shape4, ElementsAre(1));
|
EXPECT_THAT(output_shape4, ElementsAre(1));
|
||||||
EXPECT_THAT(m.GetOutput4<float>(),
|
EXPECT_THAT(m.GetOutput4<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DetectionPostprocessOpTest, QuantizedTest) {
|
TEST(DetectionPostprocessOpTest, QuantizedTest) {
|
||||||
@ -385,17 +385,17 @@ TEST(DetectionPostprocessOpTest, FloatTestFastNMS) {
|
|||||||
std::vector<int> output_shape2 = m.GetOutputShape2();
|
std::vector<int> output_shape2 = m.GetOutputShape2();
|
||||||
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput2<float>(),
|
EXPECT_THAT(m.GetOutput2<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
|
||||||
// detection_scores
|
// detection_scores
|
||||||
std::vector<int> output_shape3 = m.GetOutputShape3();
|
std::vector<int> output_shape3 = m.GetOutputShape3();
|
||||||
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput3<float>(),
|
EXPECT_THAT(m.GetOutput3<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
|
||||||
// num_detections
|
// num_detections
|
||||||
std::vector<int> output_shape4 = m.GetOutputShape4();
|
std::vector<int> output_shape4 = m.GetOutputShape4();
|
||||||
EXPECT_THAT(output_shape4, ElementsAre(1));
|
EXPECT_THAT(output_shape4, ElementsAre(1));
|
||||||
EXPECT_THAT(m.GetOutput4<float>(),
|
EXPECT_THAT(m.GetOutput4<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DetectionPostprocessOpTest, QuantizedTestFastNMS) {
|
TEST(DetectionPostprocessOpTest, QuantizedTestFastNMS) {
|
||||||
@ -492,22 +492,22 @@ TEST(DetectionPostprocessOpTest, FloatTestRegularNMS) {
|
|||||||
EXPECT_THAT(m.GetOutput1<float>(),
|
EXPECT_THAT(m.GetOutput1<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({0.0, 10.0, 1.0, 11.0, 0.0, 10.0,
|
ElementsAreArray(ArrayFloatNear({0.0, 10.0, 1.0, 11.0, 0.0, 10.0,
|
||||||
1.0, 11.0, 0.0, 0.0, 0.0, 0.0},
|
1.0, 11.0, 0.0, 0.0, 0.0, 0.0},
|
||||||
3e-1)));
|
3e-4)));
|
||||||
// detection_classes
|
// detection_classes
|
||||||
std::vector<int> output_shape2 = m.GetOutputShape2();
|
std::vector<int> output_shape2 = m.GetOutputShape2();
|
||||||
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput2<float>(),
|
EXPECT_THAT(m.GetOutput2<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
|
||||||
// detection_scores
|
// detection_scores
|
||||||
std::vector<int> output_shape3 = m.GetOutputShape3();
|
std::vector<int> output_shape3 = m.GetOutputShape3();
|
||||||
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput3<float>(),
|
EXPECT_THAT(m.GetOutput3<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({0.95, 0.93, 0.0}, 1e-4)));
|
||||||
// num_detections
|
// num_detections
|
||||||
std::vector<int> output_shape4 = m.GetOutputShape4();
|
std::vector<int> output_shape4 = m.GetOutputShape4();
|
||||||
EXPECT_THAT(output_shape4, ElementsAre(1));
|
EXPECT_THAT(output_shape4, ElementsAre(1));
|
||||||
EXPECT_THAT(m.GetOutput4<float>(),
|
EXPECT_THAT(m.GetOutput4<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({2.0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({2.0}, 1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DetectionPostprocessOpTest, QuantizedTestRegularNMS) {
|
TEST(DetectionPostprocessOpTest, QuantizedTestRegularNMS) {
|
||||||
@ -666,17 +666,17 @@ TEST(DetectionPostprocessOpTest, FloatTestwithBackgroundClassAndKeypoints) {
|
|||||||
std::vector<int> output_shape2 = m.GetOutputShape2();
|
std::vector<int> output_shape2 = m.GetOutputShape2();
|
||||||
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput2<float>(),
|
EXPECT_THAT(m.GetOutput2<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
|
||||||
// detection_scores
|
// detection_scores
|
||||||
std::vector<int> output_shape3 = m.GetOutputShape3();
|
std::vector<int> output_shape3 = m.GetOutputShape3();
|
||||||
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput3<float>(),
|
EXPECT_THAT(m.GetOutput3<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
|
||||||
// num_detections
|
// num_detections
|
||||||
std::vector<int> output_shape4 = m.GetOutputShape4();
|
std::vector<int> output_shape4 = m.GetOutputShape4();
|
||||||
EXPECT_THAT(output_shape4, ElementsAre(1));
|
EXPECT_THAT(output_shape4, ElementsAre(1));
|
||||||
EXPECT_THAT(m.GetOutput4<float>(),
|
EXPECT_THAT(m.GetOutput4<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(DetectionPostprocessOpTest,
|
TEST(DetectionPostprocessOpTest,
|
||||||
@ -780,17 +780,17 @@ TEST(DetectionPostprocessOpTest, FloatTestwithNoBackgroundClassAndKeypoints) {
|
|||||||
std::vector<int> output_shape2 = m.GetOutputShape2();
|
std::vector<int> output_shape2 = m.GetOutputShape2();
|
||||||
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput2<float>(),
|
EXPECT_THAT(m.GetOutput2<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({1, 0, 0}, 1e-4)));
|
||||||
// detection_scores
|
// detection_scores
|
||||||
std::vector<int> output_shape3 = m.GetOutputShape3();
|
std::vector<int> output_shape3 = m.GetOutputShape3();
|
||||||
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
||||||
EXPECT_THAT(m.GetOutput3<float>(),
|
EXPECT_THAT(m.GetOutput3<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({0.95, 0.9, 0.3}, 1e-4)));
|
||||||
// num_detections
|
// num_detections
|
||||||
std::vector<int> output_shape4 = m.GetOutputShape4();
|
std::vector<int> output_shape4 = m.GetOutputShape4();
|
||||||
EXPECT_THAT(output_shape4, ElementsAre(1));
|
EXPECT_THAT(output_shape4, ElementsAre(1));
|
||||||
EXPECT_THAT(m.GetOutput4<float>(),
|
EXPECT_THAT(m.GetOutput4<float>(),
|
||||||
ElementsAreArray(ArrayFloatNear({3.0}, 1e-1)));
|
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace custom
|
} // namespace custom
|
||||||
|
Loading…
Reference in New Issue
Block a user