Adds support for dynamic-shaped outputs from NMSv4 & v5. MaskRCNN has NMS ops where the output size input is not constant.

PiperOrigin-RevId: 269587721
This commit is contained in:
Sachin Joglekar 2019-09-17 09:37:33 -07:00 committed by TensorFlower Gardener
parent 0d76b0f988
commit 498dc69fb5
2 changed files with 118 additions and 20 deletions

View File

@ -38,6 +38,7 @@ constexpr int kInputTensorBoxes = 0;
// Type: Float.
constexpr int kInputTensorScores = 1;
// Max number of boxes to output. Actual output can be smaller.
// The output tensors (indices/scores) are of this length.
// Type: Int32.
constexpr int kInputTensorMaxOutputSize = 2;
// Type: Float.
@ -99,14 +100,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
GetInput(context, node, kInputTensorMaxOutputSize);
TF_LITE_ENSURE_EQ(context, input_max_output_size->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, NumDimensions(input_max_output_size), 0);
// TODO(b/265135869): Add support for non-constant max_output_size by making
// output dynamic?
if (!IsConstantTensor(input_max_output_size)) {
context->ReportError(context, "Max output size should be a constant");
return kTfLiteError;
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
int max_output_size_value = 0;
if (is_max_output_size_const) {
max_output_size_value = *GetTensorData<int>(input_max_output_size);
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
}
int max_output_size_value = *GetTensorData<int>(input_max_output_size);
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
// IoU & Score thresholds.
const TfLiteTensor* input_iou_threshold =
@ -128,25 +127,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorSelectedIndices);
output_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
TfLiteTensor* output_selected_scores =
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
output_selected_scores->type = kTfLiteFloat32;
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
TfLiteTensor* output_num_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
output_num_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_num_selected_indices, {});
if (is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
} else {
SetTensorToDynamic(output_selected_indices);
SetTensorToDynamic(output_selected_scores);
}
} else {
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2);
TfLiteTensor* output_selected_indices =
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
output_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
TfLiteTensor* output_num_selected_indices =
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
output_num_selected_indices->type = kTfLiteInt32;
SetTensorSizes(context, output_num_selected_indices, {});
if (is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
} else {
SetTensorToDynamic(output_selected_indices);
}
}
return kTfLiteOk;
@ -181,6 +191,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_max_output_size =
GetInput(context, node, kInputTensorMaxOutputSize);
const int max_output_size_value = *GetTensorData<int>(input_max_output_size);
TF_LITE_ENSURE(context, (max_output_size_value >= 0));
const bool is_max_output_size_const = IsConstantTensor(input_max_output_size);
const TfLiteTensor* input_iou_threshold =
GetInput(context, node, kInputTensorIouThreshold);
const float iou_threshold = *GetTensorData<float>(input_iou_threshold);
@ -208,6 +220,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOutput(context, node, kSoftNMSOutputTensorSelectedScores);
output_num_selected_indices =
GetOutput(context, node, kSoftNMSOutputTensorNumSelectedIndices);
if (!is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
SetTensorSizes(context, output_selected_scores, {max_output_size_value});
}
reference_ops::NonMaxSuppression(
input_boxes->data.f, num_boxes, input_scores->data.f,
max_output_size_value, iou_threshold, score_threshold, soft_nms_sigma,
@ -221,6 +237,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetOutput(context, node, kNMSOutputTensorSelectedIndices);
output_num_selected_indices =
GetOutput(context, node, kNMSOutputTensorNumSelectedIndices);
if (!is_max_output_size_const) {
SetTensorSizes(context, output_selected_indices, {max_output_size_value});
}
reference_ops::NonMaxSuppression(
input_boxes->data.f, num_boxes, input_scores->data.f,
max_output_size_value, iou_threshold, score_threshold, /**sigma=**/ 0.0,

View File

@ -29,6 +29,10 @@ class BaseNMSOp : public SingleOpModel {
PopulateTensor(input_scores_, data);
}
void SetMaxOutputSize(int max_output_size) {
PopulateTensor(input_max_output_size_, {max_output_size});
}
void SetScoreThreshold(float score_threshold) {
PopulateTensor(input_score_threshold_, {score_threshold});
}
@ -60,12 +64,18 @@ class BaseNMSOp : public SingleOpModel {
class NonMaxSuppressionV4OpModel : public BaseNMSOp {
public:
explicit NonMaxSuppressionV4OpModel(const int max_output_size,
const float iou_threshold) {
explicit NonMaxSuppressionV4OpModel(const float iou_threshold,
const bool static_shaped_outputs,
const int max_output_size = -1) {
const int num_boxes = 6;
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size});
if (static_shaped_outputs) {
input_max_output_size_ =
AddConstInput(TensorType_INT32, {max_output_size});
} else {
input_max_output_size_ = AddInput(TensorType_INT32, {});
}
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
@ -94,7 +104,9 @@ class NonMaxSuppressionV4OpModel : public BaseNMSOp {
};
TEST(NonMaxSuppressionV4OpModel, TestOutput) {
NonMaxSuppressionV4OpModel nms(6, 0.5);
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
/**static_shaped_outputs=**/ true,
/**max_output_size=**/ 6);
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
nms.SetScoreThreshold(0.4);
nms.Invoke();
@ -108,8 +120,32 @@ TEST(NonMaxSuppressionV4OpModel, TestOutput) {
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0, 0, 0, 0}));
}
TEST(NonMaxSuppressionV4OpModel, TestDynamicOutput) {
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
/**static_shaped_outputs=**/ false);
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
nms.SetScoreThreshold(0.4);
nms.SetMaxOutputSize(1);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
nms.SetMaxOutputSize(2);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
nms.SetScoreThreshold(0.99);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0}));
}
TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
NonMaxSuppressionV4OpModel nms(0, 0.5);
NonMaxSuppressionV4OpModel nms(/**iou_threshold=**/ 0.5,
/**static_shaped_outputs=**/ true,
/**max_output_size=**/ 0);
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
nms.SetScoreThreshold(0.4);
nms.Invoke();
@ -118,13 +154,19 @@ TEST(NonMaxSuppressionV4OpModel, TestOutputWithZeroMaxOutput) {
class NonMaxSuppressionV5OpModel : public BaseNMSOp {
public:
explicit NonMaxSuppressionV5OpModel(const int max_output_size,
const float iou_threshold,
const float sigma) {
explicit NonMaxSuppressionV5OpModel(const float iou_threshold,
const float sigma,
const bool static_shaped_outputs,
const int max_output_size = -1) {
const int num_boxes = 6;
input_boxes_ = AddInput({TensorType_FLOAT32, {num_boxes, 4}});
input_scores_ = AddInput({TensorType_FLOAT32, {num_boxes}});
input_max_output_size_ = AddConstInput(TensorType_INT32, {max_output_size});
if (static_shaped_outputs) {
input_max_output_size_ =
AddConstInput(TensorType_INT32, {max_output_size});
} else {
input_max_output_size_ = AddInput(TensorType_INT32, {});
}
input_iou_threshold_ = AddConstInput(TensorType_FLOAT32, {iou_threshold});
input_score_threshold_ = AddInput({TensorType_FLOAT32, {}});
input_sigma_ = AddConstInput(TensorType_FLOAT32, {sigma});
@ -155,7 +197,10 @@ class NonMaxSuppressionV5OpModel : public BaseNMSOp {
};
TEST(NonMaxSuppressionV5OpModel, TestOutput) {
NonMaxSuppressionV5OpModel nms(6, 0.5, 0.5);
NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
/**sigma=**/ 0.5,
/**static_shaped_outputs=**/ true,
/**max_output_size=**/ 6);
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
nms.SetScoreThreshold(0.0);
nms.Invoke();
@ -172,5 +217,39 @@ TEST(NonMaxSuppressionV5OpModel, TestOutput) {
EXPECT_THAT(nms.GetSelectedScores(),
ElementsAreArray({0.0, 0.0, 0.0, 0.0, 0.0, 0.0}));
}
TEST(NonMaxSuppressionV5OpModel, TestDynamicOutput) {
NonMaxSuppressionV5OpModel nms(/**iou_threshold=**/ 0.5,
/**sigma=**/ 0.5,
/**static_shaped_outputs=**/ false,
/**max_output_size=**/ 6);
nms.SetScores({0.9, 0.75, 0.6, 0.95, 0.5, 0.3});
nms.SetScoreThreshold(0.0);
nms.SetMaxOutputSize(2);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({2}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0}));
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9}));
nms.SetMaxOutputSize(1);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({1}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3}));
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95}));
nms.SetMaxOutputSize(3);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({3}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({3, 0, 5}));
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.95, 0.9, 0.3}));
// No candidate gets selected. But the outputs should be zeroed out.
nms.SetScoreThreshold(0.99);
nms.Invoke();
EXPECT_THAT(nms.GetNumSelectedIndices(), ElementsAreArray({0}));
EXPECT_THAT(nms.GetSelectedIndices(), ElementsAreArray({0, 0, 0}));
EXPECT_THAT(nms.GetSelectedScores(), ElementsAreArray({0.0, 0.0, 0.0}));
}
} // namespace
} // namespace tflite