Replace non_max_suppression_padded with a new implementation that supports batched inputs. This new implementation is considerably faster than previous implementations on TPU and GPU.

PiperOrigin-RevId: 302730463
Change-Id: I3ff9a1204dc892b7c6688b6b400281ff9cd5331f
This commit is contained in:
A. Unique TensorFlower 2020-03-24 13:06:39 -07:00 committed by TensorFlower Gardener
parent 21c8683cad
commit 448a04e7c4
6 changed files with 687 additions and 63 deletions

View File

@ -32,6 +32,7 @@ package_group(
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/compiler/tests/...",
"//platforms/xla/tests/neural_nets",
],
)
@ -747,6 +748,7 @@ tf_xla_py_test(
tf_xla_py_test(
name = "image_ops_test",
size = "small",
timeout = "long",
srcs = ["image_ops_test.py"],
python_version = "PY3",
shard_count = 10,

View File

@ -976,5 +976,275 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
self.assertAllClose(indices_tf[:num_valid], [0, 2, 4])
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
def testBatchedNMSFrom6(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
max_output_size = 6
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
pad_to_max_output_size=True,
sorted_input=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
invalid_index = len(boxes_data[0]) - 1
self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index],
[0, 1, 3, 5, invalid_index, invalid_index]],
indices_output)
self.assertAllEqual([5, 4], num_valid_output)
def testBatchedNMSFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
max_output_size = 3
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
pad_to_max_output_size=True,
sorted_input=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
self.assertAllEqual([3, 3], num_valid_output)
def testBatchedNMSSingleFrom6Max3(self):
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
scores_data = [0.9, 0.7, 0.6, 0.5, 0.4, 0.3]
max_output_size = 3
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
pad_to_max_output_size=True,
sorted_input=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
self.assertAllEqual([0, 1, 2], indices_output)
self.assertAllEqual(3, num_valid_output)
def testBatchedNMSSingleFrom6NoPad(self):
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
scores_data = [0.9, 0.7, 0.6, 0.5, 0.4, 0.3]
max_output_size = 6
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
sorted_input=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
self.assertAllEqual(5, num_valid_output)
def testBatchedNMSBatchDimsFrom6Max3(self):
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]]
scores_data = [[[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]]
max_output_size = 3
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
pad_to_max_output_size=True,
sorted_input=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
self.assertAllEqual([[3, 3]], num_valid_output)
def testBatchedNMSScoreThresholdFrom6Max3(self):
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
scores_data = [[0.9, 0.7, 0.6, 0.4, 0.3, 0.2],
[0.8, 0.7, 0.6, 0.4, 0.3, 0.1]]
max_output_size = 3
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
score_threshold=0.5,
pad_to_max_output_size=True,
sorted_input=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
invalid_index = len(boxes_data[0]) - 1
self.assertAllEqual([3, 2], num_valid_output)
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
def testBatchedNMSUnsortedInputFrom6(self):
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
[[0, 0.4, 1, 1.4], [0, 2, 1, 2], [0, 0.2, 1, 1.2],
[0, 0, 1, 1], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]]]
scores_data = [[0.3, 0.7, 0.9, 0.6, 0.5, 0.4],
[0.5, 0.8, 0.4, 0.3, 0.6, 0.7]]
max_output_size = 6
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
pad_to_max_output_size=True,
canonicalized_coordinates=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
self.assertAllEqual([[2, 1, 3, 5, 0, 0], [1, 5, 0, 3, 3, 3]],
indices_output)
self.assertAllEqual([5, 4], num_valid_output)
def testBatchedNMSNoncanonicalizedInputFrom6(self):
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
[[1, 2, 0, 2], [1, 0.8, 0, 1.8], [1, 0.6, 0, 1.6],
[1, 0.4, 0, 1.4], [1, 0.2, 0, 1.2], [1, 0, 0, 1]]]
scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
max_output_size = 6
iou_threshold = 0.5
boxes_np = np.array(boxes_data, dtype=np.float32)
scores_np = np.array(scores_data, dtype=np.float32)
with self.session() as sess:
boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape)
scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape)
with self.test_scope():
(indices, num_valid) = image_ops.non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=max_output_size,
iou_threshold=iou_threshold,
pad_to_max_output_size=True,
sorted_input=True)
inputs = {
boxes: boxes_np,
scores: scores_np
}
indices_output, num_valid_output = sess.run([indices, num_valid], inputs)
invalid_index = len(boxes_data[0]) - 1
self.assertAllEqual([[0, 1, 2, 4, 5, invalid_index],
[0, 1, 3, 5, invalid_index, invalid_index]],
indices_output)
self.assertAllEqual([5, 4], num_valid_output)
if __name__ == "__main__":
test.main()

