Add an option to clip/not clip box outputs in CombinedNonMaxSuppression.

PiperOrigin-RevId: 243104621
This commit is contained in:
Pengchong Jin 2019-04-11 11:13:09 -07:00 committed by TensorFlower Gardener
parent 9d724a8e60
commit 35e39cd3f1
12 changed files with 78 additions and 31 deletions

View File

@ -50,6 +50,14 @@ are padded/clipped to `max_total_size`. If true, the
output nmsed boxes, scores and classes are padded to be of length output nmsed boxes, scores and classes are padded to be of length
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in `max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
which case it is clipped to `max_total_size`. Defaults to false. which case it is clipped to `max_total_size`. Defaults to false.
END
}
attr {
name: "clip_boxes"
description: <<END
If true, assume the box coordinates are between [0, 1] and clip the output boxes
if they fall beyond [0, 1]. If false, do not do clipping and output the box
coordinates as it is.
END END
} }
out_arg { out_arg {

View File

@ -0,0 +1,4 @@
op {
graph_op_name: "CombinedNonMaxSuppression"
visibility: HIDDEN
}

View File

@ -227,14 +227,11 @@ void BatchedNonMaxSuppressionOp(
OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores, OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
int num_boxes, const int max_size_per_class, const int total_size_per_batch, int num_boxes, const int max_size_per_class, const int total_size_per_batch,
const float score_threshold, const float iou_threshold, const float score_threshold, const float iou_threshold,
bool pad_per_class = false) { bool pad_per_class = false, bool clip_boxes = true) {
int q = inp_boxes.dim_size(2); int q = inp_boxes.dim_size(2);
int num_classes = inp_scores.dim_size(2); int num_classes = inp_scores.dim_size(2);
const int num_batches = inp_boxes.dim_size(0); const int num_batches = inp_boxes.dim_size(0);
// Default clip window of [0, 0, 1, 1] if none specified
std::vector<float> clip_window{0, 0, 1, 1};
// [num_batches, per_batch_size * 4] // [num_batches, per_batch_size * 4]
std::vector<std::vector<float>> nmsed_boxes(num_batches); std::vector<std::vector<float>> nmsed_boxes(num_batches);
// [num_batches, per_batch_size] // [num_batches, per_batch_size]
@ -375,18 +372,23 @@ void BatchedNonMaxSuppressionOp(
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) { while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
ResultCandidate next_candidate = result_candidate_vec[result_idx++]; ResultCandidate next_candidate = result_candidate_vec[result_idx++];
// Add to final output vectors // Add to final output vectors
if (clip_boxes) {
const float box_min = 0.0;
const float box_max = 1.0;
nmsed_boxes[batch].push_back( nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[0], clip_window[2]), std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
clip_window[0]));
nmsed_boxes[batch].push_back( nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[1], clip_window[3]), std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
clip_window[1]));
nmsed_boxes[batch].push_back( nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[2], clip_window[2]), std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
clip_window[0]));
nmsed_boxes[batch].push_back( nmsed_boxes[batch].push_back(
std::max(std::min(next_candidate.box_coord[3], clip_window[3]), std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
clip_window[1])); } else {
nmsed_boxes[batch].push_back(next_candidate.box_coord[0]);
nmsed_boxes[batch].push_back(next_candidate.box_coord[1]);
nmsed_boxes[batch].push_back(next_candidate.box_coord[2]);
nmsed_boxes[batch].push_back(next_candidate.box_coord[3]);
}
nmsed_scores[batch].push_back(next_candidate.score); nmsed_scores[batch].push_back(next_candidate.score);
nmsed_classes[batch].push_back(next_candidate.class_idx); nmsed_classes[batch].push_back(next_candidate.class_idx);
curr_total_size--; curr_total_size--;
@ -679,6 +681,7 @@ class CombinedNonMaxSuppressionOp : public OpKernel {
explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context) explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
: OpKernel(context) { : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_)); OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
OP_REQUIRES_OK(context, context->GetAttr("clip_boxes", &clip_boxes_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -734,11 +737,12 @@ class CombinedNonMaxSuppressionOp : public OpKernel {
BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes, BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
max_size_per_class, max_total_size_per_batch, max_size_per_class, max_total_size_per_batch,
score_threshold_val, iou_threshold_val, score_threshold_val, iou_threshold_val,
pad_per_class_); pad_per_class_, clip_boxes_);
} }
private: private:
bool pad_per_class_; bool pad_per_class_;
bool clip_boxes_;
}; };
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),

View File

@ -863,7 +863,7 @@ TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestEmptyInput) {
class CombinedNonMaxSuppressionOpTest : public OpsTestBase { class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
protected: protected:
void MakeOp(bool pad_per_class = false) { void MakeOp(bool pad_per_class = false, bool clip_boxes = true) {
TF_EXPECT_OK(NodeDefBuilder("combined_non_max_suppression_op", TF_EXPECT_OK(NodeDefBuilder("combined_non_max_suppression_op",
"CombinedNonMaxSuppression") "CombinedNonMaxSuppression")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
@ -873,6 +873,7 @@ class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Attr("pad_per_class", pad_per_class) .Attr("pad_per_class", pad_per_class)
.Attr("clip_boxes", clip_boxes)
.Finalize(node_def())); .Finalize(node_def()));
TF_EXPECT_OK(InitOp()); TF_EXPECT_OK(InitOp());
} }
@ -942,6 +943,39 @@ TEST_F(CombinedNonMaxSuppressionOpTest, TestSelectFromThreeClusters) {
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3)); test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
} }
TEST_F(CombinedNonMaxSuppressionOpTest,
TestSelectFromThreeClustersNoBoxClipping) {
MakeOp(false, false);
AddInputFromArray<float>(TensorShape({1, 6, 1, 4}),
{0, 0, 10, 10, 0, 1, 10, 11, 0, 1, 10, 9,
0, 11, 10, 20, 0, 12, 10, 21, 0, 30, 100, 40});
AddInputFromArray<float>(TensorShape({1, 6, 1}),
{.9f, .75f, .6f, .95f, .5f, .3f});
AddInputFromArray<int>(TensorShape({}), {3});
AddInputFromArray<int>(TensorShape({}), {3});
AddInputFromArray<float>(TensorShape({}), {.5f});
AddInputFromArray<float>(TensorShape({}), {0.0f});
TF_ASSERT_OK(RunOpKernel());
// boxes
Tensor expected_boxes(allocator(), DT_FLOAT, TensorShape({1, 3, 4}));
test::FillValues<float>(&expected_boxes,
{0, 11, 10, 20, 0, 0, 10, 10, 0, 30, 100, 40});
test::ExpectTensorEqual<float>(expected_boxes, *GetOutput(0));
// scores
Tensor expected_scores(allocator(), DT_FLOAT, TensorShape({1, 3}));
test::FillValues<float>(&expected_scores, {0.95, 0.9, 0.3});
test::ExpectTensorEqual<float>(expected_scores, *GetOutput(1));
// classes
Tensor expected_classes(allocator(), DT_FLOAT, TensorShape({1, 3}));
test::FillValues<float>(&expected_classes, {0, 0, 0});
test::ExpectTensorEqual<float>(expected_classes, *GetOutput(2));
// valid
Tensor expected_valid_d(allocator(), DT_INT32, TensorShape({1}));
test::FillValues<int>(&expected_valid_d, {3});
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
}
TEST_F(CombinedNonMaxSuppressionOpTest, TEST_F(CombinedNonMaxSuppressionOpTest,
TestSelectFromThreeClustersWithScoreThreshold) { TestSelectFromThreeClustersWithScoreThreshold) {
MakeOp(); MakeOp();

View File

@ -922,6 +922,7 @@ REGISTER_OP("CombinedNonMaxSuppression")
.Output("nmsed_classes: float") .Output("nmsed_classes: float")
.Output("valid_detections: int32") .Output("valid_detections: int32")
.Attr("pad_per_class: bool = false") .Attr("pad_per_class: bool = false")
.Attr("clip_boxes: bool = true")
.SetShapeFn(CombinedNMSShapeFn); .SetShapeFn(CombinedNMSShapeFn);
} // namespace tensorflow } // namespace tensorflow

View File

@ -3496,6 +3496,7 @@ def combined_non_max_suppression(boxes,
iou_threshold=0.5, iou_threshold=0.5,
score_threshold=float('-inf'), score_threshold=float('-inf'),
pad_per_class=False, pad_per_class=False,
clip_boxes=True,
name=None): name=None):
"""Greedily selects a subset of bounding boxes in descending order of score. """Greedily selects a subset of bounding boxes in descending order of score.
@ -3532,6 +3533,9 @@ def combined_non_max_suppression(boxes,
scores and classes are padded to be of length scores and classes are padded to be of length
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in `max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
which case it is clipped to `max_total_size`. Defaults to false. which case it is clipped to `max_total_size`. Defaults to false.
clip_boxes: If true, the coordinates of output nmsed boxes will be clipped
to [0, 1]. If false, output the box coordinates as it is. Defaults to
true.
name: A name for the operation (optional). name: A name for the operation (optional).
Returns: Returns:
@ -3553,7 +3557,7 @@ def combined_non_max_suppression(boxes,
score_threshold, dtype=dtypes.float32, name='score_threshold') score_threshold, dtype=dtypes.float32, name='score_threshold')
return gen_image_ops.combined_non_max_suppression( return gen_image_ops.combined_non_max_suppression(
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
score_threshold, pad_per_class) score_threshold, pad_per_class, clip_boxes)
@tf_export('image.draw_bounding_boxes', v1=[]) @tf_export('image.draw_bounding_boxes', v1=[])

