Adds support for dynamic-shaped outputs from NMSv4 & v5. MaskRCNN has NMS ops where the output size input is not constant.
PiperOrigin-RevId: 269587721
This commit is contained in:
parent
0d76b0f988
commit
498dc69fb5
@ -38,6 +38,7 @@ constexpr int kInputTensorBoxes = 0;
|
||||
// Type: Float.
|
||||
constexpr int kInputTensorScores = 1;
|
||||
// Max number of boxes to output. Actual output can be smaller.
|
||||
// The output tensors (indices/scores) are of this length.
|
||||
// Type: Int32.
|
||||
constexpr int kInputTensorMaxOutputSize = 2;
|
||||
// Type: Float.
|
||||
@ -99,14 +100,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetInput(context, node, kInputTensorMaxOutputSize);
|
||||
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
|
||||
// TODO(b/265135869): Add support for non-constant max_output_size by making
|
||||
// output dynamic?
|
||||
if (!IsConstantTensor(input_max_output_size)) {
|
||||
context->ReportError(context, "Max output size should be a constant");
|
||||
return kTfLiteError;
|
||||
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
|
||||
int max_output_size_value = 0;
|
||||
if (is_max_output_size_const) {
|
||||
max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
||||
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
||||
}
|
||||
int max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
||||
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
||||
|
||||
// IoU & Score thresholds.
|
||||
const TfLiteTensor* input_iou_threshold =
|
||||
@ -128,25 +127,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
|
||||
output_selected_indices->type = kTfLiteInt32;
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
TfLiteTensor* output_selected_scores =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
||||
output_selected_scores->type = kTfLiteFloat32;
|
||||
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
||||
TfLiteTensor* output_num_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
||||
output_num_selected_indices->type = kTfLiteInt32;
|
||||
SetTensorSizes(context, output_num_selected_indices, {});
|
||||
|
||||
if (is_max_output_size_const) {
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
||||
} else {
|
||||
SetTensorToDynamic(output_selected_indices);
|
||||
SetTensorToDynamic(output_selected_scores);
|
||||
}
|
||||
} else {
|
||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||
TfLiteTensor* output_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
||||
output_selected_indices->type = kTfLiteInt32;
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
TfLiteTensor* output_num_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
||||
output_num_selected_indices->type = kTfLiteInt32;
|
||||
SetTensorSizes(context, output_num_selected_indices, {});
|
||||
|
||||
if (is_max_output_size_const) {
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
} else {
|
||||
SetTensorToDynamic(output_selected_indices);
|
||||
}
|
||||
}
|
||||
|
||||
return kTfLiteOk;
|
||||
@ -181,6 +191,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_max_output_size =
|
||||
GetInput(context, node, kInputTensorMaxOutputSize);
|
||||
const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
||||
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
||||
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
|
||||
const TfLiteTensor* input_iou_threshold =
|
||||
GetInput(context, node, kInputTensorIouThreshold);
|
||||
const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
|
||||
@ -208,6 +220,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
||||
output_num_selected_indices =
|
||||
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
||||
if (!is_max_output_size_const) {
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
||||
}
|
||||
reference_ops::NonMaxSuppression(
|
||||
input_boxes->data.f, num_boxes, input_scores->data.f,
|
||||
max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma,
|
||||
@ -221,6 +237,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
||||
output_num_selected_indices =
|
||||
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
||||
if (!is_max_output_size_const) {
|
||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||
}
|
||||
reference_ops::NonMaxSuppression(
|
||||
input_boxes->data.f, num_boxes, input_scores->data.f,
|
||||
max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0,
|
||||
|
@ -29,6 +29,10 @@ class BaseNMSOp : public SingleOpModel {
|
||||
PopulateTensor(input_scores_, data);
|
||||
}
|
||||
|
||||
void SetMaxOutputSize(int max_output_size) {
|
||||
PopulateTensor(input_max_output_size_, {max_output_size});
|
||||
}
|
||||
|
||||
void SetScoreThreshold(float score_threshold) {
|
||||
PopulateTensor(input_score_threshold_, {score_threshold});
|
||||
}
|
||||
@ -60,12 +64,18 @@ class BaseNMSOp : public SingleOpModel {
|
||||
|
||||
class NonMaxSuppressionV4OpModel : public BaseNMSOp {
|
||||
public:
|
||||
explicit NonMaxSuppressionV4OpModel(const int max_output_size,
|
||||
const float iou_threshold) {
|
||||
explicit NonMaxSuppressionV4OpModel(const float iou_threshold,
|
||||
const bool static_shaped_outputs,
|
||||
const int max_output_size = -1) {
|
||||
const int num_boxes = 6;
|
||||
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
|
||||
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
|
||||
input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size});
|
||||
if (static_shaped_outputs) {
|
||||
input_max_output_size_ =
|
||||
AddConstInput(TensorType_INT32, {max_output_size});
|
||||
} else {
|
||||
input_max_output_size_ = AddInput(TensorType_INT32, {});
|
||||
}
|
||||
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
|
||||
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
|
||||
|
||||
@ -94,7 +104,9 @@ class NonMaxSuppressionV4OpModel : public BaseNMSOp {
|
||||
};
|
||||
|
||||
TEST(NonMaxSuppressionV4OpModel, TestOutput) {
|
||||
NonMaxSuppressionV4OpModel nms(6, 0.5);
|
||||
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
|
||||
/**static_shaped_outputs=**/ true,
|
||||
/**max_output_size=**/ 6);
|
||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||
nms.SetScoreThreshold(0.4);
|
||||
nms.Invoke();
|
||||
@ -108,8 +120,32 @@ TEST(NonMaxSuppressionV4OpModel, TestOutput) {
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionV4OpModel, TestDynamicOutput) {
|
||||
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
|
||||
/**static_shaped_outputs=**/ false);
|
||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||
nms.SetScoreThreshold(0.4);
|
||||
|
||||
nms.SetMaxOutputSize(1);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
|
||||
|
||||
nms.SetMaxOutputSize(2);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
|
||||
|
||||
nms.SetScoreThreshold(0.99);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0}));
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
|
||||
NonMaxSuppressionV4OpModel nms(0, 0.5);
|
||||
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
|
||||
/**static_shaped_outputs=**/ true,
|
||||
/**max_output_size=**/ 0);
|
||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||
nms.SetScoreThreshold(0.4);
|
||||
nms.Invoke();
|
||||
@ -118,13 +154,19 @@ TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
|
||||
|
||||
class NonMaxSuppressionV5OpModel : public BaseNMSOp {
|
||||
public:
|
||||
explicit NonMaxSuppressionV5OpModel(const int max_output_size,
|
||||
const float iou_threshold,
|
||||
const float sigma) {
|
||||
explicit NonMaxSuppressionV5OpModel(const float iou_threshold,
|
||||
const float sigma,
|
||||
const bool static_shaped_outputs,
|
||||
const int max_output_size = -1) {
|
||||
const int num_boxes = 6;
|
||||
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
|
||||
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
|
||||
input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size});
|
||||
if (static_shaped_outputs) {
|
||||
input_max_output_size_ =
|
||||
AddConstInput(TensorType_INT32, {max_output_size});
|
||||
} else {
|
||||
input_max_output_size_ = AddInput(TensorType_INT32, {});
|
||||
}
|
||||
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
|
||||
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
|
||||
input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma});
|
||||
@ -155,7 +197,10 @@ class NonMaxSuppressionV5OpModel : public BaseNMSOp {
|
||||
};
|
||||
|
||||
TEST(NonMaxSuppressionV5OpModel, TestOutput) {
|
||||
NonMaxSuppressionV5OpModel nms(6, 0.5, 0.5);
|
||||
NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
|
||||
/**sigma=**/ 0.5,
|
||||
/**static_shaped_outputs=**/ true,
|
||||
/**max_output_size=**/ 6);
|
||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||
nms.SetScoreThreshold(0.0);
|
||||
nms.Invoke();
|
||||
@ -172,5 +217,39 @@ TEST(NonMaxSuppressionV5OpModel, TestOutput) {
|
||||
EXPECT_THAT(nms.GetSelectedScores(),
|
||||
ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
|
||||
}
|
||||
|
||||
TEST(NonMaxSuppressionV5OpModel, TestDynamicOutput) {
|
||||
NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
|
||||
/**sigma=**/ 0.5,
|
||||
/**static_shaped_outputs=**/ false,
|
||||
/**max_output_size=**/ 6);
|
||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||
nms.SetScoreThreshold(0.0);
|
||||
|
||||
nms.SetMaxOutputSize(2);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
|
||||
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9}));
|
||||
|
||||
nms.SetMaxOutputSize(1);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
|
||||
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95}));
|
||||
|
||||
nms.SetMaxOutputSize(3);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5}));
|
||||
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9, 0.3}));
|
||||
|
||||
// No candidate gets selected. But the outputs should be zeroed out.
|
||||
nms.SetScoreThreshold(0.99);
|
||||
nms.Invoke();
|
||||
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
|
||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0}));
|
||||
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
Loading…
Reference in New Issue
Block a user