Add an option to clip/not clip box outputs in CombinedNonMaxSuppression.
PiperOrigin-RevId: 243104621
This commit is contained in:
parent
9d724a8e60
commit
35e39cd3f1
@ -50,6 +50,14 @@ are padded/clipped to `max_total_size`. If true, the
|
|||||||
output nmsed boxes, scores and classes are padded to be of length
|
output nmsed boxes, scores and classes are padded to be of length
|
||||||
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
|
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
|
||||||
which case it is clipped to `max_total_size`. Defaults to false.
|
which case it is clipped to `max_total_size`. Defaults to false.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "clip_boxes"
|
||||||
|
description: <<END
|
||||||
|
If true, assume the box coordinates are between [0, 1] and clip the output boxes
|
||||||
|
if they fall beyond [0, 1]. If false, do not do clipping and output the box
|
||||||
|
coordinates as it is.
|
||||||
END
|
END
|
||||||
}
|
}
|
||||||
out_arg {
|
out_arg {
|
||||||
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "CombinedNonMaxSuppression"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -227,14 +227,11 @@ void BatchedNonMaxSuppressionOp(
|
|||||||
OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
|
OpKernelContext* context, const Tensor& inp_boxes, const Tensor& inp_scores,
|
||||||
int num_boxes, const int max_size_per_class, const int total_size_per_batch,
|
int num_boxes, const int max_size_per_class, const int total_size_per_batch,
|
||||||
const float score_threshold, const float iou_threshold,
|
const float score_threshold, const float iou_threshold,
|
||||||
bool pad_per_class = false) {
|
bool pad_per_class = false, bool clip_boxes = true) {
|
||||||
int q = inp_boxes.dim_size(2);
|
int q = inp_boxes.dim_size(2);
|
||||||
int num_classes = inp_scores.dim_size(2);
|
int num_classes = inp_scores.dim_size(2);
|
||||||
const int num_batches = inp_boxes.dim_size(0);
|
const int num_batches = inp_boxes.dim_size(0);
|
||||||
|
|
||||||
// Default clip window of [0, 0, 1, 1] if none specified
|
|
||||||
std::vector<float> clip_window{0, 0, 1, 1};
|
|
||||||
|
|
||||||
// [num_batches, per_batch_size * 4]
|
// [num_batches, per_batch_size * 4]
|
||||||
std::vector<std::vector<float>> nmsed_boxes(num_batches);
|
std::vector<std::vector<float>> nmsed_boxes(num_batches);
|
||||||
// [num_batches, per_batch_size]
|
// [num_batches, per_batch_size]
|
||||||
@ -375,18 +372,23 @@ void BatchedNonMaxSuppressionOp(
|
|||||||
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
|
while (curr_total_size > 0 && result_idx < result_candidate_vec.size()) {
|
||||||
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
|
ResultCandidate next_candidate = result_candidate_vec[result_idx++];
|
||||||
// Add to final output vectors
|
// Add to final output vectors
|
||||||
nmsed_boxes[batch].push_back(
|
if (clip_boxes) {
|
||||||
std::max(std::min(next_candidate.box_coord[0], clip_window[2]),
|
const float box_min = 0.0;
|
||||||
clip_window[0]));
|
const float box_max = 1.0;
|
||||||
nmsed_boxes[batch].push_back(
|
nmsed_boxes[batch].push_back(
|
||||||
std::max(std::min(next_candidate.box_coord[1], clip_window[3]),
|
std::max(std::min(next_candidate.box_coord[0], box_max), box_min));
|
||||||
clip_window[1]));
|
nmsed_boxes[batch].push_back(
|
||||||
nmsed_boxes[batch].push_back(
|
std::max(std::min(next_candidate.box_coord[1], box_max), box_min));
|
||||||
std::max(std::min(next_candidate.box_coord[2], clip_window[2]),
|
nmsed_boxes[batch].push_back(
|
||||||
clip_window[0]));
|
std::max(std::min(next_candidate.box_coord[2], box_max), box_min));
|
||||||
nmsed_boxes[batch].push_back(
|
nmsed_boxes[batch].push_back(
|
||||||
std::max(std::min(next_candidate.box_coord[3], clip_window[3]),
|
std::max(std::min(next_candidate.box_coord[3], box_max), box_min));
|
||||||
clip_window[1]));
|
} else {
|
||||||
|
nmsed_boxes[batch].push_back(next_candidate.box_coord[0]);
|
||||||
|
nmsed_boxes[batch].push_back(next_candidate.box_coord[1]);
|
||||||
|
nmsed_boxes[batch].push_back(next_candidate.box_coord[2]);
|
||||||
|
nmsed_boxes[batch].push_back(next_candidate.box_coord[3]);
|
||||||
|
}
|
||||||
nmsed_scores[batch].push_back(next_candidate.score);
|
nmsed_scores[batch].push_back(next_candidate.score);
|
||||||
nmsed_classes[batch].push_back(next_candidate.class_idx);
|
nmsed_classes[batch].push_back(next_candidate.class_idx);
|
||||||
curr_total_size--;
|
curr_total_size--;
|
||||||
@ -679,6 +681,7 @@ class CombinedNonMaxSuppressionOp : public OpKernel {
|
|||||||
explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
|
explicit CombinedNonMaxSuppressionOp(OpKernelConstruction* context)
|
||||||
: OpKernel(context) {
|
: OpKernel(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
|
OP_REQUIRES_OK(context, context->GetAttr("pad_per_class", &pad_per_class_));
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("clip_boxes", &clip_boxes_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
@ -734,11 +737,12 @@ class CombinedNonMaxSuppressionOp : public OpKernel {
|
|||||||
BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
|
BatchedNonMaxSuppressionOp(context, boxes, scores, num_boxes,
|
||||||
max_size_per_class, max_total_size_per_batch,
|
max_size_per_class, max_total_size_per_batch,
|
||||||
score_threshold_val, iou_threshold_val,
|
score_threshold_val, iou_threshold_val,
|
||||||
pad_per_class_);
|
pad_per_class_, clip_boxes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool pad_per_class_;
|
bool pad_per_class_;
|
||||||
|
bool clip_boxes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
|
||||||
|
@ -863,7 +863,7 @@ TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestEmptyInput) {
|
|||||||
|
|
||||||
class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
|
class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
|
||||||
protected:
|
protected:
|
||||||
void MakeOp(bool pad_per_class = false) {
|
void MakeOp(bool pad_per_class = false, bool clip_boxes = true) {
|
||||||
TF_EXPECT_OK(NodeDefBuilder("combined_non_max_suppression_op",
|
TF_EXPECT_OK(NodeDefBuilder("combined_non_max_suppression_op",
|
||||||
"CombinedNonMaxSuppression")
|
"CombinedNonMaxSuppression")
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
@ -873,6 +873,7 @@ class CombinedNonMaxSuppressionOpTest : public OpsTestBase {
|
|||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(DT_FLOAT))
|
||||||
.Attr("pad_per_class", pad_per_class)
|
.Attr("pad_per_class", pad_per_class)
|
||||||
|
.Attr("clip_boxes", clip_boxes)
|
||||||
.Finalize(node_def()));
|
.Finalize(node_def()));
|
||||||
TF_EXPECT_OK(InitOp());
|
TF_EXPECT_OK(InitOp());
|
||||||
}
|
}
|
||||||
@ -942,6 +943,39 @@ TEST_F(CombinedNonMaxSuppressionOpTest, TestSelectFromThreeClusters) {
|
|||||||
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
|
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CombinedNonMaxSuppressionOpTest,
|
||||||
|
TestSelectFromThreeClustersNoBoxClipping) {
|
||||||
|
MakeOp(false, false);
|
||||||
|
AddInputFromArray<float>(TensorShape({1, 6, 1, 4}),
|
||||||
|
{0, 0, 10, 10, 0, 1, 10, 11, 0, 1, 10, 9,
|
||||||
|
0, 11, 10, 20, 0, 12, 10, 21, 0, 30, 100, 40});
|
||||||
|
AddInputFromArray<float>(TensorShape({1, 6, 1}),
|
||||||
|
{.9f, .75f, .6f, .95f, .5f, .3f});
|
||||||
|
AddInputFromArray<int>(TensorShape({}), {3});
|
||||||
|
AddInputFromArray<int>(TensorShape({}), {3});
|
||||||
|
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||||
|
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// boxes
|
||||||
|
Tensor expected_boxes(allocator(), DT_FLOAT, TensorShape({1, 3, 4}));
|
||||||
|
test::FillValues<float>(&expected_boxes,
|
||||||
|
{0, 11, 10, 20, 0, 0, 10, 10, 0, 30, 100, 40});
|
||||||
|
test::ExpectTensorEqual<float>(expected_boxes, *GetOutput(0));
|
||||||
|
// scores
|
||||||
|
Tensor expected_scores(allocator(), DT_FLOAT, TensorShape({1, 3}));
|
||||||
|
test::FillValues<float>(&expected_scores, {0.95, 0.9, 0.3});
|
||||||
|
test::ExpectTensorEqual<float>(expected_scores, *GetOutput(1));
|
||||||
|
// classes
|
||||||
|
Tensor expected_classes(allocator(), DT_FLOAT, TensorShape({1, 3}));
|
||||||
|
test::FillValues<float>(&expected_classes, {0, 0, 0});
|
||||||
|
test::ExpectTensorEqual<float>(expected_classes, *GetOutput(2));
|
||||||
|
// valid
|
||||||
|
Tensor expected_valid_d(allocator(), DT_INT32, TensorShape({1}));
|
||||||
|
test::FillValues<int>(&expected_valid_d, {3});
|
||||||
|
test::ExpectTensorEqual<int>(expected_valid_d, *GetOutput(3));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(CombinedNonMaxSuppressionOpTest,
|
TEST_F(CombinedNonMaxSuppressionOpTest,
|
||||||
TestSelectFromThreeClustersWithScoreThreshold) {
|
TestSelectFromThreeClustersWithScoreThreshold) {
|
||||||
MakeOp();
|
MakeOp();
|
||||||
|
@ -922,6 +922,7 @@ REGISTER_OP("CombinedNonMaxSuppression")
|
|||||||
.Output("nmsed_classes: float")
|
.Output("nmsed_classes: float")
|
||||||
.Output("valid_detections: int32")
|
.Output("valid_detections: int32")
|
||||||
.Attr("pad_per_class: bool = false")
|
.Attr("pad_per_class: bool = false")
|
||||||
|
.Attr("clip_boxes: bool = true")
|
||||||
.SetShapeFn(CombinedNMSShapeFn);
|
.SetShapeFn(CombinedNMSShapeFn);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -3496,6 +3496,7 @@ def combined_non_max_suppression(boxes,
|
|||||||
iou_threshold=0.5,
|
iou_threshold=0.5,
|
||||||
score_threshold=float('-inf'),
|
score_threshold=float('-inf'),
|
||||||
pad_per_class=False,
|
pad_per_class=False,
|
||||||
|
clip_boxes=True,
|
||||||
name=None):
|
name=None):
|
||||||
"""Greedily selects a subset of bounding boxes in descending order of score.
|
"""Greedily selects a subset of bounding boxes in descending order of score.
|
||||||
|
|
||||||
@ -3532,6 +3533,9 @@ def combined_non_max_suppression(boxes,
|
|||||||
scores and classes are padded to be of length
|
scores and classes are padded to be of length
|
||||||
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
|
`max_size_per_class`*`num_classes`, unless it exceeds `max_total_size` in
|
||||||
which case it is clipped to `max_total_size`. Defaults to false.
|
which case it is clipped to `max_total_size`. Defaults to false.
|
||||||
|
clip_boxes: If true, the coordinates of output nmsed boxes will be clipped
|
||||||
|
to [0, 1]. If false, output the box coordinates as it is. Defaults to
|
||||||
|
true.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -3553,7 +3557,7 @@ def combined_non_max_suppression(boxes,
|
|||||||
score_threshold, dtype=dtypes.float32, name='score_threshold')
|
score_threshold, dtype=dtypes.float32, name='score_threshold')
|
||||||
return gen_image_ops.combined_non_max_suppression(
|
return gen_image_ops.combined_non_max_suppression(
|
||||||
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
|
boxes, scores, max_output_size_per_class, max_total_size, iou_threshold,
|
||||||
score_threshold, pad_per_class)
|
score_threshold, pad_per_class, clip_boxes)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('image.draw_bounding_boxes', v1=[])
|
@tf_export('image.draw_bounding_boxes', v1=[])
|
||||||
|
@ -34,7 +34,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "combined_non_max_suppression"
|
name: "combined_non_max_suppression"
|
||||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
|
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "convert_image_dtype"
|
name: "convert_image_dtype"
|
||||||
|
@ -964,10 +964,6 @@ tf_module {
|
|||||||
name: "colocate_with"
|
name: "colocate_with"
|
||||||
argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
argspec: "args=[\'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "combined_non_max_suppression"
|
|
||||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "complex"
|
name: "complex"
|
||||||
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -618,7 +618,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CombinedNonMaxSuppression"
|
name: "CombinedNonMaxSuppression"
|
||||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CompareAndBitpack"
|
name: "CompareAndBitpack"
|
||||||
|
@ -34,7 +34,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "combined_non_max_suppression"
|
name: "combined_non_max_suppression"
|
||||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
|
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "convert_image_dtype"
|
name: "convert_image_dtype"
|
||||||
|
@ -508,10 +508,6 @@ tf_module {
|
|||||||
name: "clip_by_value"
|
name: "clip_by_value"
|
||||||
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'t\', \'clip_value_min\', \'clip_value_max\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "combined_non_max_suppression"
|
|
||||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "complex"
|
name: "complex"
|
||||||
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -618,7 +618,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CombinedNonMaxSuppression"
|
name: "CombinedNonMaxSuppression"
|
||||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'boxes\', \'scores\', \'max_output_size_per_class\', \'max_total_size\', \'iou_threshold\', \'score_threshold\', \'pad_per_class\', \'clip_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CompareAndBitpack"
|
name: "CompareAndBitpack"
|
||||||
|
Loading…
Reference in New Issue
Block a user