View File

@ -34,7 +34,7 @@ tf_module {
} }
member_method { member_method {
name: "combined_non_max_suppression" name: "combined_non_max_suppression"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], " argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "convert_image_dtype" name: "convert_image_dtype"

View File

@ -964,10 +964,6 @@ tf_module {
name: "colocate_with" name: "colocate_with"
argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], " argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], "
} }
member_method {
name: "combined_non_max_suppression"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method { member_method {
name: "complex" name: "complex"
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -618,7 +618,7 @@ tf_module {
} }
member_method { member_method {
name: "CombinedNonMaxSuppression" name: "CombinedNonMaxSuppression"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "CompareAndBitpack" name: "CompareAndBitpack"

View File

@ -34,7 +34,7 @@ tf_module {
} }
member_method { member_method {
name: "combined_non_max_suppression" name: "combined_non_max_suppression"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], " argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "convert_image_dtype" name: "convert_image_dtype"

View File

@ -508,10 +508,6 @@ tf_module {
name: "clip_by_value" name: "clip_by_value"
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "combined_non_max_suppression"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method { member_method {
name: "complex" name: "complex"
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -618,7 +618,7 @@ tf_module {
} }
member_method { member_method {
name: "CombinedNonMaxSuppression" name: "CombinedNonMaxSuppression"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
} }
member_method { member_method {
name: "CompareAndBitpack" name: "CompareAndBitpack"