Relax ValidateBoxes check to tolerate degenerated boxes in detection post-proc
It could happen that some boxes have the same value for xmin and xmax, and for ymin and ymax. PiperOrigin-RevId: 347844274 Change-Id: Ib44aca22fb537deb04367eaecabed48fc01080c2
This commit is contained in:
parent
01c2fafda5
commit
40bd5a4d99
@ -382,9 +382,11 @@ void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
|
||||
|
||||
bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
|
||||
for (int i = 0; i < num_boxes; ++i) {
|
||||
// ymax>=ymin, xmax>=xmin
|
||||
auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
|
||||
if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
|
||||
// Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes
|
||||
// (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box
|
||||
// area is <= 0.
|
||||
if (box.ymin > box.ymax || box.xmin > box.xmax) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -187,6 +187,70 @@ TEST(DetectionPostprocessOpTest, FloatTest) {
|
||||
ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
|
||||
}
|
||||
|
||||
// Tests the case when a box degenerates to a point (xmin==xmax, ymin==ymax).
|
||||
TEST(DetectionPostprocessOpTest, FloatTestWithDegeneratedBox) {
|
||||
BaseDetectionPostprocessOpModel m(
|
||||
{TensorType_FLOAT32, {1, 2, 4}}, {TensorType_FLOAT32, {1, 2, 3}},
|
||||
{TensorType_FLOAT32, {2, 4}}, {TensorType_FLOAT32, {}},
|
||||
{TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
|
||||
{TensorType_FLOAT32, {}});
|
||||
|
||||
// two boxes in center-size encoding
|
||||
m.SetInput1<float>({
|
||||
0.0, 0.0, 0.0, 0.0, // box #1
|
||||
0.0, 0.0, 0.0, 0.0, // box #2
|
||||
});
|
||||
// class scores - two classes with background
|
||||
m.SetInput2<float>({
|
||||
/*background*/ 0., /*class 0*/ .9, /*class 1*/ .8, // box #1
|
||||
/*background*/ 0., /*class 0*/ .2, /*class 1*/ .7 // box #2
|
||||
});
|
||||
// two anchors in center-size encoding
|
||||
m.SetInput3<float>({
|
||||
0.5, 0.5, 1.0, 1.0, // anchor #1
|
||||
0.5, 0.5, 0.0, 0.0 // anchor #2 - DEGENERATED!
|
||||
});
|
||||
// Same boxes in box-corner encoding:
|
||||
// { 0.0, 0.0, 1.0, 1.0,
|
||||
// 0.5, 0.5, 0.5, 0.5} // DEGENERATED!
|
||||
// NOTE: this is used instead of `m.Invoke()` to make sure the entire test
|
||||
// gets aborted if an error occurs (which does not happen when e.g. ASSERT_EQ
|
||||
// is used in such a helper function).
|
||||
ASSERT_EQ(m.InvokeUnchecked(), kTfLiteOk);
|
||||
// num_detections
|
||||
std::vector<int> output_shape4 = m.GetOutputShape4();
|
||||
EXPECT_THAT(output_shape4, ElementsAre(1));
|
||||
const int num_detections = static_cast<int>(m.GetOutput4<float>()[0]);
|
||||
EXPECT_EQ(num_detections, 2);
|
||||
// detection_boxes
|
||||
std::vector<int> output_shape1 = m.GetOutputShape1();
|
||||
// NOTE: there are up to 3 detected boxes as per `max_detections` and
|
||||
// `max_classes_per_detection` parameters. But since the actual number of
|
||||
// detections is 2 (see above) only the top-2 results are tested
|
||||
// here and below.
|
||||
EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
|
||||
std::vector<float> detection_boxes = m.GetOutput1<float>();
|
||||
detection_boxes.resize(num_detections * 4);
|
||||
EXPECT_THAT(detection_boxes,
|
||||
ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 1.0, // box #1
|
||||
0.5, 0.5, 0.5, 0.5}, // box #2
|
||||
1e-1)));
|
||||
// detection_classes
|
||||
std::vector<int> output_shape2 = m.GetOutputShape2();
|
||||
EXPECT_THAT(output_shape2, ElementsAre(1, 3));
|
||||
std::vector<float> detection_classes = m.GetOutput2<float>();
|
||||
detection_classes.resize(num_detections);
|
||||
EXPECT_THAT(detection_classes,
|
||||
ElementsAreArray(ArrayFloatNear({0, 1}, 1e-4)));
|
||||
// detection_scores
|
||||
std::vector<int> output_shape3 = m.GetOutputShape3();
|
||||
EXPECT_THAT(output_shape3, ElementsAre(1, 3));
|
||||
std::vector<float> detection_scores = m.GetOutput3<float>();
|
||||
detection_scores.resize(num_detections);
|
||||
EXPECT_THAT(detection_scores,
|
||||
ElementsAreArray(ArrayFloatNear({0.9, 0.7}, 1e-4)));
|
||||
}
|
||||
|
||||
TEST(DetectionPostprocessOpTest, QuantizedTest) {
|
||||
BaseDetectionPostprocessOpModel m(
|
||||
{TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},
|
||||
|
Loading…
Reference in New Issue
Block a user