Disable MLIR bridge for NMS image ops test
MLIR bridge doesn't support tf.NonMaxSuppressionV4 legalization that is conditionally generated by non_max_suppression_padded function. PiperOrigin-RevId: 320197235 Change-Id: If7242133254680b366771ced50de074ed6180563
This commit is contained in:
parent
1725ab6962
commit
bf61dd6420
|
@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_image_ops
|
from tensorflow.python.ops import gen_image_ops
|
||||||
from tensorflow.python.ops import image_ops
|
from tensorflow.python.ops import image_ops
|
||||||
|
@ -774,6 +775,7 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
class NonMaxSuppressionTest(xla_test.XLATestCase):
|
class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS128From1024(self):
|
def testNMS128From1024(self):
|
||||||
num_boxes = 1024
|
num_boxes = 1024
|
||||||
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
|
boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4")
|
||||||
|
@ -808,6 +810,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
self.assertEqual(indices_tf.size, max_output_size)
|
self.assertEqual(indices_tf.size, max_output_size)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS3From6Boxes(self):
|
def testNMS3From6Boxes(self):
|
||||||
# Three boxes are selected based on IOU.
|
# Three boxes are selected based on IOU.
|
||||||
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
|
||||||
|
@ -849,6 +852,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
self.assertEqual(num_valid, 3)
|
self.assertEqual(num_valid, 3)
|
||||||
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
|
self.assertAllClose(indices_tf[:num_valid], [3, 0, 5])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS3Then2WithScoreThresh(self):
|
def testNMS3Then2WithScoreThresh(self):
|
||||||
# Three boxes are selected based on IOU.
|
# Three boxes are selected based on IOU.
|
||||||
# One is filtered out by score threshold.
|
# One is filtered out by score threshold.
|
||||||
|
@ -891,6 +895,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
self.assertEqual(num_valid, 2)
|
self.assertEqual(num_valid, 2)
|
||||||
self.assertAllClose(indices_tf[:num_valid], [3, 0])
|
self.assertAllClose(indices_tf[:num_valid], [3, 0])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testNMS3Then1WithScoreMaxThresh(self):
|
def testNMS3Then1WithScoreMaxThresh(self):
|
||||||
# Three boxes are selected based on IOU.
|
# Three boxes are selected based on IOU.
|
||||||
# One is filtered out by score threshold.
|
# One is filtered out by score threshold.
|
||||||
|
@ -934,6 +939,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
self.assertEqual(num_valid, 1)
|
self.assertEqual(num_valid, 1)
|
||||||
self.assertAllClose(indices_tf[:num_valid], [3])
|
self.assertAllClose(indices_tf[:num_valid], [3])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testSelectFromContinuousOverLap(self):
|
def testSelectFromContinuousOverLap(self):
|
||||||
# Tests that a suppressed box does not itself suppress other boxes.
|
# Tests that a suppressed box does not itself suppress other boxes.
|
||||||
|
|
||||||
|
@ -978,6 +984,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSFrom6(self):
|
def testBatchedNMSFrom6(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
@ -1015,6 +1022,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
indices_output)
|
indices_output)
|
||||||
self.assertAllEqual([5, 4], num_valid_output)
|
self.assertAllEqual([5, 4], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSFrom6Max3(self):
|
def testBatchedNMSFrom6Max3(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
@ -1048,6 +1056,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output)
|
||||||
self.assertAllEqual([3, 3], num_valid_output)
|
self.assertAllEqual([3, 3], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSSingleFrom6Max3(self):
|
def testBatchedNMSSingleFrom6Max3(self):
|
||||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||||
|
@ -1078,6 +1087,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
self.assertAllEqual([0, 1, 2], indices_output)
|
self.assertAllEqual([0, 1, 2], indices_output)
|
||||||
self.assertAllEqual(3, num_valid_output)
|
self.assertAllEqual(3, num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSSingleFrom6NoPad(self):
|
def testBatchedNMSSingleFrom6NoPad(self):
|
||||||
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]]
|
||||||
|
@ -1107,6 +1117,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
self.assertAllEqual([0, 1, 2, 4, 5], indices_output)
|
||||||
self.assertAllEqual(5, num_valid_output)
|
self.assertAllEqual(5, num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSBatchDimsFrom6Max3(self):
|
def testBatchedNMSBatchDimsFrom6Max3(self):
|
||||||
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
@ -1140,6 +1151,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output)
|
||||||
self.assertAllEqual([[3, 3]], num_valid_output)
|
self.assertAllEqual([[3, 3]], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSScoreThresholdFrom6Max3(self):
|
def testBatchedNMSScoreThresholdFrom6Max3(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
@ -1175,6 +1187,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
self.assertAllEqual([3, 2], num_valid_output)
|
self.assertAllEqual([3, 2], num_valid_output)
|
||||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSUnsortedInputFrom6(self):
|
def testBatchedNMSUnsortedInputFrom6(self):
|
||||||
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1],
|
||||||
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
[0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]],
|
||||||
|
@ -1211,6 +1224,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
indices_output)
|
indices_output)
|
||||||
self.assertAllEqual([5, 4], num_valid_output)
|
self.assertAllEqual([5, 4], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
def testBatchedNMSNoncanonicalizedInputFrom6(self):
|
||||||
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4],
|
||||||
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
[1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]],
|
||||||
|
@ -1248,6 +1262,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
indices_output)
|
indices_output)
|
||||||
self.assertAllEqual([5, 4], num_valid_output)
|
self.assertAllEqual([5, 4], num_valid_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
@ -1283,6 +1298,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase):
|
||||||
self.assertAllEqual([3, 2], num_valid_output)
|
self.assertAllEqual([3, 2], num_valid_output)
|
||||||
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("%1")
|
||||||
def testBatchedNMSFrom6DynamicInput(self):
|
def testBatchedNMSFrom6DynamicInput(self):
|
||||||
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4],
|
||||||
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
|
||||||
|
|
Loading…
Reference in New Issue