View File

@ -34,6 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sort_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.util import deprecation
@ -3015,7 +3016,7 @@ def non_max_suppression_with_scores(boxes,
box coordinates corresponding to the selected indices can then be obtained
using the `tf.gather` operation. For example:
```python
selected_indices, selected_scores = tf.image.non_max_suppression_v2(
selected_indices, selected_scores = tf.image.non_max_suppression_padded(
boxes, scores, max_output_size, iou_threshold=1.0, score_threshold=0.1,
soft_nms_sigma=0.5)
selected_boxes = tf.gather(boxes, selected_indices)
@ -3026,12 +3027,12 @@ def non_max_suppression_with_scores(boxes,
Bodla et al, https://arxiv.org/abs/1704.04503) where boxes reduce the score
of other overlapping boxes instead of directly causing them to be pruned.
Consequently, in contrast to `tf.image.non_max_suppression`,
`tf.image.non_max_suppression_v2` returns the new scores of each input box in
the second output, `selected_scores`.
`tf.image.non_max_suppression_padded` returns the new scores of each input box
in the second output, `selected_scores`.
To enable this Soft-NMS mode, set the `soft_nms_sigma` parameter to be
larger than 0. When `soft_nms_sigma` equals 0, the behavior of
`tf.image.non_max_suppression_v2` is identical to that of
`tf.image.non_max_suppression_padded` is identical to that of
`tf.image.non_max_suppression` (except for the extra output) both in function
and in running time.
@ -3077,62 +3078,6 @@ def non_max_suppression_with_scores(boxes,
return selected_indices, selected_scores
@tf_export('image.non_max_suppression_padded')
def non_max_suppression_padded(boxes,
scores,
max_output_size,
iou_threshold=0.5,
score_threshold=float('-inf'),
pad_to_max_output_size=False,
name=None):
"""Greedily selects a subset of bounding boxes in descending order of score.
Performs algorithmically equivalent operation to tf.image.non_max_suppression,
with the addition of an optional parameter which zero-pads the output to
be of size `max_output_size`.
The output of this operation is a tuple containing the set of integers
indexing into the input collection of bounding boxes representing the selected
boxes and the number of valid indices in the index set. The bounding box
coordinates corresponding to the selected indices can then be obtained using
the `tf.slice` and `tf.gather` operations. For example:
```python
selected_indices_padded, num_valid = tf.image.non_max_suppression_padded(
boxes, scores, max_output_size, iou_threshold,
score_threshold, pad_to_max_output_size=True)
selected_indices = tf.slice(
selected_indices_padded, tf.constant([0]), num_valid)
selected_boxes = tf.gather(boxes, selected_indices)
```
Args:
boxes: A 2-D float `Tensor` of shape `[num_boxes, 4]`.
scores: A 1-D float `Tensor` of shape `[num_boxes]` representing a single
score corresponding to each box (each row of boxes).
max_output_size: A scalar integer `Tensor` representing the maximum number
of boxes to be selected by non-max suppression.
iou_threshold: A float representing the threshold for deciding whether boxes
overlap too much with respect to IOU.
score_threshold: A float representing the threshold for deciding when to
remove boxes based on score.
pad_to_max_output_size: bool. If True, size of `selected_indices` output is
padded to `max_output_size`.
name: A name for the operation (optional).
Returns:
selected_indices: A 1-D integer `Tensor` of shape `[M]` representing the
selected indices from the boxes tensor, where `M <= max_output_size`.
valid_outputs: A scalar integer `Tensor` denoting how many elements in
`selected_indices` are valid. Valid elements occur first, then padding.
"""
with ops.name_scope(name, 'non_max_suppression_padded'):
iou_threshold = ops.convert_to_tensor(iou_threshold, name='iou_threshold')
score_threshold = ops.convert_to_tensor(
score_threshold, name='score_threshold')
return gen_image_ops.non_max_suppression_v4(boxes, scores, max_output_size,
iou_threshold, score_threshold,
pad_to_max_output_size)
@tf_export('image.non_max_suppression_overlaps')
def non_max_suppression_with_overlaps(overlaps,
scores,
@ -4312,6 +4257,411 @@ def combined_non_max_suppression(boxes,
score_threshold, pad_per_class, clip_boxes)
def _bbox_overlap(boxes_a, boxes_b):
"""Calculates the overlap (iou - intersection over union) between boxes_a and boxes_b.
Args:
boxes_a: a tensor with a shape of [batch_size, N, 4]. N is the number of
boxes per image. The last dimension is the pixel coordinates in
[ymin, xmin, ymax, xmax] form.
boxes_b: a tensor with a shape of [batch_size, M, 4]. M is the number of
boxes. The last dimension is the pixel coordinates in
[ymin, xmin, ymax, xmax] form.
Returns:
intersection_over_union: a tensor with as a shape of [batch_size, N, M],
representing the ratio of intersection area over union area (IoU) between
two boxes
"""
with ops.name_scope('bbox_overlap'):
a_y_min, a_x_min, a_y_max, a_x_max = array_ops.split(
value=boxes_a, num_or_size_splits=4, axis=2)
b_y_min, b_x_min, b_y_max, b_x_max = array_ops.split(
value=boxes_b, num_or_size_splits=4, axis=2)
# Calculates the intersection area.
i_xmin = math_ops.maximum(
a_x_min, array_ops.transpose(b_x_min, [0, 2, 1]))
i_xmax = math_ops.minimum(
a_x_max, array_ops.transpose(b_x_max, [0, 2, 1]))
i_ymin = math_ops.maximum(
a_y_min, array_ops.transpose(b_y_min, [0, 2, 1]))
i_ymax = math_ops.minimum(
a_y_max, array_ops.transpose(b_y_max, [0, 2, 1]))
i_area = math_ops.maximum(
(i_xmax - i_xmin), 0) * math_ops.maximum((i_ymax - i_ymin), 0)
# Calculates the union area.
a_area = (a_y_max - a_y_min) * (a_x_max - a_x_min)
b_area = (b_y_max - b_y_min) * (b_x_max - b_x_min)
EPSILON = 1e-8
# Adds a small epsilon to avoid divide-by-zero.
u_area = a_area + array_ops.transpose(b_area, [0, 2, 1]) - i_area + EPSILON
# Calculates IoU.
intersection_over_union = i_area / u_area
return intersection_over_union
def _self_suppression(iou, _, iou_sum, iou_threshold):
"""Suppress boxes in the same tile.
Compute boxes that cannot be suppressed by others (i.e.,
can_suppress_others), and then use them to suppress boxes in the same tile.
Args:
iou: a tensor of shape [batch_size, num_boxes_with_padding] representing
intersection over union.
iou_sum: a scalar tensor.
iou_threshold: a scalar tensor.
Returns:
iou_suppressed: a tensor of shape [batch_size, num_boxes_with_padding].
iou_diff: a scalar tensor representing whether any box is supressed in
this step.
iou_sum_new: a scalar tensor of shape [batch_size] that represents
the iou sum after suppression.
iou_threshold: a scalar tensor.
"""
batch_size = array_ops.shape(iou)[0]
can_suppress_others = math_ops.cast(
array_ops.reshape(
math_ops.reduce_max(iou, 1) < iou_threshold, [batch_size, -1, 1]),
iou.dtype)
iou_after_suppression = array_ops.reshape(
math_ops.cast(
math_ops.reduce_max(can_suppress_others * iou, 1) < iou_threshold,
iou.dtype),
[batch_size, -1, 1]) * iou
iou_sum_new = math_ops.reduce_sum(iou_after_suppression, [1, 2])
return [
iou_after_suppression,
math_ops.reduce_any(iou_sum - iou_sum_new > iou_threshold), iou_sum_new,
iou_threshold
]
def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size):
"""Suppress boxes between different tiles.
Args:
boxes: a tensor of shape [batch_size, num_boxes_with_padding, 4]
box_slice: a tensor of shape [batch_size, tile_size, 4]
iou_threshold: a scalar tensor
inner_idx: a scalar tensor representing the tile index of the tile
that is used to supress box_slice
tile_size: an integer representing the number of boxes in a tile
Returns:
boxes: unchanged boxes as input
box_slice_after_suppression: box_slice after suppression
iou_threshold: unchanged
"""
batch_size = array_ops.shape(boxes)[0]
new_slice = array_ops.slice(
boxes, [0, inner_idx * tile_size, 0],
[batch_size, tile_size, 4])
iou = _bbox_overlap(new_slice, box_slice)
box_slice_after_suppression = array_ops.expand_dims(
math_ops.cast(math_ops.reduce_all(iou < iou_threshold, [1]),
box_slice.dtype),
2) * box_slice
return boxes, box_slice_after_suppression, iou_threshold, inner_idx + 1
def _suppression_loop_body(boxes, iou_threshold, output_size, idx, tile_size):
"""Process boxes in the range [idx*tile_size, (idx+1)*tile_size).
Args:
boxes: a tensor with a shape of [batch_size, anchors, 4].
iou_threshold: a float representing the threshold for deciding whether boxes
overlap too much with respect to IOU.
output_size: an int32 tensor of size [batch_size]. Representing the number
of selected boxes for each batch.
idx: an integer scalar representing induction variable.
tile_size: an integer representing the number of boxes in a tile
Returns:
boxes: updated boxes.
iou_threshold: pass down iou_threshold to the next iteration.
output_size: the updated output_size.
idx: the updated induction variable.
"""
with ops.name_scope('suppression_loop_body'):
num_tiles = array_ops.shape(boxes)[1] // tile_size
batch_size = array_ops.shape(boxes)[0]
def cross_suppression_func(boxes, box_slice, iou_threshold, inner_idx):
return _cross_suppression(boxes, box_slice, iou_threshold, inner_idx,
tile_size)
# Iterates over tiles that can possibly suppress the current tile.
box_slice = array_ops.slice(boxes, [0, idx * tile_size, 0],
[batch_size, tile_size, 4])
_, box_slice, _, _ = control_flow_ops.while_loop(
lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
cross_suppression_func,
[boxes, box_slice, iou_threshold, constant_op.constant(0)])
# Iterates over the current tile to compute self-suppression.
iou = _bbox_overlap(box_slice, box_slice)
mask = array_ops.expand_dims(
array_ops.reshape(
math_ops.range(tile_size), [1, -1]) > array_ops.reshape(
math_ops.range(tile_size), [-1, 1]), 0)
iou *= math_ops.cast(
math_ops.logical_and(mask, iou >= iou_threshold), iou.dtype)
suppressed_iou, _, _, _ = control_flow_ops.while_loop(
lambda _iou, loop_condition, _iou_sum, _: loop_condition,
_self_suppression,
[iou, constant_op.constant(True), math_ops.reduce_sum(iou, [1, 2]),
iou_threshold])
suppressed_box = math_ops.reduce_sum(suppressed_iou, 1) > 0
box_slice *= array_ops.expand_dims(
1.0 - math_ops.cast(suppressed_box, box_slice.dtype), 2)
# Uses box_slice to update the input boxes.
mask = array_ops.reshape(
math_ops.cast(
math_ops.equal(math_ops.range(num_tiles), idx), boxes.dtype),
[1, -1, 1, 1])
boxes = array_ops.tile(array_ops.expand_dims(
box_slice, [1]), [1, num_tiles, 1, 1]) * mask + array_ops.reshape(
boxes, [batch_size, num_tiles, tile_size, 4]) * (1 - mask)
boxes = array_ops.reshape(boxes, [batch_size, -1, 4])
# Updates output_size.
output_size += math_ops.reduce_sum(
math_ops.cast(
math_ops.reduce_any(box_slice > 0, [2]), dtypes.int32), [1])
return boxes, iou_threshold, output_size, idx + 1
@tf_export('image.non_max_suppression_padded')
def non_max_suppression_padded(boxes,
scores,
max_output_size,
iou_threshold=0.5,
score_threshold=float('-inf'),
pad_to_max_output_size=False,
name=None,
sorted_input=False,
canonicalized_coordinates=False,
tile_size=512):
"""Non-maximum suppression.
Prunes away boxes that have high intersection-over-union (IOU) overlap
with previously selected boxes. Bounding boxes are supplied as
`[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the coordinates of any
diagonal pair of box corners and the coordinates can be provided as normalized
(i.e., lying in the interval `[0, 1]`) or absolute. The bounding box
coordinates are cannonicalized to `[y_min, x_min, y_max, x_max]`,
where `(y_min, x_min)` and `(y_max, x_mas)` are the coordinates of the lower
left and upper right corner. User may indiciate the input box coordinates are
already canonicalized to eliminate redundant work by setting
canonicalized_coordinates to `True`. Note that this algorithm is agnostic to
where the origin is in the coordinate system. Note that this algorithm is
invariant to orthogonal transformations and translations of the coordinate
system; thus translating or reflections of the coordinate system result in the
same boxes being selected by the algorithm.
Similar to tf.image.non_max_suppression, batched_non_max_suppression
implements hard NMS but can operate on a batch of images and improves
performance by titling the bounding boxes. Batched_non_max_suppression should
be preferred over tf.image_non_max_suppression when running on devices with
abundant parallelsim for higher computation speed. For soft NMS, refer to
tf.image.non_max_suppression_with_scores.
While a serial NMS algorithm iteratively uses the highest-scored unprocessed
box to suppress boxes, this algorithm uses many boxes to suppress other boxes
in parallel. The key idea is to partition boxes into tiles based on their
score and suppresses boxes tile by tile, thus achieving parallelism within a
tile. The tile size determines the degree of parallelism.
In cross suppression (using boxes of tile A to suppress boxes of tile B),
all boxes in A can independently suppress boxes in B.
Self suppression (suppressing boxes of the same tile) needs to be iteratively
applied until there's no more suppression. In each iteration, boxes that
cannot be suppressed are used to suppress boxes in the same tile.
boxes = boxes.pad_to_multiply_of(tile_size)
num_tiles = len(boxes) // tile_size
output_boxes = []
for i in range(num_tiles):
box_tile = boxes[i*tile_size : (i+1)*tile_size]
for j in range(i - 1):
# in parallel suppress boxes in box_tile using boxes from suppressing_tile
suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
iou = _bbox_overlap(box_tile, suppressing_tile)
# if the box is suppressed in iou, clear it to a dot
box_tile *= _update_boxes(iou)
# Iteratively handle the diagnal tile.
iou = _box_overlap(box_tile, box_tile)
iou_changed = True
while iou_changed:
# boxes that are not suppressed by anything else
suppressing_boxes = _get_suppressing_boxes(iou)
# boxes that are suppressed by suppressing_boxes
suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
# clear iou to 0 for boxes that are suppressed, as they cannot be used
# to suppress other boxes any more
new_iou = _clear_iou(iou, suppressed_boxes)
iou_changed = (new_iou != iou)
iou = new_iou
# remaining boxes that can still suppress others, are selected boxes.
output_boxes.append(_get_suppressing_boxes(iou))
if len(output_boxes) >= max_output_size:
break
Args:
boxes: a tensor of rank 2 or higher with a shape of [..., num_boxes, 4].
Dimensions except the last two are batch dimensions.
scores: a tensor of rank 1 or higher with a shape of [..., num_boxes].
max_output_size: a scalar integer `Tensor` representing the maximum number
of boxes to be selected by non max suppression.
iou_threshold: a float representing the threshold for deciding whether boxes
overlap too much with respect to IoU (intersection over union).
score_threshold: a float representing the threshold for box scores. Boxes
with a score that is lower than this threshold will be suppressed.
pad_to_max_output_size: whether to pad the output idx to max_output_size.
Must be set to True when the input is a batch of images.
name: name of operation.
sorted_input: a boolean indicating whether the input boxes and scores
are sorted in descending order by the score.
canonicalized_coordinates: if box coordinates are given as
`[y_min, x_min, y_max, x_max]`, settign to True eliminate redundant
computation to canonicalize box coordinates.
tile_size: an integer representing the number of boxes in a tile, i.e.,
the maximum number of boxes per image that can be used to suppress other
boxes in parallel; larger tile_size means larger parallelism and
potentially more redundant work.
Returns:
idx: a tensor with a shape of [..., num_boxes] representing the
indices selected by non-max suppression. The leadign dimensions
are the batch dimensions of the input boxes. All numbers are are within
[0, num_boxes). For each image (i.e., idx[i]), only the first num_valid[i]
indices (i.e., idx[i][:num_valid[i]]) are valid.
num_valid: a tensor of rank 0 or higher with a shape of [...]
representing the number of valid indices in idx. Its dimensions are the
batch dimensions of the input boxes.
Raises:
ValueError: When set pad_to_max_output_size to False for batched input.
"""
def _sort_scores_and_boxes(scores, boxes):
"""Sort boxes based their score from highest to lowest.
Args:
scores: a tensor with a shape of [batch_size, num_boxes] representing
the scores of boxes.
boxes: a tensor with a shape of [batch_size, num_boxes, 4] representing
the boxes.
Returns:
sorted_scores: a tensor with a shape of [batch_size, num_boxes]
representing the sorted scores.
sorted_boxes: a tensor representing the sorted boxes.
sorted_scores_indices: a tensor with a shape of [batch_size, num_boxes]
representing the index of the scores in a sorted descending order.
"""
with ops.name_scope('sort_scores_and_boxes'):
batch_size = array_ops.shape(boxes)[0]
num_boxes = array_ops.shape(boxes)[1]
sorted_scores_indices = sort_ops.argsort(
scores, axis=1, direction='DESCENDING')
index_offsets = math_ops.range(batch_size) * num_boxes
indices = array_ops.reshape(
sorted_scores_indices + array_ops.expand_dims(index_offsets, 1), [-1])
sorted_scores = array_ops.reshape(
array_ops.gather(array_ops.reshape(scores, [-1]), indices),
[batch_size, -1])
sorted_boxes = array_ops.reshape(
array_ops.gather(array_ops.reshape(boxes, [-1, 4]), indices),
[batch_size, -1, 4])
return sorted_scores, sorted_boxes, sorted_scores_indices
with ops.name_scope(name, 'batched_non_max_suppression'):
if boxes.get_shape().ndims > 2 and not pad_to_max_output_size:
raise ValueError("'pad_to_max_output_size' (value {}) must be "
"True for batched input".format(pad_to_max_output_size))
batch_dims = boxes.get_shape().as_list()[:-2]
num_boxes = array_ops.shape(boxes)[-2]
boxes = array_ops.reshape(boxes, [-1, num_boxes, 4])
scores = array_ops.reshape(scores, [-1, num_boxes])
batch_size = array_ops.shape(boxes)[0]
if score_threshold != float('-inf'):
with ops.name_scope('filter_by_score'):
score_mask = math_ops.cast(scores >= score_threshold, scores.dtype)
scores *= score_mask
box_mask = array_ops.expand_dims(
math_ops.cast(score_mask, boxes.dtype), 2)
boxes *= box_mask
if not canonicalized_coordinates:
with ops.name_scope('canonicalize_coordinates'):
y_1, x_1, y_2, x_2 = array_ops.split(
value=boxes, num_or_size_splits=4, axis=2)
y_1_is_min = math_ops.less(y_1[0, 0, 0], y_2[0, 0, 0])
y_min, y_max = control_flow_ops.cond(
y_1_is_min, lambda: (y_1, y_2), lambda: (y_2, y_1))
x_1_is_min = math_ops.less(x_1[0, 0, 0], x_2[0, 0, 0])
x_min, x_max = control_flow_ops.cond(
x_1_is_min, lambda: (x_1, x_2), lambda: (x_2, x_1))
boxes = array_ops.concat([y_min, x_min, y_max, x_max], axis=2)
if not sorted_input:
scores, boxes, sorted_indices = _sort_scores_and_boxes(scores, boxes)
pad = math_ops.cast(
math_ops.ceil(
math_ops.cast(num_boxes, dtypes.float32) / tile_size),
dtypes.int32) * tile_size - num_boxes
boxes = array_ops.pad(
math_ops.cast(boxes, dtypes.float32), [[0, 0], [0, pad], [0, 0]])
scores = array_ops.pad(
math_ops.cast(scores, dtypes.float32), [[0, 0], [0, pad]])
num_boxes_after_padding = num_boxes + pad
def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
return math_ops.logical_and(
math_ops.reduce_min(output_size) < max_output_size,
idx < num_boxes_after_padding // tile_size)
def suppression_loop_body(boxes, iou_threshold, output_size, idx):
return _suppression_loop_body(
boxes, iou_threshold, output_size, idx, tile_size)
selected_boxes, _, output_size, _ = control_flow_ops.while_loop(
_loop_cond, suppression_loop_body,
[boxes, iou_threshold, array_ops.zeros([batch_size], dtypes.int32),
constant_op.constant(0)]
)
num_valid = math_ops.minimum(output_size, max_output_size)
idx = num_boxes_after_padding - math_ops.cast(
nn_ops.top_k(
math_ops.cast(math_ops.reduce_any(
selected_boxes > 0, [2]), dtypes.int32) *
array_ops.expand_dims(
math_ops.range(num_boxes_after_padding, 0, -1), 0),
max_output_size)[0], dtypes.int32)
idx = math_ops.minimum(idx, num_boxes - 1)
if not sorted_input:
index_offsets = math_ops.range(batch_size) * num_boxes
gather_idx = array_ops.reshape(
idx + array_ops.expand_dims(index_offsets, 1), [-1])
idx = array_ops.reshape(
array_ops.gather(array_ops.reshape(sorted_indices, [-1]),
gather_idx),
[batch_size, -1])
num_valid = array_ops.reshape(num_valid, batch_dims)
if not pad_to_max_output_size:
idx = idx[0, :num_valid]
batch_dims.append(-1)
idx = array_ops.reshape(idx, batch_dims)
return idx, num_valid
@tf_export('image.draw_bounding_boxes', v1=[])
def draw_bounding_boxes_v2(images, boxes, colors, name=None):
"""Draw bounding boxes on a batch of images.

View File

@ -4615,7 +4615,9 @@ class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase):
self.assertEqual(selected_indices_padded.shape.is_fully_defined(), True)
self.assertEqual(selected_indices.shape.is_fully_defined(), False)
with self.cached_session():
self.assertAllClose(selected_indices_padded.eval(), [3, 0, 5, 0, 0])
invalid_index = len(boxes_np) - 1
self.assertAllClose(selected_indices_padded.eval(),
[3, 0, 5, invalid_index, invalid_index])
self.assertEqual(num_valid_padded.eval(), 3)
self.assertAllClose(selected_indices.eval(), [3, 0, 5])
self.assertEqual(num_valid.eval(), 3)

View File

@ -138,7 +138,7 @@ tf_module {
}
member_method {
name: "non_max_suppression_padded"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\', \'sorted_input\', \'canonicalized_coordinates\', \'tile_size\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\', \'False\', \'False\', \'512\'], "
}
member_method {
name: "non_max_suppression_with_scores"

View File

@ -134,7 +134,7 @@ tf_module {
}
member_method {
name: "non_max_suppression_padded"
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\', \'sorted_input\', \'canonicalized_coordinates\', \'tile_size\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\', \'False\', \'False\', \'512\'], "
}
member_method {
name: "non_max_suppression_with_scores"