Add an option to clip/not clip box outputs in CombinedNonMaxSuppression.
PiperOrigin-RevId: 243104621
This commit is contained in:
parent
9d724a8e60
commit
35e39cd3f1
@ -50,6 +50,14 @@ are padded/clipped to `max_total_size`. If true, the
|
||||
output nmsed boxes, scores and classes are padded to be of length
|
||||
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
|
||||
which case it is clipped to `max_total_size`. Defaults to false.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "clip_boxes"
|
||||
description: <<END
|
||||
If true, assume the box coordinates are between [0, 1] and clip the output boxes
|
||||
if they fall beyond [0, 1]. If false, do not do clipping and output the box
|
||||
coordinates as it is.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "CombinedNonMaxSuppression"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -227,14 +227,11 @@ void BatchedNonMaxSuppressionOp(
|
||||
OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
|
||||
int num_boxes, const int max_size_per_class, const int total_size_per_batch,
|
||||
const float score_threshold, const float iou_threshold,
|
||||
bool pad_per_class = false) {
|
||||
bool pad_per_class = false, bool clip_boxes = true) {
|
||||
int q = inp_boxes.dim_size(2);
|
||||
int num_classes = inp_scores.dim_size(2);
|
||||
const int num_batches = inp_boxes.dim_size(0);
|
||||
|
||||
// Default clip window of [0, 0, 1, 1] if none specified
|
||||
std::vector<float> clip_window{0, 0, 1, 1};
|
||||
|
||||
// [num_batches, per_batch_size * 4]
|
||||
std::vector<std::vector<float>> nmsed_boxes(num_batches);
|
||||
// [num_batches, per_batch_size]
|
||||
@ -375,18 +372,23 @@ void BatchedNonMaxSuppressionOp(
|
||||
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
|
||||
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
|
||||
// Add to final output vectors
|
||||
if (clip_boxes) {
|
||||
const float box_min = 0.0;
|
||||
const float box_max = 1.0;
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[0], clip_window[2]),
|
||||
clip_window[0]));
|
||||
std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[1], clip_window[3]),
|
||||
clip_window[1]));
|
||||
std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[2], clip_window[2]),
|
||||
clip_window[0]));
|
||||
std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
|
||||
nmsed_boxes[batch].push_back(
|
||||
std::max(std::min(next_candidate.box_coord[3], clip_window[3]),
|
||||
clip_window[1]));
|
||||
std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
|
||||
} else {
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[0]);
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[1]);
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[2]);
|
||||
nmsed_boxes[batch].push_back(next_candidate.box_coord[3]);
|
||||
}
|
||||
nmsed_scores[batch].push_back(next_candidate.score);
|
||||
nmsed_classes[batch].push_back(next_candidate.class_idx);
|
||||
curr_total_size--;
|
||||
@ -679,6 +681,7 @@ class CombinedNonMaxSuppressionOp : public OpKernel {
|
||||
explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("clip_boxes", &clip_boxes_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -734,11 +737,12 @@ class CombinedNonMaxSuppressionOp : public OpKernel {
|
||||
BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
|
||||
max_size_per_class, max_total_size_per_batch,
|
||||
score_threshold_val, iou_threshold_val,
|
||||
pad_per_class_);
|
||||
pad_per_class_, clip_boxes_);
|
||||
}
|
||||
|
||||
private:
|
||||
bool pad_per_class_;
|
||||
bool clip_boxes_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
|
||||
|
@ -863,7 +863,7 @@ TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestEmptyInput) {
|
||||
|
||||
class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(bool pad_per_class = false) {
|
||||
void MakeOp(bool pad_per_class = false, bool clip_boxes = true) {
|
||||
TF_EXPECT_OK(NodeDefBuilder("combined_non_max_suppression_op",
|
||||
"CombinedNonMaxSuppression")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
@ -873,6 +873,7 @@ class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("pad_per_class", pad_per_class)
|
||||
.Attr("clip_boxes", clip_boxes)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
}
|
||||
@ -942,6 +943,39 @@ TEST_F(CombinedNonMaxSuppressionOpTest, TestSelectFromThreeClusters) {
|
||||
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
|
||||
}
|
||||
|
||||
TEST_F(CombinedNonMaxSuppressionOpTest,
|
||||
TestSelectFromThreeClustersNoBoxClipping) {
|
||||
MakeOp(false, false);
|
||||
AddInputFromArray<float>(TensorShape({1, 6, 1, 4}),
|
||||
{0, 0, 10, 10, 0, 1, 10, 11, 0, 1, 10, 9,
|
||||
0, 11, 10, 20, 0, 12, 10, 21, 0, 30, 100, 40});
|
||||
AddInputFromArray<float>(TensorShape({1, 6, 1}),
|
||||
{.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// boxes
|
||||
Tensor expected_boxes(allocator(), DT_FLOAT, TensorShape({1, 3, 4}));
|
||||
test::FillValues<float>(&expected_boxes,
|
||||
{0, 11, 10, 20, 0, 0, 10, 10, 0, 30, 100, 40});
|
||||
test::ExpectTensorEqual<float>(expected_boxes, *GetOutput(0));
|
||||
// scores
|
||||
Tensor expected_scores(allocator(), DT_FLOAT, TensorShape({1, 3}));
|
||||
test::FillValues<float>(&expected_scores, {0.95, 0.9, 0.3});
|
||||
test::ExpectTensorEqual<float>(expected_scores, *GetOutput(1));
|
||||
// classes
|
||||
Tensor expected_classes(allocator(), DT_FLOAT, TensorShape({1, 3}));
|
||||
test::FillValues<float>(&expected_classes, {0, 0, 0});
|
||||
test::ExpectTensorEqual<float>(expected_classes, *GetOutput(2));
|
||||
// valid
|
||||
Tensor expected_valid_d(allocator(), DT_INT32, TensorShape({1}));
|
||||
test::FillValues<int>(&expected_valid_d, {3});
|
||||
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
|
||||
}
|
||||
|
||||
TEST_F(CombinedNonMaxSuppressionOpTest,
|
||||
TestSelectFromThreeClustersWithScoreThreshold) {
|
||||
MakeOp();
|
||||
|
@ -922,6 +922,7 @@ REGISTER_OP("CombinedNonMaxSuppression")
|
||||
.Output("nmsed_classes: float")
|
||||
.Output("valid_detections: int32")
|
||||
.Attr("pad_per_class: bool = false")
|
||||
.Attr("clip_boxes: bool = true")
|
||||
.SetShapeFn(CombinedNMSShapeFn);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -3496,6 +3496,7 @@ def combined_non_max_suppression(boxes,
|
||||
iou_threshold=0.5,
|
||||
score_threshold=float('-inf'),
|
||||
pad_per_class=False,
|
||||
clip_boxes=True,
|
||||
name=None):
|
||||
"""Greedily selects a subset of bounding boxes in descending order of score.
|
||||
|
||||
@ -3532,6 +3533,9 @@ def combined_non_max_suppression(boxes,
|
||||
scores and classes are padded to be of length
|
||||
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
|
||||
which case it is clipped to `max_total_size`. Defaults to false.
|
||||
clip_boxes: If true, the coordinates of output nmsed boxes will be clipped
|
||||
to [0, 1]. If false, output the box coordinates as it is. Defaults to
|
||||
true.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
@ -3553,7 +3557,7 @@ def combined_non_max_suppression(boxes,
|
||||
score_threshold, dtype=dtypes.float32, name='score_threshold')
|
||||
return gen_image_ops.combined_non_max_suppression(
|
||||
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
|
||||
score_threshold, pad_per_class)
|
||||
score_threshold, pad_per_class, clip_boxes)
|
||||
|
||||
|
||||
@tf_export('image.draw_bounding_boxes', v1=[])
|
||||
|
@ -34,7 +34,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "combined_non_max_suppression"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert_image_dtype"
|
||||
|
@ -964,10 +964,6 @@ tf_module {
|
||||
name: "colocate_with"
|
||||
argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "combined_non_max_suppression"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "complex"
|
||||
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -618,7 +618,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CombinedNonMaxSuppression"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CompareAndBitpack"
|
||||
|
@ -34,7 +34,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "combined_non_max_suppression"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "convert_image_dtype"
|
||||
|
@ -508,10 +508,6 @@ tf_module {
|
||||
name: "clip_by_value"
|
||||
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "combined_non_max_suppression"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "complex"
|
||||
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -618,7 +618,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CombinedNonMaxSuppression"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CompareAndBitpack"
|
||||
|
Loading…
Reference in New Issue
Block a user