Improve shape function of NonMaxSuppression (#16664)
* Improve shape function of NonMaxSuppression This fix tries to improve shape function of NonMaxSuppression. As was specified in the docs, the shapes of parameters of `tf.image.non_max_suppression` are clearly defined with: boxes: 2-D with shape [num_boxes, 4] scores: 1-D with shape [num_boxes] max_output_size: 0-D scalar iou_threshold: 0-D scalar However, there is no shape check in the shape function of NonMaxSuppression. This fix adds the shape check for NonMaxSuppression, and adds additinal test cases for it. Signed-off-by: Yong Tang <yong.tang.github@outlook.com> * Add additional test cases for shape check of NonMaxSuppression. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
474bf4e045
commit
087401a6a9
@ -586,6 +586,17 @@ REGISTER_OP("NonMaxSuppression")
|
||||
.Output("selected_indices: int32")
|
||||
.Attr("iou_threshold: float = 0.5")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Get inputs and validate ranks.
|
||||
ShapeHandle boxes;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
|
||||
ShapeHandle scores;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
|
||||
ShapeHandle max_output_size;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
|
||||
// The boxes is a 2-D float Tensor of shape [num_boxes, 4].
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
|
||||
|
||||
c->set_output(0, c->Vector(c->UnknownDim()));
|
||||
return Status::OK();
|
||||
});
|
||||
@ -597,6 +608,19 @@ REGISTER_OP("NonMaxSuppressionV2")
|
||||
.Input("iou_threshold: float")
|
||||
.Output("selected_indices: int32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
// Get inputs and validate ranks.
|
||||
ShapeHandle boxes;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
|
||||
ShapeHandle scores;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
|
||||
ShapeHandle max_output_size;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
|
||||
ShapeHandle iou_threshold;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
|
||||
// The boxes is a 2-D float Tensor of shape [num_boxes, 4].
|
||||
DimensionHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
|
||||
|
||||
c->set_output(0, c->Vector(c->UnknownDim()));
|
||||
return Status::OK();
|
||||
});
|
||||
|
@ -3169,6 +3169,46 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
|
||||
boxes, scores, max_output_size, iou_threshold).eval()
|
||||
self.assertAllClose(selected_indices, [3, 0, 5])
|
||||
|
||||
def testInvalidShape(self):
|
||||
# The boxes should be 2D of shape [num_boxes, 4].
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Shape must be rank 2 but is rank 1'):
|
||||
boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0])
|
||||
scores = constant_op.constant([0.9])
|
||||
selected_indices = image_ops.non_max_suppression(
|
||||
boxes, scores, 3, 0.5)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Dimension must be 4 but is 3'):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0]])
|
||||
scores = constant_op.constant([0.9])
|
||||
selected_indices = image_ops.non_max_suppression(
|
||||
boxes, scores, 3, 0.5)
|
||||
|
||||
# The scores should be 1D of shape [num_boxes].
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Shape must be rank 1 but is rank 2'):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([[0.9]])
|
||||
selected_indices = image_ops.non_max_suppression(
|
||||
boxes, scores, 3, 0.5)
|
||||
|
||||
# The max_output_size should be a scaler (0-D).
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Shape must be rank 0 but is rank 1'):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([0.9])
|
||||
selected_indices = image_ops.non_max_suppression(
|
||||
boxes, scores, [3], 0.5)
|
||||
|
||||
# The iou_threshold should be a scaler (0-D).
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Shape must be rank 0 but is rank 2'):
|
||||
boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
|
||||
scores = constant_op.constant([0.9])
|
||||
selected_indices = image_ops.non_max_suppression(
|
||||
boxes, scores, 3, [[0.5]])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user