Replace non_max_suppression_padded with a new implementation that supports batched inputs. This new implementation is considerably faster than previous implementations on TPU and GPU.
PiperOrigin-RevId: 302730463 Change-Id: I3ff9a1204dc892b7c6688b6b400281ff9cd5331f
This commit is contained in:
parent
21c8683cad
commit
448a04e7c4
@ -32,6 +32,7 @@ package_group(
|
||||
packages = [
|
||||
# To pass open source testing in the pip Kokoros.
|
||||
"//bazel_pip/tensorflow/compiler/tests/...",
|
||||
"//platforms/xla/tests/neural_nets",
|
||||
],
|
||||
)
|
||||
|
||||
@ -747,6 +748,7 @@ tf_xla_py_test(
|
||||
tf_xla_py_test(
|
||||
name = "image_ops_test",
|
||||
size = "small",
|
||||
timeout = "long",
|
||||
srcs = ["image_ops_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
|
@ -976,5 +976,275 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||
self.assertAllClose(indices_tf[:num_valid], [0, 2, 4])
|
||||
|
||||
|
||||
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||
|
||||
def testBatchedNMSFrom6(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
|
||||
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
|
||||
scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
|
||||
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
|
||||
max_output_size = 6
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
pad_to_max_output_size=True,
|
||||
sorted_input=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
invalid_index = len(boxes_data[0]) - 1
|
||||
self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index],
|
||||
[0, 1, 3, 5, invalid_index, invalid_index]],
|
||||
indices_output)
|
||||
self.assertAllEqual([5, 4], num_valid_output)
|
||||
|
||||
def testBatchedNMSFrom6Max3(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
|
||||
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
|
||||
scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
|
||||
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
|
||||
max_output_size = 3
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
pad_to_max_output_size=True,
|
||||
sorted_input=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
||||
self.assertAllEqual([3, 3], num_valid_output)
|
||||
|
||||
def testBatchedNMSSingleFrom6Max3(self):
|
||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
scores_data = [0.9, 0.7, 0.6, 0.5, 0.4, 0.3]
|
||||
max_output_size = 3
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
pad_to_max_output_size=True,
|
||||
sorted_input=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
self.assertAllEqual([0, 1, 2], indices_output)
|
||||
self.assertAllEqual(3, num_valid_output)
|
||||
|
||||
def testBatchedNMSSingleFrom6NoPad(self):
|
||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||
scores_data = [0.9, 0.7, 0.6, 0.5, 0.4, 0.3]
|
||||
max_output_size = 6
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
sorted_input=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
||||
self.assertAllEqual(5, num_valid_output)
|
||||
|
||||
def testBatchedNMSBatchDimsFrom6Max3(self):
|
||||
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
|
||||
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]]
|
||||
scores_data = [[[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
|
||||
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]]
|
||||
max_output_size = 3
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
pad_to_max_output_size=True,
|
||||
sorted_input=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
||||
self.assertAllEqual([[3, 3]], num_valid_output)
|
||||
|
||||
def testBatchedNMSScoreThresholdFrom6Max3(self):
|
||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
|
||||
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
|
||||
scores_data = [[0.9, 0.7, 0.6, 0.4, 0.3, 0.2],
|
||||
[0.8, 0.7, 0.6, 0.4, 0.3, 0.1]]
|
||||
max_output_size = 3
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
score_threshold=0.5,
|
||||
pad_to_max_output_size=True,
|
||||
sorted_input=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
invalid_index = len(boxes_data[0]) - 1
|
||||
self.assertAllEqual([3, 2], num_valid_output)
|
||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||
|
||||
def testBatchedNMSUnsortedInputFrom6(self):
|
||||
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
||||
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
||||
[[0, 0.4, 1, 1.4], [0, 2, 1, 2], [0, 0.2, 1, 1.2],
|
||||
[0, 0, 1, 1], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]]]
|
||||
scores_data = [[0.3, 0.7, 0.9, 0.6, 0.5, 0.4],
|
||||
[0.5, 0.8, 0.4, 0.3, 0.6, 0.7]]
|
||||
max_output_size = 6
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
pad_to_max_output_size=True,
|
||||
canonicalized_coordinates=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
self.assertAllEqual([[2, 1, 3, 5, 0, 0], [1, 5, 0, 3, 3, 3]],
|
||||
indices_output)
|
||||
self.assertAllEqual([5, 4], num_valid_output)
|
||||
|
||||
def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
||||
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
||||
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
||||
[[1, 2, 0, 2], [1, 0.8, 0, 1.8], [1, 0.6, 0, 1.6],
|
||||
[1, 0.4, 0, 1.4], [1, 0.2, 0, 1.2], [1, 0, 0, 1]]]
|
||||
|
||||
scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
|
||||
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
|
||||
max_output_size = 6
|
||||
iou_threshold = 0.5
|
||||
boxes_np = np.array(boxes_data, dtype=np.float32)
|
||||
scores_np = np.array(scores_data, dtype=np.float32)
|
||||
|
||||
with self.session() as sess:
|
||||
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
|
||||
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
|
||||
|
||||
with self.test_scope():
|
||||
(indices, num_valid) = image_ops.non_max_suppression_padded(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
max_output_size=max_output_size,
|
||||
iou_threshold=iou_threshold,
|
||||
pad_to_max_output_size=True,
|
||||
sorted_input=True)
|
||||
|
||||
inputs = {
|
||||
boxes: boxes_np,
|
||||
scores: scores_np
|
||||
}
|
||||
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
|
||||
invalid_index = len(boxes_data[0]) - 1
|
||||
self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index],
|
||||
[0, 1, 3, 5, invalid_index, invalid_index]],
|
||||
indices_output)
|
||||
self.assertAllEqual([5, 4], num_valid_output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import sort_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.util import deprecation
|
||||
@ -3015,7 +3016,7 @@ def non_max_suppression_with_scores(boxes,
|
||||
box coordinates corresponding to the selected indices can then be obtained
|
||||
using the `tf.gather` operation. For example:
|
||||
```python
|
||||
selected_indices, selected_scores = tf.image.non_max_suppression_v2(
|
||||
selected_indices, selected_scores = tf.image.non_max_suppression_padded(
|
||||
boxes, scores, max_output_size, iou_threshold=1.0, score_threshold=0.1,
|
||||
soft_nms_sigma=0.5)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
@ -3026,12 +3027,12 @@ def non_max_suppression_with_scores(boxes,
|
||||
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
|
||||
of other overlapping boxes instead of directly causing them to be pruned.
|
||||
Consequently, in contrast to `tf.image.non_max_suppression`,
|
||||
`tf.image.non_max_suppression_v2` returns the new scores of each input box in
|
||||
the second output, `selected_scores`.
|
||||
`tf.image.non_max_suppression_padded` returns the new scores of each input box
|
||||
in the second output, `selected_scores`.
|
||||
|
||||
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
|
||||
larger than 0. When `soft_nms_sigma` equals 0, the behavior of
|
||||
`tf.image.non_max_suppression_v2` is identical to that of
|
||||
`tf.image.non_max_suppression_padded` is identical to that of
|
||||
`tf.image.non_max_suppression` (except for the extra output) both in function
|
||||
and in running time.
|
||||
|
||||
@ -3077,62 +3078,6 @@ def non_max_suppression_with_scores(boxes,
|
||||
return selected_indices, selected_scores
|
||||
|
||||
|
||||
@tf_export('image.non_max_suppression_padded')
|
||||
def non_max_suppression_padded(boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold=0.5,
|
||||
score_threshold=float('-inf'),
|
||||
pad_to_max_output_size=False,
|
||||
name=None):
|
||||
"""Greedily selects a subset of bounding boxes in descending order of score.
|
||||
|
||||
Performs algorithmically equivalent operation to tf.image.non_max_suppression,
|
||||
with the addition of an optional parameter which zero-pads the output to
|
||||
be of size `max_output_size`.
|
||||
The output of this operation is a tuple containing the set of integers
|
||||
indexing into the input collection of bounding boxes representing the selected
|
||||
boxes and the number of valid indices in the index set. The bounding box
|
||||
coordinates corresponding to the selected indices can then be obtained using
|
||||
the `tf.slice` and `tf.gather` operations. For example:
|
||||
```python
|
||||
selected_indices_padded, num_valid = tf.image.non_max_suppression_padded(
|
||||
boxes, scores, max_output_size, iou_threshold,
|
||||
score_threshold, pad_to_max_output_size=True)
|
||||
selected_indices = tf.slice(
|
||||
selected_indices_padded, tf.constant([0]), num_valid)
|
||||
selected_boxes = tf.gather(boxes, selected_indices)
|
||||
```
|
||||
|
||||
Args:
|
||||
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
|
||||
scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
|
||||
score corresponding to each box (each row of boxes).
|
||||
max_output_size: A scalar integer `Tensor` representing the maximum number
|
||||
of boxes to be selected by non-max suppression.
|
||||
iou_threshold: A float representing the threshold for deciding whether boxes
|
||||
overlap too much with respect to IOU.
|
||||
score_threshold: A float representing the threshold for deciding when to
|
||||
remove boxes based on score.
|
||||
pad_to_max_output_size: bool. If True, size of `selected_indices` output is
|
||||
padded to `max_output_size`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the
|
||||
selected indices from the boxes tensor, where `M <= max_output_size`.
|
||||
valid_outputs: A scalar integer `Tensor` denoting how many elements in
|
||||
`selected_indices` are valid. Valid elements occur first, then padding.
|
||||
"""
|
||||
with ops.name_scope(name, 'non_max_suppression_padded'):
|
||||
iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold')
|
||||
score_threshold = ops.convert_to_tensor(
|
||||
score_threshold, name='score_threshold')
|
||||
return gen_image_ops.non_max_suppression_v4(boxes, scores, max_output_size,
|
||||
iou_threshold, score_threshold,
|
||||
pad_to_max_output_size)
|
||||
|
||||
|
||||
@tf_export('image.non_max_suppression_overlaps')
|
||||
def non_max_suppression_with_overlaps(overlaps,
|
||||
scores,
|
||||
@ -4312,6 +4257,411 @@ def combined_non_max_suppression(boxes,
|
||||
score_threshold, pad_per_class, clip_boxes)
|
||||
|
||||
|
||||
def _bbox_overlap(boxes_a, boxes_b):
|
||||
"""Calculates the overlap (iou - intersection over union) between boxes_a and boxes_b.
|
||||
|
||||
Args:
|
||||
boxes_a: a tensor with a shape of [batch_size, N, 4]. N is the number of
|
||||
boxes per image. The last dimension is the pixel coordinates in
|
||||
[ymin, xmin, ymax, xmax] form.
|
||||
boxes_b: a tensor with a shape of [batch_size, M, 4]. M is the number of
|
||||
boxes. The last dimension is the pixel coordinates in
|
||||
[ymin, xmin, ymax, xmax] form.
|
||||
Returns:
|
||||
intersection_over_union: a tensor with as a shape of [batch_size, N, M],
|
||||
representing the ratio of intersection area over union area (IoU) between
|
||||
two boxes
|
||||
"""
|
||||
with ops.name_scope('bbox_overlap'):
|
||||
a_y_min, a_x_min, a_y_max, a_x_max = array_ops.split(
|
||||
value=boxes_a, num_or_size_splits=4, axis=2)
|
||||
b_y_min, b_x_min, b_y_max, b_x_max = array_ops.split(
|
||||
value=boxes_b, num_or_size_splits=4, axis=2)
|
||||
|
||||
# Calculates the intersection area.
|
||||
i_xmin = math_ops.maximum(
|
||||
a_x_min, array_ops.transpose(b_x_min, [0, 2, 1]))
|
||||
i_xmax = math_ops.minimum(
|
||||
a_x_max, array_ops.transpose(b_x_max, [0, 2, 1]))
|
||||
i_ymin = math_ops.maximum(
|
||||
a_y_min, array_ops.transpose(b_y_min, [0, 2, 1]))
|
||||
i_ymax = math_ops.minimum(
|
||||
a_y_max, array_ops.transpose(b_y_max, [0, 2, 1]))
|
||||
i_area = math_ops.maximum(
|
||||
(i_xmax - i_xmin), 0) * math_ops.maximum((i_ymax - i_ymin), 0)
|
||||
|
||||
# Calculates the union area.
|
||||
a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min)
|
||||
b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min)
|
||||
EPSILON = 1e-8
|
||||
# Adds a small epsilon to avoid divide-by-zero.
|
||||
u_area = a_area + array_ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON
|
||||
|
||||
# Calculates IoU.
|
||||
intersection_over_union = i_area / u_area
|
||||
|
||||
return intersection_over_union
|
||||
|
||||
|
||||
def _self_suppression(iou, _, iou_sum, iou_threshold):
|
||||
"""Suppress boxes in the same tile.
|
||||
|
||||
Compute boxes that cannot be suppressed by others (i.e.,
|
||||
can_suppress_others), and then use them to suppress boxes in the same tile.
|
||||
|
||||
Args:
|
||||
iou: a tensor of shape [batch_size, num_boxes_with_padding] representing
|
||||
intersection over union.
|
||||
iou_sum: a scalar tensor.
|
||||
iou_threshold: a scalar tensor.
|
||||
|
||||
Returns:
|
||||
iou_suppressed: a tensor of shape [batch_size, num_boxes_with_padding].
|
||||
iou_diff: a scalar tensor representing whether any box is supressed in
|
||||
this step.
|
||||
iou_sum_new: a scalar tensor of shape [batch_size] that represents
|
||||
the iou sum after suppression.
|
||||
iou_threshold: a scalar tensor.
|
||||
"""
|
||||
batch_size = array_ops.shape(iou)[0]
|
||||
can_suppress_others = math_ops.cast(
|
||||
array_ops.reshape(
|
||||
math_ops.reduce_max(iou, 1) < iou_threshold, [batch_size, -1, 1]),
|
||||
iou.dtype)
|
||||
iou_after_suppression = array_ops.reshape(
|
||||
math_ops.cast(
|
||||
math_ops.reduce_max(can_suppress_others * iou, 1) < iou_threshold,
|
||||
iou.dtype),
|
||||
[batch_size, -1, 1]) * iou
|
||||
iou_sum_new = math_ops.reduce_sum(iou_after_suppression, [1, 2])
|
||||
return [
|
||||
iou_after_suppression,
|
||||
math_ops.reduce_any(iou_sum - iou_sum_new > iou_threshold), iou_sum_new,
|
||||
iou_threshold
|
||||
]
|
||||
|
||||
|
||||
def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size):
|
||||
"""Suppress boxes between different tiles.
|
||||
|
||||
Args:
|
||||
boxes: a tensor of shape [batch_size, num_boxes_with_padding, 4]
|
||||
box_slice: a tensor of shape [batch_size, tile_size, 4]
|
||||
iou_threshold: a scalar tensor
|
||||
inner_idx: a scalar tensor representing the tile index of the tile
|
||||
that is used to supress box_slice
|
||||
tile_size: an integer representing the number of boxes in a tile
|
||||
|
||||
Returns:
|
||||
boxes: unchanged boxes as input
|
||||
box_slice_after_suppression: box_slice after suppression
|
||||
iou_threshold: unchanged
|
||||
"""
|
||||
batch_size = array_ops.shape(boxes)[0]
|
||||
new_slice = array_ops.slice(
|
||||
boxes, [0, inner_idx * tile_size, 0],
|
||||
[batch_size, tile_size, 4])
|
||||
iou = _bbox_overlap(new_slice, box_slice)
|
||||
box_slice_after_suppression = array_ops.expand_dims(
|
||||
math_ops.cast(math_ops.reduce_all(iou < iou_threshold, [1]),
|
||||
box_slice.dtype),
|
||||
2) * box_slice
|
||||
return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1
|
||||
|
||||
|
||||
def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
|
||||
"""Process boxes in the range [idx*tile_size, (idx+1)*tile_size).
|
||||
|
||||
Args:
|
||||
boxes: a tensor with a shape of [batch_size, anchors, 4].
|
||||
iou_threshold: a float representing the threshold for deciding whether boxes
|
||||
overlap too much with respect to IOU.
|
||||
output_size: an int32 tensor of size [batch_size]. Representing the number
|
||||
of selected boxes for each batch.
|
||||
idx: an integer scalar representing induction variable.
|
||||
tile_size: an integer representing the number of boxes in a tile
|
||||
|
||||
Returns:
|
||||
boxes: updated boxes.
|
||||
iou_threshold: pass down iou_threshold to the next iteration.
|
||||
output_size: the updated output_size.
|
||||
idx: the updated induction variable.
|
||||
"""
|
||||
with ops.name_scope('suppression_loop_body'):
|
||||
num_tiles = array_ops.shape(boxes)[1] // tile_size
|
||||
batch_size = array_ops.shape(boxes)[0]
|
||||
|
||||
def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx):
|
||||
return _cross_suppression(boxes, box_slice, iou_threshold, inner_idx,
|
||||
tile_size)
|
||||
|
||||
# Iterates over tiles that can possibly suppress the current tile.
|
||||
box_slice = array_ops.slice(boxes, [0, idx * tile_size, 0],
|
||||
[batch_size, tile_size, 4])
|
||||
_, box_slice, _, _ = control_flow_ops.while_loop(
|
||||
lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
|
||||
cross_suppression_func,
|
||||
[boxes, box_slice, iou_threshold, constant_op.constant(0)])
|
||||
|
||||
# Iterates over the current tile to compute self-suppression.
|
||||
iou = _bbox_overlap(box_slice, box_slice)
|
||||
mask = array_ops.expand_dims(
|
||||
array_ops.reshape(
|
||||
math_ops.range(tile_size), [1, -1]) > array_ops.reshape(
|
||||
math_ops.range(tile_size), [-1, 1]), 0)
|
||||
iou *= math_ops.cast(
|
||||
math_ops.logical_and(mask, iou >= iou_threshold), iou.dtype)
|
||||
suppressed_iou, _, _, _ = control_flow_ops.while_loop(
|
||||
lambda _iou, loop_condition, _iou_sum, _: loop_condition,
|
||||
_self_suppression,
|
||||
[iou, constant_op.constant(True), math_ops.reduce_sum(iou, [1, 2]),
|
||||
iou_threshold])
|
||||
suppressed_box = math_ops.reduce_sum(suppressed_iou, 1) > 0
|
||||
box_slice *= array_ops.expand_dims(
|
||||
1.0 - math_ops.cast(suppressed_box, box_slice.dtype), 2)
|
||||
|
||||
# Uses box_slice to update the input boxes.
|
||||
mask = array_ops.reshape(
|
||||
math_ops.cast(
|
||||
math_ops.equal(math_ops.range(num_tiles), idx), boxes.dtype),
|
||||
[1, -1, 1, 1])
|
||||
boxes = array_ops.tile(array_ops.expand_dims(
|
||||
box_slice, [1]), [1, num_tiles, 1, 1]) * mask + array_ops.reshape(
|
||||
boxes, [batch_size, num_tiles, tile_size, 4]) * (1 - mask)
|
||||
boxes = array_ops.reshape(boxes, [batch_size, -1, 4])
|
||||
|
||||
# Updates output_size.
|
||||
output_size += math_ops.reduce_sum(
|
||||
math_ops.cast(
|
||||
math_ops.reduce_any(box_slice > 0, [2]), dtypes.int32), [1])
|
||||
return boxes, iou_threshold, output_size, idx + 1
|
||||
|
||||
|
||||
@tf_export('image.non_max_suppression_padded')
|
||||
def non_max_suppression_padded(boxes,
|
||||
scores,
|
||||
max_output_size,
|
||||
iou_threshold=0.5,
|
||||
score_threshold=float('-inf'),
|
||||
pad_to_max_output_size=False,
|
||||
name=None,
|
||||
sorted_input=False,
|
||||
canonicalized_coordinates=False,
|
||||
tile_size=512):
|
||||
"""Non-maximum suppression.
|
||||
|
||||
Prunes away boxes that have high intersection-over-union (IOU) overlap
|
||||
with previously selected boxes. Bounding boxes are supplied as
|
||||
`[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the coordinates of any
|
||||
diagonal pair of box corners and the coordinates can be provided as normalized
|
||||
(i.e., lying in the interval `[0, 1]`) or absolute. The bounding box
|
||||
coordinates are cannonicalized to `[y_min, x_min, y_max, x_max]`,
|
||||
where `(y_min, x_min)` and `(y_max, x_mas)` are the coordinates of the lower
|
||||
left and upper right corner. User may indiciate the input box coordinates are
|
||||
already canonicalized to eliminate redundant work by setting
|
||||
canonicalized_coordinates to `True`. Note that this algorithm is agnostic to
|
||||
where the origin is in the coordinate system. Note that this algorithm is
|
||||
invariant to orthogonal transformations and translations of the coordinate
|
||||
system; thus translating or reflections of the coordinate system result in the
|
||||
same boxes being selected by the algorithm.
|
||||
|
||||
Similar to tf.image.non_max_suppression, batched_non_max_suppression
|
||||
implements hard NMS but can operate on a batch of images and improves
|
||||
performance by titling the bounding boxes. Batched_non_max_suppression should
|
||||
be preferred over tf.image_non_max_suppression when running on devices with
|
||||
abundant parallelsim for higher computation speed. For soft NMS, refer to
|
||||
tf.image.non_max_suppression_with_scores.
|
||||
|
||||
While a serial NMS algorithm iteratively uses the highest-scored unprocessed
|
||||
box to suppress boxes, this algorithm uses many boxes to suppress other boxes
|
||||
in parallel. The key idea is to partition boxes into tiles based on their
|
||||
score and suppresses boxes tile by tile, thus achieving parallelism within a
|
||||
tile. The tile size determines the degree of parallelism.
|
||||
|
||||
In cross suppression (using boxes of tile A to suppress boxes of tile B),
|
||||
all boxes in A can independently suppress boxes in B.
|
||||
|
||||
Self suppression (suppressing boxes of the same tile) needs to be iteratively
|
||||
applied until there's no more suppression. In each iteration, boxes that
|
||||
cannot be suppressed are used to suppress boxes in the same tile.
|
||||
|
||||
boxes = boxes.pad_to_multiply_of(tile_size)
|
||||
num_tiles = len(boxes) // tile_size
|
||||
output_boxes = []
|
||||
for i in range(num_tiles):
|
||||
box_tile = boxes[i*tile_size : (i+1)*tile_size]
|
||||
for j in range(i - 1):
|
||||
# in parallel suppress boxes in box_tile using boxes from suppressing_tile
|
||||
suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
|
||||
iou = _bbox_overlap(box_tile, suppressing_tile)
|
||||
# if the box is suppressed in iou, clear it to a dot
|
||||
box_tile *= _update_boxes(iou)
|
||||
# Iteratively handle the diagnal tile.
|
||||
iou = _box_overlap(box_tile, box_tile)
|
||||
iou_changed = True
|
||||
while iou_changed:
|
||||
# boxes that are not suppressed by anything else
|
||||
suppressing_boxes = _get_suppressing_boxes(iou)
|
||||
# boxes that are suppressed by suppressing_boxes
|
||||
suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
|
||||
# clear iou to 0 for boxes that are suppressed, as they cannot be used
|
||||
# to suppress other boxes any more
|
||||
new_iou = _clear_iou(iou, suppressed_boxes)
|
||||
iou_changed = (new_iou != iou)
|
||||
iou = new_iou
|
||||
# remaining boxes that can still suppress others, are selected boxes.
|
||||
output_boxes.append(_get_suppressing_boxes(iou))
|
||||
if len(output_boxes) >= max_output_size:
|
||||
break
|
||||
|
||||
Args:
|
||||
boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4].
|
||||
Dimensions except the last two are batch dimensions.
|
||||
scores: a tensor of rank 1 or higher with a shape of [..., num_boxes].
|
||||
max_output_size: a scalar integer `Tensor` representing the maximum number
|
||||
of boxes to be selected by non max suppression.
|
||||
iou_threshold: a float representing the threshold for deciding whether boxes
|
||||
overlap too much with respect to IoU (intersection over union).
|
||||
score_threshold: a float representing the threshold for box scores. Boxes
|
||||
with a score that is lower than this threshold will be suppressed.
|
||||
pad_to_max_output_size: whether to pad the output idx to max_output_size.
|
||||
Must be set to True when the input is a batch of images.
|
||||
name: name of operation.
|
||||
sorted_input: a boolean indicating whether the input boxes and scores
|
||||
are sorted in descending order by the score.
|
||||
canonicalized_coordinates: if box coordinates are given as
|
||||
`[y_min, x_min, y_max, x_max]`, settign to True eliminate redundant
|
||||
computation to canonicalize box coordinates.
|
||||
tile_size: an integer representing the number of boxes in a tile, i.e.,
|
||||
the maximum number of boxes per image that can be used to suppress other
|
||||
boxes in parallel; larger tile_size means larger parallelism and
|
||||
potentially more redundant work.
|
||||
Returns:
|
||||
idx: a tensor with a shape of [..., num_boxes] representing the
|
||||
indices selected by non-max suppression. The leadign dimensions
|
||||
are the batch dimensions of the input boxes. All numbers are are within
|
||||
[0, num_boxes). For each image (i.e., idx[i]), only the first num_valid[i]
|
||||
indices (i.e., idx[i][:num_valid[i]]) are valid.
|
||||
num_valid: a tensor of rank 0 or higher with a shape of [...]
|
||||
representing the number of valid indices in idx. Its dimensions are the
|
||||
batch dimensions of the input boxes.
|
||||
Raises:
|
||||
ValueError: When set pad_to_max_output_size to False for batched input.
|
||||
"""
|
||||
def _sort_scores_and_boxes(scores, boxes):
|
||||
"""Sort boxes based their score from highest to lowest.
|
||||
|
||||
Args:
|
||||
scores: a tensor with a shape of [batch_size, num_boxes] representing
|
||||
the scores of boxes.
|
||||
boxes: a tensor with a shape of [batch_size, num_boxes, 4] representing
|
||||
the boxes.
|
||||
Returns:
|
||||
sorted_scores: a tensor with a shape of [batch_size, num_boxes]
|
||||
representing the sorted scores.
|
||||
sorted_boxes: a tensor representing the sorted boxes.
|
||||
sorted_scores_indices: a tensor with a shape of [batch_size, num_boxes]
|
||||
representing the index of the scores in a sorted descending order.
|
||||
"""
|
||||
with ops.name_scope('sort_scores_and_boxes'):
|
||||
batch_size = array_ops.shape(boxes)[0]
|
||||
num_boxes = array_ops.shape(boxes)[1]
|
||||
sorted_scores_indices = sort_ops.argsort(
|
||||
scores, axis=1, direction='DESCENDING')
|
||||
index_offsets = math_ops.range(batch_size) * num_boxes
|
||||
indices = array_ops.reshape(
|
||||
sorted_scores_indices + array_ops.expand_dims(index_offsets, 1), [-1])
|
||||
sorted_scores = array_ops.reshape(
|
||||
array_ops.gather(array_ops.reshape(scores, [-1]), indices),
|
||||
[batch_size, -1])
|
||||
sorted_boxes = array_ops.reshape(
|
||||
array_ops.gather(array_ops.reshape(boxes, [-1, 4]), indices),
|
||||
[batch_size, -1, 4])
|
||||
return sorted_scores, sorted_boxes, sorted_scores_indices
|
||||
|
||||
with ops.name_scope(name, 'batched_non_max_suppression'):
|
||||
if boxes.get_shape().ndims > 2 and not pad_to_max_output_size:
|
||||
raise ValueError("'pad_to_max_output_size' (value {}) must be "
|
||||
"True for batched input".format(pad_to_max_output_size))
|
||||
|
||||
batch_dims = boxes.get_shape().as_list()[:-2]
|
||||
num_boxes = array_ops.shape(boxes)[-2]
|
||||
boxes = array_ops.reshape(boxes, [-1, num_boxes, 4])
|
||||
scores = array_ops.reshape(scores, [-1, num_boxes])
|
||||
batch_size = array_ops.shape(boxes)[0]
|
||||
if score_threshold != float('-inf'):
|
||||
with ops.name_scope('filter_by_score'):
|
||||
score_mask = math_ops.cast(scores >= score_threshold, scores.dtype)
|
||||
scores *= score_mask
|
||||
box_mask = array_ops.expand_dims(
|
||||
math_ops.cast(score_mask, boxes.dtype), 2)
|
||||
boxes *= box_mask
|
||||
|
||||
if not canonicalized_coordinates:
|
||||
with ops.name_scope('canonicalize_coordinates'):
|
||||
y_1, x_1, y_2, x_2 = array_ops.split(
|
||||
value=boxes, num_or_size_splits=4, axis=2)
|
||||
y_1_is_min = math_ops.less(y_1[0, 0, 0], y_2[0, 0, 0])
|
||||
y_min, y_max = control_flow_ops.cond(
|
||||
y_1_is_min, lambda: (y_1, y_2), lambda: (y_2, y_1))
|
||||
x_1_is_min = math_ops.less(x_1[0, 0, 0], x_2[0, 0, 0])
|
||||
x_min, x_max = control_flow_ops.cond(
|
||||
x_1_is_min, lambda: (x_1, x_2), lambda: (x_2, x_1))
|
||||
boxes = array_ops.concat([y_min, x_min, y_max, x_max], axis=2)
|
||||
|
||||
if not sorted_input:
|
||||
scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes)
|
||||
|
||||
pad = math_ops.cast(
|
||||
math_ops.ceil(
|
||||
math_ops.cast(num_boxes, dtypes.float32) / tile_size),
|
||||
dtypes.int32) * tile_size - num_boxes
|
||||
boxes = array_ops.pad(
|
||||
math_ops.cast(boxes, dtypes.float32), [[0, 0], [0, pad], [0, 0]])
|
||||
scores = array_ops.pad(
|
||||
math_ops.cast(scores, dtypes.float32), [[0, 0], [0, pad]])
|
||||
num_boxes_after_padding = num_boxes + pad
|
||||
|
||||
def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
|
||||
return math_ops.logical_and(
|
||||
math_ops.reduce_min(output_size) < max_output_size,
|
||||
idx < num_boxes_after_padding // tile_size)
|
||||
|
||||
def suppression_loop_body(boxes, iou_threshold, output_size, idx):
|
||||
return _suppression_loop_body(
|
||||
boxes, iou_threshold, output_size, idx, tile_size)
|
||||
|
||||
selected_boxes, _, output_size, _ = control_flow_ops.while_loop(
|
||||
_loop_cond, suppression_loop_body,
|
||||
[boxes, iou_threshold, array_ops.zeros([batch_size], dtypes.int32),
|
||||
constant_op.constant(0)]
|
||||
)
|
||||
num_valid = math_ops.minimum(output_size, max_output_size)
|
||||
idx = num_boxes_after_padding - math_ops.cast(
|
||||
nn_ops.top_k(
|
||||
math_ops.cast(math_ops.reduce_any(
|
||||
selected_boxes > 0, [2]), dtypes.int32) *
|
||||
array_ops.expand_dims(
|
||||
math_ops.range(num_boxes_after_padding, 0, -1), 0),
|
||||
max_output_size)[0], dtypes.int32)
|
||||
idx = math_ops.minimum(idx, num_boxes - 1)
|
||||
if not sorted_input:
|
||||
index_offsets = math_ops.range(batch_size) * num_boxes
|
||||
gather_idx = array_ops.reshape(
|
||||
idx + array_ops.expand_dims(index_offsets, 1), [-1])
|
||||
idx = array_ops.reshape(
|
||||
array_ops.gather(array_ops.reshape(sorted_indices, [-1]),
|
||||
gather_idx),
|
||||
[batch_size, -1])
|
||||
|
||||
num_valid = array_ops.reshape(num_valid, batch_dims)
|
||||
if not pad_to_max_output_size:
|
||||
idx = idx[0, :num_valid]
|
||||
batch_dims.append(-1)
|
||||
idx = array_ops.reshape(idx, batch_dims)
|
||||
return idx, num_valid
|
||||
|
||||
|
||||
@tf_export('image.draw_bounding_boxes', v1=[])
|
||||
def draw_bounding_boxes_v2(images, boxes, colors, name=None):
|
||||
"""Draw bounding boxes on a batch of images.
|
||||
|
@ -4615,7 +4615,9 @@ class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
|
||||
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
|
||||
with self.cached_session():
|
||||
self.assertAllClose(selected_indices_padded.eval(), [3, 0, 5, 0, 0])
|
||||
invalid_index = len(boxes_np) - 1
|
||||
self.assertAllClose(selected_indices_padded.eval(),
|
||||
[3, 0, 5, invalid_index, invalid_index])
|
||||
self.assertEqual(num_valid_padded.eval(), 3)
|
||||
self.assertAllClose(selected_indices.eval(), [3, 0, 5])
|
||||
self.assertEqual(num_valid.eval(), 3)
|
||||
|
@ -138,7 +138,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "non_max_suppression_padded"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\', \'sorted_input\', \'canonicalized_coordinates\', \'tile_size\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\', \'False\', \'False\', \'512\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "non_max_suppression_with_scores"
|
||||
|
@ -134,7 +134,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "non_max_suppression_padded"
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\', \'sorted_input\', \'canonicalized_coordinates\', \'tile_size\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\', \'False\', \'False\', \'512\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "non_max_suppression_with_scores"
|
||||
|
Loading…
x
Reference in New Issue
Block a user