diff --git a/tensorflow/lite/kernels/non_max_suppression.cc b/tensorflow/lite/kernels/non_max_suppression.cc index 0e0bf9c1246..33de6830a3d 100644 --- a/tensorflow/lite/kernels/non_max_suppression.cc +++ b/tensorflow/lite/kernels/non_max_suppression.cc @@ -38,6 +38,7 @@ constexpr int kInputTensorBoxes = 0; // Type: Float. constexpr int kInputTensorScores = 1; // Max number of boxes to output. Actual output can be smaller. +// The output tensors (indices/scores) are of this length. // Type: Int32. constexpr int kInputTensorMaxOutputSize = 2; // Type: Float. @@ -99,14 +100,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { GetInput(context, node, kInputTensorMaxOutputSize); TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32); TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0); - // TODO(b/265135869): Add support for non-constant max_output_size by making - // output dynamic? - if (!IsConstantTensor(input_max_output_size)) { - context->ReportError(context, "Max output size should be a constant"); - return kTfLiteError; + const bool is_max_output_size_const = IsConstantTensor(input_max_output_size); + int max_output_size_value = 0; + if (is_max_output_size_const) { + max_output_size_value = *GetTensorData(input_max_output_size); + TF_LITE_ENSURE(context, (max_output_size_value >= 0)); } - int max_output_size_value = *GetTensorData(input_max_output_size); - TF_LITE_ENSURE(context, (max_output_size_value >= 0)); // IoU & Score thresholds. const TfLiteTensor* input_iou_threshold = @@ -128,25 +127,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output_selected_indices = GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices); output_selected_indices->type = kTfLiteInt32; - SetTensorSizes(context, output_selected_indices, {max_output_size_value}); TfLiteTensor* output_selected_scores = GetOutput(context, node, kSoftNMSOutputTensorSelectedScores); output_selected_scores->type = kTfLiteFloat32; - SetTensorSizes(context, output_selected_scores, {max_output_size_value}); TfLiteTensor* output_num_selected_indices = GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices); output_num_selected_indices->type = kTfLiteInt32; SetTensorSizes(context, output_num_selected_indices, {}); + + if (is_max_output_size_const) { + SetTensorSizes(context, output_selected_indices, {max_output_size_value}); + SetTensorSizes(context, output_selected_scores, {max_output_size_value}); + } else { + SetTensorToDynamic(output_selected_indices); + SetTensorToDynamic(output_selected_scores); + } } else { TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); TfLiteTensor* output_selected_indices = GetOutput(context, node, kNMSOutputTensorSelectedIndices); output_selected_indices->type = kTfLiteInt32; - SetTensorSizes(context, output_selected_indices, {max_output_size_value}); TfLiteTensor* output_num_selected_indices = GetOutput(context, node, kNMSOutputTensorNumSelectedIndices); output_num_selected_indices->type = kTfLiteInt32; SetTensorSizes(context, output_num_selected_indices, {}); + + if (is_max_output_size_const) { + SetTensorSizes(context, output_selected_indices, {max_output_size_value}); + } else { + SetTensorToDynamic(output_selected_indices); + } } return kTfLiteOk; @@ -181,6 +191,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input_max_output_size = GetInput(context, node, kInputTensorMaxOutputSize); const int max_output_size_value = *GetTensorData(input_max_output_size); + TF_LITE_ENSURE(context, (max_output_size_value >= 0)); + const bool is_max_output_size_const = IsConstantTensor(input_max_output_size); const TfLiteTensor* input_iou_threshold = GetInput(context, node, kInputTensorIouThreshold); const float iou_threshold = *GetTensorData(input_iou_threshold); @@ -208,6 +220,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetOutput(context, node, kSoftNMSOutputTensorSelectedScores); output_num_selected_indices = GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices); + if (!is_max_output_size_const) { + SetTensorSizes(context, output_selected_indices, {max_output_size_value}); + SetTensorSizes(context, output_selected_scores, {max_output_size_value}); + } reference_ops::NonMaxSuppression( input_boxes->data.f, num_boxes, input_scores->data.f, max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma, @@ -221,6 +237,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetOutput(context, node, kNMSOutputTensorSelectedIndices); output_num_selected_indices = GetOutput(context, node, kNMSOutputTensorNumSelectedIndices); + if (!is_max_output_size_const) { + SetTensorSizes(context, output_selected_indices, {max_output_size_value}); + } reference_ops::NonMaxSuppression( input_boxes->data.f, num_boxes, input_scores->data.f, max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0, diff --git a/tensorflow/lite/kernels/non_max_suppression_test.cc b/tensorflow/lite/kernels/non_max_suppression_test.cc index dd8efc0a300..454bb5a0959 100644 --- a/tensorflow/lite/kernels/non_max_suppression_test.cc +++ b/tensorflow/lite/kernels/non_max_suppression_test.cc @@ -29,6 +29,10 @@ class BaseNMSOp : public SingleOpModel { PopulateTensor(input_scores_, data); } + void SetMaxOutputSize(int max_output_size) { + PopulateTensor(input_max_output_size_, {max_output_size}); + } + void SetScoreThreshold(float score_threshold) { PopulateTensor(input_score_threshold_, {score_threshold}); } @@ -60,12 +64,18 @@ class BaseNMSOp : public SingleOpModel { class NonMaxSuppressionV4OpModel : public BaseNMSOp { public: - explicit NonMaxSuppressionV4OpModel(const int max_output_size, - const float iou_threshold) { + explicit NonMaxSuppressionV4OpModel(const float iou_threshold, + const bool static_shaped_outputs, + const int max_output_size = -1) { const int num_boxes = 6; input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}}); input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}}); - input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size}); + if (static_shaped_outputs) { + input_max_output_size_ = + AddConstInput(TensorType_INT32, {max_output_size}); + } else { + input_max_output_size_ = AddInput(TensorType_INT32, {}); + } input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold}); input_score_threshold_ = AddInput({TensorType_FLOAT32, {}}); @@ -94,7 +104,9 @@ class NonMaxSuppressionV4OpModel : public BaseNMSOp { }; TEST(NonMaxSuppressionV4OpModel, TestOutput) { - NonMaxSuppressionV4OpModel nms(6, 0.5); + NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5, + /**static_shaped_outputs=**/ true, + /**max_output_size=**/ 6); nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); nms.SetScoreThreshold(0.4); nms.Invoke(); @@ -108,8 +120,32 @@ TEST(NonMaxSuppressionV4OpModel, TestOutput) { EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0})); } +TEST(NonMaxSuppressionV4OpModel, TestDynamicOutput) { + NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5, + /**static_shaped_outputs=**/ false); + nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); + nms.SetScoreThreshold(0.4); + + nms.SetMaxOutputSize(1); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3})); + + nms.SetMaxOutputSize(2); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0})); + + nms.SetScoreThreshold(0.99); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0})); +} + TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) { - NonMaxSuppressionV4OpModel nms(0, 0.5); + NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5, + /**static_shaped_outputs=**/ true, + /**max_output_size=**/ 0); nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); nms.SetScoreThreshold(0.4); nms.Invoke(); @@ -118,13 +154,19 @@ TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) { class NonMaxSuppressionV5OpModel : public BaseNMSOp { public: - explicit NonMaxSuppressionV5OpModel(const int max_output_size, - const float iou_threshold, - const float sigma) { + explicit NonMaxSuppressionV5OpModel(const float iou_threshold, + const float sigma, + const bool static_shaped_outputs, + const int max_output_size = -1) { const int num_boxes = 6; input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}}); input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}}); - input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size}); + if (static_shaped_outputs) { + input_max_output_size_ = + AddConstInput(TensorType_INT32, {max_output_size}); + } else { + input_max_output_size_ = AddInput(TensorType_INT32, {}); + } input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold}); input_score_threshold_ = AddInput({TensorType_FLOAT32, {}}); input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma}); @@ -155,7 +197,10 @@ class NonMaxSuppressionV5OpModel : public BaseNMSOp { }; TEST(NonMaxSuppressionV5OpModel, TestOutput) { - NonMaxSuppressionV5OpModel nms(6, 0.5, 0.5); + NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5, + /**sigma=**/ 0.5, + /**static_shaped_outputs=**/ true, + /**max_output_size=**/ 6); nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); nms.SetScoreThreshold(0.0); nms.Invoke(); @@ -172,5 +217,39 @@ TEST(NonMaxSuppressionV5OpModel, TestOutput) { EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0})); } + +TEST(NonMaxSuppressionV5OpModel, TestDynamicOutput) { + NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5, + /**sigma=**/ 0.5, + /**static_shaped_outputs=**/ false, + /**max_output_size=**/ 6); + nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3}); + nms.SetScoreThreshold(0.0); + + nms.SetMaxOutputSize(2); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0})); + EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9})); + + nms.SetMaxOutputSize(1); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3})); + EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95})); + + nms.SetMaxOutputSize(3); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5})); + EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9, 0.3})); + + // No candidate gets selected. But the outputs should be zeroed out. + nms.SetScoreThreshold(0.99); + nms.Invoke(); + EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0})); + EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0})); + EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0})); +} } // namespace } // namespace tflite