diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc index 4f1040eabf7..f746ad19498 100644 --- a/tensorflow/lite/kernels/detection_postprocess.cc +++ b/tensorflow/lite/kernels/detection_postprocess.cc @@ -382,9 +382,11 @@ void SelectDetectionsAboveScoreThreshold(const std::vector& values, bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) { for (int i = 0; i < num_boxes; ++i) { - // ymax>=ymin, xmax>=xmin auto& box = ReInterpretTensor(decoded_boxes)[i]; - if (box.ymin >= box.ymax || box.xmin >= box.xmax) { + // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes + // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box + // area is <= 0. + if (box.ymin > box.ymax || box.xmin > box.xmax) { return false; } } diff --git a/tensorflow/lite/kernels/detection_postprocess_test.cc b/tensorflow/lite/kernels/detection_postprocess_test.cc index b9c42e75f21..4f73098e555 100644 --- a/tensorflow/lite/kernels/detection_postprocess_test.cc +++ b/tensorflow/lite/kernels/detection_postprocess_test.cc @@ -187,6 +187,70 @@ TEST(DetectionPostprocessOpTest, FloatTest) { ElementsAreArray(ArrayFloatNear({3.0}, 1e-4))); } +// Tests the case when a box degenerates to a point (xmin==xmax, ymin==ymax). +TEST(DetectionPostprocessOpTest, FloatTestWithDegeneratedBox) { + BaseDetectionPostprocessOpModel m( + {TensorType_FLOAT32, {1, 2, 4}}, {TensorType_FLOAT32, {1, 2, 3}}, + {TensorType_FLOAT32, {2, 4}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, + {TensorType_FLOAT32, {}}); + + // two boxes in center-size encoding + m.SetInput1({ + 0.0, 0.0, 0.0, 0.0, // box #1 + 0.0, 0.0, 0.0, 0.0, // box #2 + }); + // class scores - two classes with background + m.SetInput2({ + /*background*/ 0., /*class 0*/ .9, /*class 1*/ .8, // box #1 + /*background*/ 0., /*class 0*/ .2, /*class 1*/ .7 // box #2 + }); + // two anchors in center-size encoding + m.SetInput3({ + 0.5, 0.5, 1.0, 1.0, // anchor #1 + 0.5, 0.5, 0.0, 0.0 // anchor #2 - DEGENERATED! + }); + // Same boxes in box-corner encoding: + // { 0.0, 0.0, 1.0, 1.0, + // 0.5, 0.5, 0.5, 0.5} // DEGENERATED! + // NOTE: this is used instead of `m.Invoke()` to make sure the entire test + // gets aborted if an error occurs (which does not happen when e.g. ASSERT_EQ + // is used in such a helper function). + ASSERT_EQ(m.InvokeUnchecked(), kTfLiteOk); + // num_detections + std::vector output_shape4 = m.GetOutputShape4(); + EXPECT_THAT(output_shape4, ElementsAre(1)); + const int num_detections = static_cast(m.GetOutput4()[0]); + EXPECT_EQ(num_detections, 2); + // detection_boxes + std::vector output_shape1 = m.GetOutputShape1(); + // NOTE: there are up to 3 detected boxes as per `max_detections` and + // `max_classes_per_detection` parameters. But since the actual number of + // detections is 2 (see above) only the top-2 results are tested + // here and below. + EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4)); + std::vector detection_boxes = m.GetOutput1(); + detection_boxes.resize(num_detections * 4); + EXPECT_THAT(detection_boxes, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 1.0, // box #1 + 0.5, 0.5, 0.5, 0.5}, // box #2 + 1e-1))); + // detection_classes + std::vector output_shape2 = m.GetOutputShape2(); + EXPECT_THAT(output_shape2, ElementsAre(1, 3)); + std::vector detection_classes = m.GetOutput2(); + detection_classes.resize(num_detections); + EXPECT_THAT(detection_classes, + ElementsAreArray(ArrayFloatNear({0, 1}, 1e-4))); + // detection_scores + std::vector output_shape3 = m.GetOutputShape3(); + EXPECT_THAT(output_shape3, ElementsAre(1, 3)); + std::vector detection_scores = m.GetOutput3(); + detection_scores.resize(num_detections); + EXPECT_THAT(detection_scores, + ElementsAreArray(ArrayFloatNear({0.9, 0.7}, 1e-4))); +} + TEST(DetectionPostprocessOpTest, QuantizedTest) { BaseDetectionPostprocessOpModel m( {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},