From cb49525a25e865c90f37001da6ca77431fc2df9d Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Tue, 3 Sep 2019 15:34:04 -0700 Subject: [PATCH] Adds builtin ops for NMS & Soft NMS in TFLite PiperOrigin-RevId: 267028434 --- tensorflow/lite/builtin_ops.h | 2 + .../lite/core/api/flatbuffer_conversions.cc | 2 + .../lite/g3doc/guide/ops_compatibility.md | 34 +++ tensorflow/lite/kernels/BUILD | 14 + .../internal/non_max_suppression_test.cc | 21 ++ .../internal/reference/non_max_suppression.h | 4 +- .../lite/kernels/non_max_suppression.cc | 252 ++++++++++++++++++ .../lite/kernels/non_max_suppression_test.cc | 176 ++++++++++++ tensorflow/lite/kernels/register.cc | 6 + tensorflow/lite/schema/schema.fbs | 12 +- tensorflow/lite/schema/schema_generated.h | 240 ++++++++++++++++- 11 files changed, 753 insertions(+), 10 deletions(-) create mode 100644 tensorflow/lite/kernels/non_max_suppression.cc create mode 100644 tensorflow/lite/kernels/non_max_suppression_test.cc diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h index 75bfd9f2f6c..918f7246b05 100644 --- a/tensorflow/lite/builtin_ops.h +++ b/tensorflow/lite/builtin_ops.h @@ -146,6 +146,8 @@ typedef enum { kTfLiteBuiltinHardSwish = 117, kTfLiteBuiltinIf = 118, kTfLiteBuiltinWhile = 119, + kTfLiteBuiltinNonMaxSuppressionV4 = 120, + kTfLiteBuiltinNonMaxSuppressionV5 = 121, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 369c48aa4a3..a379cd82903 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -816,6 +816,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_WHERE: case BuiltinOperator_RANK: case BuiltinOperator_QUANTIZE: + case BuiltinOperator_NON_MAX_SUPPRESSION_V4: + case BuiltinOperator_NON_MAX_SUPPRESSION_V5: break; } return kTfLiteOk; diff --git a/tensorflow/lite/g3doc/guide/ops_compatibility.md b/tensorflow/lite/g3doc/guide/ops_compatibility.md index dd75913e7be..8d7e3baa0fb 100644 --- a/tensorflow/lite/g3doc/guide/ops_compatibility.md +++ b/tensorflow/lite/g3doc/guide/ops_compatibility.md @@ -641,6 +641,40 @@ Outputs { } ``` +**NON_MAX_SUPPRESSION_V4** + +``` +Inputs { + 0: boxes in format [y1, x1, y2, x2] + 1: scores + 2: max number of detections + 3: IOU threshold + 4: score threshold +} +Outputs { + 0: selected indices + 1: number of selected indices +} +``` + +**NON_MAX_SUPPRESSION_V5** + +``` +Inputs { + 0: boxes in format [y1, x1, y2, x2] + 1: scores + 2: max number of detections + 3: IOU threshold + 4: score threshold + 5: soft NMS sigma +} +Outputs { + 0: selected indices + 1: selected scores + 2: number of selected indices +} +``` + **PACK** ``` diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index fe299747105..2de938a40ca 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -466,6 +466,7 @@ cc_library( "mirror_pad.cc", "mul.cc", "neg.cc", + "non_max_suppression.cc", "one_hot.cc", "pack.cc", "pad.cc", @@ -1853,6 +1854,19 @@ cc_test( ], ) +cc_test( + name = "non_max_suppression_test", + size = "small", + srcs = ["non_max_suppression_test.cc"], + deps = [ + ":builtin_ops", + ":test_main", + ":test_util", + "//tensorflow/lite:framework", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/lite/kernels/internal/non_max_suppression_test.cc b/tensorflow/lite/kernels/internal/non_max_suppression_test.cc index 6fb24555e9c..8a778fd388d 100644 --- a/tensorflow/lite/kernels/internal/non_max_suppression_test.cc +++ b/tensorflow/lite/kernels/internal/non_max_suppression_test.cc @@ -300,5 +300,26 @@ TEST(NonMaxSuppression, TestSelectFromThreeClustersWithSoftNMS) { EXPECT_EQ(num_selected_indices, 4); MatchFirstNElements(4, selected_indices, {3, 0, 1, 5}); } + +TEST(NonMaxSuppression, TestNullSelectedScoresOutput) { + // Inputs + std::vector boxes; + std::vector scores; + InitializeCandidates(&boxes, &scores); + const float iou_threshold = 0.5; + const float score_threshold = 0.4; + int max_output_size; + + // Outputs + std::vector selected_indices(6); + int num_selected_indices = -1; + + max_output_size = 100; + reference_ops::NonMaxSuppression( + boxes.data(), kNumBoxes, scores.data(), max_output_size, iou_threshold, + score_threshold, /**sigma=**/ 0.0, selected_indices.data(), + /**selected_scores=**/ nullptr, &num_selected_indices); + EXPECT_EQ(num_selected_indices, 2); +} } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/internal/reference/non_max_suppression.h b/tensorflow/lite/kernels/internal/reference/non_max_suppression.h index 1bb3eb74c05..5d3823788ef 100644 --- a/tensorflow/lite/kernels/internal/reference/non_max_suppression.h +++ b/tensorflow/lite/kernels/internal/reference/non_max_suppression.h @@ -172,7 +172,9 @@ inline void NonMaxSuppression(const float* boxes, const int num_boxes, if (next_candidate.score == original_score) { // Suppression has not occurred, so select next_candidate. selected_indices[*num_selected_indices] = next_candidate.index; - selected_scores[*num_selected_indices] = next_candidate.score; + if (selected_scores) { + selected_scores[*num_selected_indices] = next_candidate.score; + } ++*num_selected_indices; } if (next_candidate.score > score_threshold) { diff --git a/tensorflow/lite/kernels/non_max_suppression.cc b/tensorflow/lite/kernels/non_max_suppression.cc new file mode 100644 index 00000000000..0e0bf9c1246 --- /dev/null +++ b/tensorflow/lite/kernels/non_max_suppression.cc @@ -0,0 +1,252 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/non_max_suppression.h" + +#include + +#include +#include + +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_internal.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace non_max_suppression { + +// Boxes in format [y1, x1, y2, x2]. Shape: [num_boxes, 4] +// Type: Float. +constexpr int kInputTensorBoxes = 0; +// Shape: [num_boxes] +// Type: Float. +constexpr int kInputTensorScores = 1; +// Max number of boxes to output. Actual output can be smaller. +// Type: Int32. +constexpr int kInputTensorMaxOutputSize = 2; +// Type: Float. +constexpr int kInputTensorIouThreshold = 3; +// Type: Float. +constexpr int kInputTensorScoreThreshold = 4; +// Only applies to NON_MAX_SUPPRESSION_V5. +// Type: Float. +constexpr int kInputTensorSigma = 5; + +// Indices of selected boxes. Shape: [num_selected_indices] +// Type: Int32. +constexpr int kNMSOutputTensorSelectedIndices = 0; +// Type: Int32. +constexpr int kNMSOutputTensorNumSelectedIndices = 1; + +// Indices of selected boxes. Shape: [num_selected_indices] +// Type: Int32. +constexpr int kSoftNMSOutputTensorSelectedIndices = 0; +// Scores of selected boxes. Shape: [num_selected_indices] +// Type: Float. +constexpr int kSoftNMSOutputTensorSelectedScores = 1; +// Type: Int32. +constexpr int kSoftNMSOutputTensorNumSelectedIndices = 2; + +TfLiteStatus SetTensorSizes(TfLiteContext* context, TfLiteTensor* tensor, + std::initializer_list values) { + TfLiteIntArray* size = TfLiteIntArrayCreate(values.size()); + int index = 0; + for (const auto& v : values) { + size->data[index++] = v; + } + return context->ResizeTensor(context, tensor, size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const int num_inputs = NumInputs(node); + const bool is_soft_nms = num_inputs == 6; + if (num_inputs != 5 && num_inputs != 6) { + context->ReportError(context, "Found NMS op with invalid num inputs: %d", + NumInputs(node)); + return kTfLiteError; + } + + // Boxes & Scores. + const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes); + TF_LITE_ENSURE_EQ(context, input_boxes->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_boxes), 2); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input_boxes, 1), 4); + const int num_boxes = SizeOfDimension(input_boxes, 0); + const TfLiteTensor* input_scores = + GetInput(context, node, kInputTensorScores); + TF_LITE_ENSURE_EQ(context, input_scores->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_scores), 1); + TF_LITE_ENSURE_EQ(context, num_boxes, SizeOfDimension(input_scores, 0)); + + // Max output size. + const TfLiteTensor* input_max_output_size = + GetInput(context, node, kInputTensorMaxOutputSize); + TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0); + // TODO(b/265135869): Add support for non-constant max_output_size by making + // output dynamic? + if (!IsConstantTensor(input_max_output_size)) { + context->ReportError(context, "Max output size should be a constant"); + return kTfLiteError; + } + int max_output_size_value = *GetTensorData(input_max_output_size); + TF_LITE_ENSURE(context, (max_output_size_value >= 0)); + + // IoU & Score thresholds. + const TfLiteTensor* input_iou_threshold = + GetInput(context, node, kInputTensorIouThreshold); + TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_iou_threshold), 0); + const TfLiteTensor* input_score_threshold = + GetInput(context, node, kInputTensorScoreThreshold); + TF_LITE_ENSURE_EQ(context, input_iou_threshold->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_score_threshold), 0); + + if (is_soft_nms) { + const TfLiteTensor* input_sigma = + GetInput(context, node, kInputTensorSigma); + TF_LITE_ENSURE_EQ(context, input_sigma->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input_sigma), 0); + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 3); + TfLiteTensor* output_selected_indices = + GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices); + output_selected_indices->type = kTfLiteInt32; + SetTensorSizes(context, output_selected_indices, {max_output_size_value}); + TfLiteTensor* output_selected_scores = + GetOutput(context, node, kSoftNMSOutputTensorSelectedScores); + output_selected_scores->type = kTfLiteFloat32; + SetTensorSizes(context, output_selected_scores, {max_output_size_value}); + TfLiteTensor* output_num_selected_indices = + GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices); + output_num_selected_indices->type = kTfLiteInt32; + SetTensorSizes(context, output_num_selected_indices, {}); + } else { + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + TfLiteTensor* output_selected_indices = + GetOutput(context, node, kNMSOutputTensorSelectedIndices); + output_selected_indices->type = kTfLiteInt32; + SetTensorSizes(context, output_selected_indices, {max_output_size_value}); + TfLiteTensor* output_num_selected_indices = + GetOutput(context, node, kNMSOutputTensorNumSelectedIndices); + output_num_selected_indices->type = kTfLiteInt32; + SetTensorSizes(context, output_num_selected_indices, {}); + } + + return kTfLiteOk; +} + +// If num_selected_indices < max_output_size, the output tensor can contain +// garbage values initially present in memory. This causes segfault in +// downstream ops such as GATHER, since one of the outputs denotes indices and +// int garbage values can be pretty large. This method zeroes-out the remaining +// values. +// NOTE: We ensure memory being reset is valid, by setting pertinent output +// tensors to max_output_size length in Prepare. +void ResetUnusedElementsToZeroes(const int max_output_size, + const int num_selected_indices, + int* selected_indices, + float* selected_scores) { + for (int i = num_selected_indices; i < max_output_size; ++i) { + selected_indices[i] = 0; + if (selected_scores) { + selected_scores[i] = 0.0; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const bool is_soft_nms = NumInputs(node) == 6; + + const TfLiteTensor* input_boxes = GetInput(context, node, kInputTensorBoxes); + const int num_boxes = SizeOfDimension(input_boxes, 0); + const TfLiteTensor* input_scores = + GetInput(context, node, kInputTensorScores); + const TfLiteTensor* input_max_output_size = + GetInput(context, node, kInputTensorMaxOutputSize); + const int max_output_size_value = *GetTensorData(input_max_output_size); + const TfLiteTensor* input_iou_threshold = + GetInput(context, node, kInputTensorIouThreshold); + const float iou_threshold = *GetTensorData(input_iou_threshold); + const TfLiteTensor* input_score_threshold = + GetInput(context, node, kInputTensorScoreThreshold); + const float score_threshold = *GetTensorData(input_score_threshold); + + TfLiteTensor* output_selected_indices = nullptr; + TfLiteTensor* output_selected_scores = nullptr; + TfLiteTensor* output_num_selected_indices = nullptr; + + if (is_soft_nms) { + const TfLiteTensor* input_sigma = + GetInput(context, node, kInputTensorSigma); + const float soft_nms_sigma = *GetTensorData(input_sigma); + if (soft_nms_sigma < 0) { + context->ReportError(context, "Invalid sigma value for soft NMS: %f", + soft_nms_sigma); + return kTfLiteError; + } + + output_selected_indices = + GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices); + output_selected_scores = + GetOutput(context, node, kSoftNMSOutputTensorSelectedScores); + output_num_selected_indices = + GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices); + reference_ops::NonMaxSuppression( + input_boxes->data.f, num_boxes, input_scores->data.f, + max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma, + output_selected_indices->data.i32, output_selected_scores->data.f, + output_num_selected_indices->data.i32); + ResetUnusedElementsToZeroes( + max_output_size_value, *output_num_selected_indices->data.i32, + output_selected_indices->data.i32, output_selected_scores->data.f); + } else { + output_selected_indices = + GetOutput(context, node, kNMSOutputTensorSelectedIndices); + output_num_selected_indices = + GetOutput(context, node, kNMSOutputTensorNumSelectedIndices); + reference_ops::NonMaxSuppression( + input_boxes->data.f, num_boxes, input_scores->data.f, + max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0, + output_selected_indices->data.i32, /**selected_scores=**/ nullptr, + output_num_selected_indices->data.i32); + ResetUnusedElementsToZeroes(max_output_size_value, + *output_num_selected_indices->data.i32, + output_selected_indices->data.i32, nullptr); + } + + return kTfLiteOk; +} +} // namespace non_max_suppression + +TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4() { + static TfLiteRegistration r = {nullptr, nullptr, non_max_suppression::Prepare, + non_max_suppression::Eval}; + return &r; +} + +TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5() { + static TfLiteRegistration r = {nullptr, nullptr, non_max_suppression::Prepare, + non_max_suppression::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/lite/kernels/non_max_suppression_test.cc b/tensorflow/lite/kernels/non_max_suppression_test.cc new file mode 100644 index 00000000000..dd8efc0a300 --- /dev/null +++ b/tensorflow/lite/kernels/non_max_suppression_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseNMSOp : public SingleOpModel { + public: + void SetScores(std::initializer_list data) { + PopulateTensor(input_scores_, data); + } + + void SetScoreThreshold(float score_threshold) { + PopulateTensor(input_score_threshold_, {score_threshold}); + } + + std::vector GetSelectedIndices() { + return ExtractVector(output_selected_indices_); + } + + std::vector GetSelectedScores() { + return ExtractVector(output_selected_scores_); + } + + std::vector GetNumSelectedIndices() { + return ExtractVector(output_num_selected_indices_); + } + + protected: + int input_boxes_; + int input_scores_; + int input_max_output_size_; + int input_iou_threshold_; + int input_score_threshold_; + int input_sigma_; + + int output_selected_indices_; + int output_selected_scores_; + int output_num_selected_indices_; +}; + +class NonMaxSuppressionV4OpModel : public BaseNMSOp { + public: + explicit NonMaxSuppressionV4OpModel(const int max_output_size, + const float iou_threshold) { + const int num_boxes = 6; + input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}}); + input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}}); + input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size}); + input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold}); + input_score_threshold_ = AddInput({TensorType_FLOAT32, {}}); + + output_selected_indices_ = AddOutput(TensorType_INT32); + + output_num_selected_indices_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_NON_MAX_SUPPRESSION_V4, + BuiltinOptions_NonMaxSuppressionV4Options, + CreateNonMaxSuppressionV4Options(builder_).Union()); + BuildInterpreter({GetShape(input_boxes_), GetShape(input_scores_), + GetShape(input_max_output_size_), + GetShape(input_iou_threshold_), + GetShape(input_score_threshold_)}); + + // Default data. + PopulateTensor(input_boxes_, { + 1, 1, 0, 0, // Box 0 + 0, 0.1, 1, 1.1, // Box 1 + 0, .9f, 1, -0.1, // Box 2 + 0, 10, 1, 11, // Box 3 + 1, 10.1f, 0, 11.1, // Box 4 + 1, 101, 0, 100 // Box 5 + }); + } +}; + +TEST(NonMaxSuppressionV4OpModel, TestOutput) { + NonMaxSuppressionV4OpModel nms(6, 0.5); + nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); + nms.SetScoreThreshold(0.4); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 0, 0, 0, 0})); + + nms.SetScoreThreshold(0.99); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0})); + // The first two indices should be zeroed-out. + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0})); +} + +TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) { + NonMaxSuppressionV4OpModel nms(0, 0.5); + nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); + nms.SetScoreThreshold(0.4); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0})); +} + +class NonMaxSuppressionV5OpModel : public BaseNMSOp { + public: + explicit NonMaxSuppressionV5OpModel(const int max_output_size, + const float iou_threshold, + const float sigma) { + const int num_boxes = 6; + input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}}); + input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}}); + input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size}); + input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold}); + input_score_threshold_ = AddInput({TensorType_FLOAT32, {}}); + input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma}); + + output_selected_indices_ = AddOutput(TensorType_INT32); + output_selected_scores_ = AddOutput(TensorType_FLOAT32); + output_num_selected_indices_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_NON_MAX_SUPPRESSION_V5, + BuiltinOptions_NonMaxSuppressionV5Options, + CreateNonMaxSuppressionV5Options(builder_).Union()); + + BuildInterpreter( + {GetShape(input_boxes_), GetShape(input_scores_), + GetShape(input_max_output_size_), GetShape(input_iou_threshold_), + GetShape(input_score_threshold_), GetShape(input_sigma_)}); + + // Default data. + PopulateTensor(input_boxes_, { + 1, 1, 0, 0, // Box 0 + 0, 0.1, 1, 1.1, // Box 1 + 0, .9f, 1, -0.1, // Box 2 + 0, 10, 1, 11, // Box 3 + 1, 10.1f, 0, 11.1, // Box 4 + 1, 101, 0, 100 // Box 5 + }); + } +}; + +TEST(NonMaxSuppressionV5OpModel, TestOutput) { + NonMaxSuppressionV5OpModel nms(6, 0.5, 0.5); + nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); + nms.SetScoreThreshold(0.0); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5, 0, 0, 0})); + EXPECT_THAT(nms.GetSelectedScores(), + ElementsAreArray({0.95, 0.9, 0.3, 0.0, 0.0, 0.0})); + + // No candidate gets selected. But the outputs should be zeroed out. + nms.SetScoreThreshold(0.99); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0})); + EXPECT_THAT(nms.GetSelectedScores(), + ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0})); +} +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index da727c9fefc..c09525ff740 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -144,6 +144,8 @@ TfLiteRegistration* Register_MATRIX_DIAG(); TfLiteRegistration* Register_QUANTIZE(); TfLiteRegistration* Register_MATRIX_SET_DIAG(); TfLiteRegistration* Register_HARD_SWISH(); +TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V4(); +TfLiteRegistration* Register_NON_MAX_SUPPRESSION_V5(); const TfLiteRegistration* BuiltinOpResolver::FindOp(tflite::BuiltinOperator op, int version) const { @@ -382,6 +384,10 @@ BuiltinOpResolver::BuiltinOpResolver() { /* min_version */ 1, /* max_version */ 2); AddBuiltin(BuiltinOperator_MATRIX_SET_DIAG, Register_MATRIX_SET_DIAG()); + AddBuiltin(BuiltinOperator_NON_MAX_SUPPRESSION_V4, + Register_NON_MAX_SUPPRESSION_V4()); + AddBuiltin(BuiltinOperator_NON_MAX_SUPPRESSION_V5, + Register_NON_MAX_SUPPRESSION_V5()); // WARNING: Control flow ops are experimental and subject to change. AddBuiltin(BuiltinOperator_IF, tflite::ops::custom::Register_IF()); diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs index d6338603576..2a7c22286c3 100644 --- a/tensorflow/lite/schema/schema.fbs +++ b/tensorflow/lite/schema/schema.fbs @@ -233,6 +233,8 @@ enum BuiltinOperator : byte { HARD_SWISH = 117, IF = 118, WHILE = 119, + NON_MAX_SUPPRESSION_V4 = 120, + NON_MAX_SUPPRESSION_V5 = 121, } // Options for the builtin operators. @@ -330,7 +332,9 @@ union BuiltinOptions { HardSwishOptions, IfOptions, WhileOptions, - DepthToSpaceOptions + DepthToSpaceOptions, + NonMaxSuppressionV4Options, + NonMaxSuppressionV5Options } enum Padding : byte { SAME, VALID } @@ -802,6 +806,12 @@ table WhileOptions { body_subgraph_index:int; } +table NonMaxSuppressionV4Options { +} + +table NonMaxSuppressionV5Options { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h index b4509e694a6..fa249d38f6c 100755 --- a/tensorflow/lite/schema/schema_generated.h +++ b/tensorflow/lite/schema/schema_generated.h @@ -313,6 +313,12 @@ struct IfOptionsT; struct WhileOptions; struct WhileOptionsT; +struct NonMaxSuppressionV4Options; +struct NonMaxSuppressionV4OptionsT; + +struct NonMaxSuppressionV5Options; +struct NonMaxSuppressionV5OptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -589,11 +595,13 @@ enum BuiltinOperator { BuiltinOperator_HARD_SWISH = 117, BuiltinOperator_IF = 118, BuiltinOperator_WHILE = 119, + BuiltinOperator_NON_MAX_SUPPRESSION_V4 = 120, + BuiltinOperator_NON_MAX_SUPPRESSION_V5 = 121, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_WHILE + BuiltinOperator_MAX = BuiltinOperator_NON_MAX_SUPPRESSION_V5 }; -inline const BuiltinOperator (&EnumValuesBuiltinOperator())[120] { +inline const BuiltinOperator (&EnumValuesBuiltinOperator())[122] { static const BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -714,7 +722,9 @@ inline const BuiltinOperator (&EnumValuesBuiltinOperator())[120] { BuiltinOperator_ROUND, BuiltinOperator_HARD_SWISH, BuiltinOperator_IF, - BuiltinOperator_WHILE + BuiltinOperator_WHILE, + BuiltinOperator_NON_MAX_SUPPRESSION_V4, + BuiltinOperator_NON_MAX_SUPPRESSION_V5 }; return values; } @@ -841,13 +851,15 @@ inline const char * const *EnumNamesBuiltinOperator() { "HARD_SWISH", "IF", "WHILE", + "NON_MAX_SUPPRESSION_V4", + "NON_MAX_SUPPRESSION_V5", nullptr }; return names; } inline const char *EnumNameBuiltinOperator(BuiltinOperator e) { - if (e < BuiltinOperator_ADD || e > BuiltinOperator_WHILE) return ""; + if (e < BuiltinOperator_ADD || e > BuiltinOperator_NON_MAX_SUPPRESSION_V5) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOperator()[index]; } @@ -948,11 +960,13 @@ enum BuiltinOptions { BuiltinOptions_IfOptions = 92, BuiltinOptions_WhileOptions = 93, BuiltinOptions_DepthToSpaceOptions = 94, + BuiltinOptions_NonMaxSuppressionV4Options = 95, + BuiltinOptions_NonMaxSuppressionV5Options = 96, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_DepthToSpaceOptions + BuiltinOptions_MAX = BuiltinOptions_NonMaxSuppressionV5Options }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[95] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[97] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -1048,7 +1062,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[95] { BuiltinOptions_HardSwishOptions, BuiltinOptions_IfOptions, BuiltinOptions_WhileOptions, - BuiltinOptions_DepthToSpaceOptions + BuiltinOptions_DepthToSpaceOptions, + BuiltinOptions_NonMaxSuppressionV4Options, + BuiltinOptions_NonMaxSuppressionV5Options }; return values; } @@ -1150,13 +1166,15 @@ inline const char * const *EnumNamesBuiltinOptions() { "IfOptions", "WhileOptions", "DepthToSpaceOptions", + "NonMaxSuppressionV4Options", + "NonMaxSuppressionV5Options", nullptr }; return names; } inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { - if (e < BuiltinOptions_NONE || e > BuiltinOptions_DepthToSpaceOptions) return ""; + if (e < BuiltinOptions_NONE || e > BuiltinOptions_NonMaxSuppressionV5Options) return ""; const size_t index = static_cast(e); return EnumNamesBuiltinOptions()[index]; } @@ -1541,6 +1559,14 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DepthToSpaceOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV4Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NonMaxSuppressionV5Options; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -2325,6 +2351,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_DepthToSpaceOptions ? reinterpret_cast(value) : nullptr; } + NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() { + return type == BuiltinOptions_NonMaxSuppressionV4Options ? + reinterpret_cast(value) : nullptr; + } + const NonMaxSuppressionV4OptionsT *AsNonMaxSuppressionV4Options() const { + return type == BuiltinOptions_NonMaxSuppressionV4Options ? + reinterpret_cast(value) : nullptr; + } + NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() { + return type == BuiltinOptions_NonMaxSuppressionV5Options ? + reinterpret_cast(value) : nullptr; + } + const NonMaxSuppressionV5OptionsT *AsNonMaxSuppressionV5Options() const { + return type == BuiltinOptions_NonMaxSuppressionV5Options ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -8104,6 +8146,86 @@ inline flatbuffers::Offset CreateWhileOptions( flatbuffers::Offset CreateWhileOptions(flatbuffers::FlatBufferBuilder &_fbb, const WhileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct NonMaxSuppressionV4OptionsT : public flatbuffers::NativeTable { + typedef NonMaxSuppressionV4Options TableType; + NonMaxSuppressionV4OptionsT() { + } +}; + +struct NonMaxSuppressionV4Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NonMaxSuppressionV4OptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NonMaxSuppressionV4OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NonMaxSuppressionV4OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NonMaxSuppressionV4OptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NonMaxSuppressionV4OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NonMaxSuppressionV4OptionsBuilder &operator=(const NonMaxSuppressionV4OptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNonMaxSuppressionV4Options( + flatbuffers::FlatBufferBuilder &_fbb) { + NonMaxSuppressionV4OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNonMaxSuppressionV4Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NonMaxSuppressionV5OptionsT : public flatbuffers::NativeTable { + typedef NonMaxSuppressionV5Options TableType; + NonMaxSuppressionV5OptionsT() { + } +}; + +struct NonMaxSuppressionV5Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NonMaxSuppressionV5OptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NonMaxSuppressionV5OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NonMaxSuppressionV5OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NonMaxSuppressionV5OptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NonMaxSuppressionV5OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NonMaxSuppressionV5OptionsBuilder &operator=(const NonMaxSuppressionV5OptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNonMaxSuppressionV5Options( + flatbuffers::FlatBufferBuilder &_fbb) { + NonMaxSuppressionV5OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNonMaxSuppressionV5Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -8522,6 +8644,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const DepthToSpaceOptions *builtin_options_as_DepthToSpaceOptions() const { return builtin_options_type() == BuiltinOptions_DepthToSpaceOptions ? static_cast(builtin_options()) : nullptr; } + const NonMaxSuppressionV4Options *builtin_options_as_NonMaxSuppressionV4Options() const { + return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV4Options ? static_cast(builtin_options()) : nullptr; + } + const NonMaxSuppressionV5Options *builtin_options_as_NonMaxSuppressionV5Options() const { + return builtin_options_type() == BuiltinOptions_NonMaxSuppressionV5Options ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -8934,6 +9062,14 @@ template<> inline const DepthToSpaceOptions *Operator::builtin_options_as inline const NonMaxSuppressionV4Options *Operator::builtin_options_as() const { + return builtin_options_as_NonMaxSuppressionV4Options(); +} + +template<> inline const NonMaxSuppressionV5Options *Operator::builtin_options_as() const { + return builtin_options_as_NonMaxSuppressionV5Options(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -12043,6 +12179,52 @@ inline flatbuffers::Offset CreateWhileOptions(flatbuffers::FlatBuf _body_subgraph_index); } +inline NonMaxSuppressionV4OptionsT *NonMaxSuppressionV4Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NonMaxSuppressionV4OptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NonMaxSuppressionV4Options::UnPackTo(NonMaxSuppressionV4OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NonMaxSuppressionV4Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNonMaxSuppressionV4Options(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNonMaxSuppressionV4Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV4OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NonMaxSuppressionV4OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNonMaxSuppressionV4Options( + _fbb); +} + +inline NonMaxSuppressionV5OptionsT *NonMaxSuppressionV5Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NonMaxSuppressionV5OptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NonMaxSuppressionV5Options::UnPackTo(NonMaxSuppressionV5OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NonMaxSuppressionV5Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNonMaxSuppressionV5Options(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNonMaxSuppressionV5Options(flatbuffers::FlatBufferBuilder &_fbb, const NonMaxSuppressionV5OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NonMaxSuppressionV5OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNonMaxSuppressionV5Options( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -12712,6 +12894,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -13106,6 +13296,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -13488,6 +13686,14 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateDepthToSpaceOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(value); + return CreateNonMaxSuppressionV4Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(value); + return CreateNonMaxSuppressionV5Options(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -13870,6 +14076,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new DepthToSpaceOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_NonMaxSuppressionV4Options: { + value = new NonMaxSuppressionV4OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + value = new NonMaxSuppressionV5OptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -14347,6 +14561,16 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_NonMaxSuppressionV4Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NonMaxSuppressionV5Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr;