From 40bd5a4d991cff4d6a577146278ee9f0e82e9a7c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 16 Dec 2020 10:04:29 -0800
Subject: [PATCH] Relax ValidateBoxes check to tolerate degenerated boxes in
 detection post-proc

It could happen that some boxes have the same value for xmin and xmax, and for
ymin and ymax.

PiperOrigin-RevId: 347844274
Change-Id: Ib44aca22fb537deb04367eaecabed48fc01080c2
---
 .../lite/kernels/detection_postprocess.cc     |  6 +-
 .../kernels/detection_postprocess_test.cc     | 64 +++++++++++++++++++
 2 files changed, 68 insertions(+), 2 deletions(-)

diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc
index 4f1040eabf7..f746ad19498 100644
--- a/tensorflow/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/lite/kernels/detection_postprocess.cc
@@ -382,9 +382,11 @@ void SelectDetectionsAboveScoreThreshold(const std::vector<float>& values,
 
 bool ValidateBoxes(const TfLiteTensor* decoded_boxes, const int num_boxes) {
   for (int i = 0; i < num_boxes; ++i) {
-    // ymax>=ymin, xmax>=xmin
     auto& box = ReInterpretTensor<const BoxCornerEncoding*>(decoded_boxes)[i];
-    if (box.ymin >= box.ymax || box.xmin >= box.xmax) {
+    // Note: `ComputeIntersectionOverUnion` properly handles degenerated boxes
+    // (xmin == xmax and/or ymin == ymax) as it just returns 0 in case the box
+    // area is <= 0.
+    if (box.ymin > box.ymax || box.xmin > box.xmax) {
       return false;
     }
   }
diff --git a/tensorflow/lite/kernels/detection_postprocess_test.cc b/tensorflow/lite/kernels/detection_postprocess_test.cc
index b9c42e75f21..4f73098e555 100644
--- a/tensorflow/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/lite/kernels/detection_postprocess_test.cc
@@ -187,6 +187,70 @@ TEST(DetectionPostprocessOpTest, FloatTest) {
               ElementsAreArray(ArrayFloatNear({3.0}, 1e-4)));
 }
 
+// Tests the case when a box degenerates to a point (xmin==xmax, ymin==ymax).
+TEST(DetectionPostprocessOpTest, FloatTestWithDegeneratedBox) {
+  BaseDetectionPostprocessOpModel m(
+      {TensorType_FLOAT32, {1, 2, 4}}, {TensorType_FLOAT32, {1, 2, 3}},
+      {TensorType_FLOAT32, {2, 4}}, {TensorType_FLOAT32, {}},
+      {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
+      {TensorType_FLOAT32, {}});
+
+  // two boxes in center-size encoding
+  m.SetInput1<float>({
+      0.0, 0.0, 0.0, 0.0,  // box #1
+      0.0, 0.0, 0.0, 0.0,  // box #2
+  });
+  // class scores - two classes with background
+  m.SetInput2<float>({
+      /*background*/ 0., /*class 0*/ .9, /*class 1*/ .8,  // box #1
+      /*background*/ 0., /*class 0*/ .2, /*class 1*/ .7   // box #2
+  });
+  // two anchors in center-size encoding
+  m.SetInput3<float>({
+      0.5, 0.5, 1.0, 1.0,  // anchor #1
+      0.5, 0.5, 0.0, 0.0   // anchor #2 - DEGENERATED!
+  });
+  // Same boxes in box-corner encoding:
+  // { 0.0, 0.0, 1.0, 1.0,
+  //   0.5, 0.5, 0.5, 0.5} // DEGENERATED!
+  // NOTE: this is used instead of `m.Invoke()` to make sure the entire test
+  // gets aborted if an error occurs (which does not happen when e.g. ASSERT_EQ
+  // is used in such a helper function).
+  ASSERT_EQ(m.InvokeUnchecked(), kTfLiteOk);
+  // num_detections
+  std::vector<int> output_shape4 = m.GetOutputShape4();
+  EXPECT_THAT(output_shape4, ElementsAre(1));
+  const int num_detections = static_cast<int>(m.GetOutput4<float>()[0]);
+  EXPECT_EQ(num_detections, 2);
+  // detection_boxes
+  std::vector<int> output_shape1 = m.GetOutputShape1();
+  // NOTE: there are up to 3 detected boxes as per `max_detections` and
+  // `max_classes_per_detection` parameters. But since the actual number of
+  // detections is 2 (see above) only the top-2 results are tested
+  // here and below.
+  EXPECT_THAT(output_shape1, ElementsAre(1, 3, 4));
+  std::vector<float> detection_boxes = m.GetOutput1<float>();
+  detection_boxes.resize(num_detections * 4);
+  EXPECT_THAT(detection_boxes,
+              ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 1.0,   // box #1
+                                               0.5, 0.5, 0.5, 0.5},  // box #2
+                                              1e-1)));
+  // detection_classes
+  std::vector<int> output_shape2 = m.GetOutputShape2();
+  EXPECT_THAT(output_shape2, ElementsAre(1, 3));
+  std::vector<float> detection_classes = m.GetOutput2<float>();
+  detection_classes.resize(num_detections);
+  EXPECT_THAT(detection_classes,
+              ElementsAreArray(ArrayFloatNear({0, 1}, 1e-4)));
+  // detection_scores
+  std::vector<int> output_shape3 = m.GetOutputShape3();
+  EXPECT_THAT(output_shape3, ElementsAre(1, 3));
+  std::vector<float> detection_scores = m.GetOutput3<float>();
+  detection_scores.resize(num_detections);
+  EXPECT_THAT(detection_scores,
+              ElementsAreArray(ArrayFloatNear({0.9, 0.7}, 1e-4)));
+}
+
 TEST(DetectionPostprocessOpTest, QuantizedTest) {
   BaseDetectionPostprocessOpModel m(
       {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},