Adds support for dynamic-shaped outputs from NMSv4 & v5. MaskRCNN has NMS ops where the output size input is not constant.
PiperOrigin-RevId: 269587721
This commit is contained in:
parent
0d76b0f988
commit
498dc69fb5
@ -38,6 +38,7 @@ constexpr int kInputTensorBoxes = 0;
|
|||||||
// Type: Float.
|
// Type: Float.
|
||||||
constexpr int kInputTensorScores = 1;
|
constexpr int kInputTensorScores = 1;
|
||||||
// Max number of boxes to output. Actual output can be smaller.
|
// Max number of boxes to output. Actual output can be smaller.
|
||||||
|
// The output tensors (indices/scores) are of this length.
|
||||||
// Type: Int32.
|
// Type: Int32.
|
||||||
constexpr int kInputTensorMaxOutputSize = 2;
|
constexpr int kInputTensorMaxOutputSize = 2;
|
||||||
// Type: Float.
|
// Type: Float.
|
||||||
@ -99,14 +100,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetInput(context, node, kInputTensorMaxOutputSize);
|
GetInput(context, node, kInputTensorMaxOutputSize);
|
||||||
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
|
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
|
||||||
// TODO(b/265135869): Add support for non-constant max_output_size by making
|
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
|
||||||
// output dynamic?
|
int max_output_size_value = 0;
|
||||||
if (!IsConstantTensor(input_max_output_size)) {
|
if (is_max_output_size_const) {
|
||||||
context->ReportError(context, "Max output size should be a constant");
|
max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
||||||
return kTfLiteError;
|
|
||||||
}
|
|
||||||
int max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
|
||||||
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
||||||
|
}
|
||||||
|
|
||||||
// IoU & Score thresholds.
|
// IoU & Score thresholds.
|
||||||
const TfLiteTensor* input_iou_threshold =
|
const TfLiteTensor* input_iou_threshold =
|
||||||
@ -128,25 +127,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TfLiteTensor* output_selected_indices =
|
TfLiteTensor* output_selected_indices =
|
||||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
|
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
|
||||||
output_selected_indices->type = kTfLiteInt32;
|
output_selected_indices->type = kTfLiteInt32;
|
||||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
|
||||||
TfLiteTensor* output_selected_scores =
|
TfLiteTensor* output_selected_scores =
|
||||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
||||||
output_selected_scores->type = kTfLiteFloat32;
|
output_selected_scores->type = kTfLiteFloat32;
|
||||||
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
|
||||||
TfLiteTensor* output_num_selected_indices =
|
TfLiteTensor* output_num_selected_indices =
|
||||||
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
||||||
output_num_selected_indices->type = kTfLiteInt32;
|
output_num_selected_indices->type = kTfLiteInt32;
|
||||||
SetTensorSizes(context, output_num_selected_indices, {});
|
SetTensorSizes(context, output_num_selected_indices, {});
|
||||||
|
|
||||||
|
if (is_max_output_size_const) {
|
||||||
|
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||||
|
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
||||||
|
} else {
|
||||||
|
SetTensorToDynamic(output_selected_indices);
|
||||||
|
SetTensorToDynamic(output_selected_scores);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
|
||||||
TfLiteTensor* output_selected_indices =
|
TfLiteTensor* output_selected_indices =
|
||||||
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
||||||
output_selected_indices->type = kTfLiteInt32;
|
output_selected_indices->type = kTfLiteInt32;
|
||||||
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
|
||||||
TfLiteTensor* output_num_selected_indices =
|
TfLiteTensor* output_num_selected_indices =
|
||||||
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
||||||
output_num_selected_indices->type = kTfLiteInt32;
|
output_num_selected_indices->type = kTfLiteInt32;
|
||||||
SetTensorSizes(context, output_num_selected_indices, {});
|
SetTensorSizes(context, output_num_selected_indices, {});
|
||||||
|
|
||||||
|
if (is_max_output_size_const) {
|
||||||
|
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||||
|
} else {
|
||||||
|
SetTensorToDynamic(output_selected_indices);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
@ -181,6 +191,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
const TfLiteTensor* input_max_output_size =
|
const TfLiteTensor* input_max_output_size =
|
||||||
GetInput(context, node, kInputTensorMaxOutputSize);
|
GetInput(context, node, kInputTensorMaxOutputSize);
|
||||||
const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
|
||||||
|
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
|
||||||
|
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
|
||||||
const TfLiteTensor* input_iou_threshold =
|
const TfLiteTensor* input_iou_threshold =
|
||||||
GetInput(context, node, kInputTensorIouThreshold);
|
GetInput(context, node, kInputTensorIouThreshold);
|
||||||
const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
|
const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
|
||||||
@ -208,6 +220,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
|
||||||
output_num_selected_indices =
|
output_num_selected_indices =
|
||||||
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
|
||||||
|
if (!is_max_output_size_const) {
|
||||||
|
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||||
|
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
|
||||||
|
}
|
||||||
reference_ops::NonMaxSuppression(
|
reference_ops::NonMaxSuppression(
|
||||||
input_boxes->data.f, num_boxes, input_scores->data.f,
|
input_boxes->data.f, num_boxes, input_scores->data.f,
|
||||||
max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma,
|
max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma,
|
||||||
@ -221,6 +237,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
|
||||||
output_num_selected_indices =
|
output_num_selected_indices =
|
||||||
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
|
||||||
|
if (!is_max_output_size_const) {
|
||||||
|
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
|
||||||
|
}
|
||||||
reference_ops::NonMaxSuppression(
|
reference_ops::NonMaxSuppression(
|
||||||
input_boxes->data.f, num_boxes, input_scores->data.f,
|
input_boxes->data.f, num_boxes, input_scores->data.f,
|
||||||
max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0,
|
max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0,
|
||||||
|
@ -29,6 +29,10 @@ class BaseNMSOp : public SingleOpModel {
|
|||||||
PopulateTensor(input_scores_, data);
|
PopulateTensor(input_scores_, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetMaxOutputSize(int max_output_size) {
|
||||||
|
PopulateTensor(input_max_output_size_, {max_output_size});
|
||||||
|
}
|
||||||
|
|
||||||
void SetScoreThreshold(float score_threshold) {
|
void SetScoreThreshold(float score_threshold) {
|
||||||
PopulateTensor(input_score_threshold_, {score_threshold});
|
PopulateTensor(input_score_threshold_, {score_threshold});
|
||||||
}
|
}
|
||||||
@ -60,12 +64,18 @@ class BaseNMSOp : public SingleOpModel {
|
|||||||
|
|
||||||
class NonMaxSuppressionV4OpModel : public BaseNMSOp {
|
class NonMaxSuppressionV4OpModel : public BaseNMSOp {
|
||||||
public:
|
public:
|
||||||
explicit NonMaxSuppressionV4OpModel(const int max_output_size,
|
explicit NonMaxSuppressionV4OpModel(const float iou_threshold,
|
||||||
const float iou_threshold) {
|
const bool static_shaped_outputs,
|
||||||
|
const int max_output_size = -1) {
|
||||||
const int num_boxes = 6;
|
const int num_boxes = 6;
|
||||||
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
|
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
|
||||||
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
|
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
|
||||||
input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size});
|
if (static_shaped_outputs) {
|
||||||
|
input_max_output_size_ =
|
||||||
|
AddConstInput(TensorType_INT32, {max_output_size});
|
||||||
|
} else {
|
||||||
|
input_max_output_size_ = AddInput(TensorType_INT32, {});
|
||||||
|
}
|
||||||
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
|
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
|
||||||
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
|
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
|
||||||
|
|
||||||
@ -94,7 +104,9 @@ class NonMaxSuppressionV4OpModel : public BaseNMSOp {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST(NonMaxSuppressionV4OpModel, TestOutput) {
|
TEST(NonMaxSuppressionV4OpModel, TestOutput) {
|
||||||
NonMaxSuppressionV4OpModel nms(6, 0.5);
|
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
|
||||||
|
/**static_shaped_outputs=**/ true,
|
||||||
|
/**max_output_size=**/ 6);
|
||||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||||
nms.SetScoreThreshold(0.4);
|
nms.SetScoreThreshold(0.4);
|
||||||
nms.Invoke();
|
nms.Invoke();
|
||||||
@ -108,8 +120,32 @@ TEST(NonMaxSuppressionV4OpModel, TestOutput) {
|
|||||||
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(NonMaxSuppressionV4OpModel, TestDynamicOutput) {
|
||||||
|
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
|
||||||
|
/**static_shaped_outputs=**/ false);
|
||||||
|
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||||
|
nms.SetScoreThreshold(0.4);
|
||||||
|
|
||||||
|
nms.SetMaxOutputSize(1);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
|
||||||
|
|
||||||
|
nms.SetMaxOutputSize(2);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
|
||||||
|
|
||||||
|
nms.SetScoreThreshold(0.99);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
|
TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
|
||||||
NonMaxSuppressionV4OpModel nms(0, 0.5);
|
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
|
||||||
|
/**static_shaped_outputs=**/ true,
|
||||||
|
/**max_output_size=**/ 0);
|
||||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||||
nms.SetScoreThreshold(0.4);
|
nms.SetScoreThreshold(0.4);
|
||||||
nms.Invoke();
|
nms.Invoke();
|
||||||
@ -118,13 +154,19 @@ TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
|
|||||||
|
|
||||||
class NonMaxSuppressionV5OpModel : public BaseNMSOp {
|
class NonMaxSuppressionV5OpModel : public BaseNMSOp {
|
||||||
public:
|
public:
|
||||||
explicit NonMaxSuppressionV5OpModel(const int max_output_size,
|
explicit NonMaxSuppressionV5OpModel(const float iou_threshold,
|
||||||
const float iou_threshold,
|
const float sigma,
|
||||||
const float sigma) {
|
const bool static_shaped_outputs,
|
||||||
|
const int max_output_size = -1) {
|
||||||
const int num_boxes = 6;
|
const int num_boxes = 6;
|
||||||
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
|
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
|
||||||
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
|
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
|
||||||
input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size});
|
if (static_shaped_outputs) {
|
||||||
|
input_max_output_size_ =
|
||||||
|
AddConstInput(TensorType_INT32, {max_output_size});
|
||||||
|
} else {
|
||||||
|
input_max_output_size_ = AddInput(TensorType_INT32, {});
|
||||||
|
}
|
||||||
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
|
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
|
||||||
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
|
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
|
||||||
input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma});
|
input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma});
|
||||||
@ -155,7 +197,10 @@ class NonMaxSuppressionV5OpModel : public BaseNMSOp {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST(NonMaxSuppressionV5OpModel, TestOutput) {
|
TEST(NonMaxSuppressionV5OpModel, TestOutput) {
|
||||||
NonMaxSuppressionV5OpModel nms(6, 0.5, 0.5);
|
NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
|
||||||
|
/**sigma=**/ 0.5,
|
||||||
|
/**static_shaped_outputs=**/ true,
|
||||||
|
/**max_output_size=**/ 6);
|
||||||
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||||
nms.SetScoreThreshold(0.0);
|
nms.SetScoreThreshold(0.0);
|
||||||
nms.Invoke();
|
nms.Invoke();
|
||||||
@ -172,5 +217,39 @@ TEST(NonMaxSuppressionV5OpModel, TestOutput) {
|
|||||||
EXPECT_THAT(nms.GetSelectedScores(),
|
EXPECT_THAT(nms.GetSelectedScores(),
|
||||||
ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
|
ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(NonMaxSuppressionV5OpModel, TestDynamicOutput) {
|
||||||
|
NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
|
||||||
|
/**sigma=**/ 0.5,
|
||||||
|
/**static_shaped_outputs=**/ false,
|
||||||
|
/**max_output_size=**/ 6);
|
||||||
|
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
|
||||||
|
nms.SetScoreThreshold(0.0);
|
||||||
|
|
||||||
|
nms.SetMaxOutputSize(2);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9}));
|
||||||
|
|
||||||
|
nms.SetMaxOutputSize(1);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95}));
|
||||||
|
|
||||||
|
nms.SetMaxOutputSize(3);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9, 0.3}));
|
||||||
|
|
||||||
|
// No candidate gets selected. But the outputs should be zeroed out.
|
||||||
|
nms.SetScoreThreshold(0.99);
|
||||||
|
nms.Invoke();
|
||||||
|
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0}));
|
||||||
|
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0}));
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
Loading…
Reference in New Issue
Block a user