Disable MLIR bridge for NMS image ops test
MLIR bridge doesn't support tf.NonMaxSuppressionV4 legalization that is conditionally generated by non_max_suppression_padded function. PiperOrigin-RevId: 320197235 Change-Id: If7242133254680b366771ced50de074ed6180563
This commit is contained in:
		
							parent
							
								
									1725ab6962
								
							
						
					
					
						commit
						bf61dd6420
					
				| @ -30,6 +30,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin | |||||||
| from tensorflow.compiler.tests import xla_test | from tensorflow.compiler.tests import xla_test | ||||||
| from tensorflow.python.framework import dtypes | from tensorflow.python.framework import dtypes | ||||||
| from tensorflow.python.framework import ops | from tensorflow.python.framework import ops | ||||||
|  | from tensorflow.python.framework import test_util | ||||||
| from tensorflow.python.ops import array_ops | from tensorflow.python.ops import array_ops | ||||||
| from tensorflow.python.ops import gen_image_ops | from tensorflow.python.ops import gen_image_ops | ||||||
| from tensorflow.python.ops import image_ops | from tensorflow.python.ops import image_ops | ||||||
| @ -774,6 +775,7 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase): | |||||||
| 
 | 
 | ||||||
| class NonMaxSuppressionTest(xla_test.XLATestCase): | class NonMaxSuppressionTest(xla_test.XLATestCase): | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testNMS128From1024(self): |   def testNMS128From1024(self): | ||||||
|     num_boxes = 1024 |     num_boxes = 1024 | ||||||
|     boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") |     boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") | ||||||
| @ -808,6 +810,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): | |||||||
| 
 | 
 | ||||||
|       self.assertEqual(indices_tf.size, max_output_size) |       self.assertEqual(indices_tf.size, max_output_size) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testNMS3From6Boxes(self): |   def testNMS3From6Boxes(self): | ||||||
|     # Three boxes are selected based on IOU. |     # Three boxes are selected based on IOU. | ||||||
|     boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], |     boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], | ||||||
| @ -849,6 +852,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): | |||||||
|       self.assertEqual(num_valid, 3) |       self.assertEqual(num_valid, 3) | ||||||
|       self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) |       self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testNMS3Then2WithScoreThresh(self): |   def testNMS3Then2WithScoreThresh(self): | ||||||
|     # Three boxes are selected based on IOU. |     # Three boxes are selected based on IOU. | ||||||
|     # One is filtered out by score threshold. |     # One is filtered out by score threshold. | ||||||
| @ -891,6 +895,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): | |||||||
|       self.assertEqual(num_valid, 2) |       self.assertEqual(num_valid, 2) | ||||||
|       self.assertAllClose(indices_tf[:num_valid], [3, 0]) |       self.assertAllClose(indices_tf[:num_valid], [3, 0]) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testNMS3Then1WithScoreMaxThresh(self): |   def testNMS3Then1WithScoreMaxThresh(self): | ||||||
|     # Three boxes are selected based on IOU. |     # Three boxes are selected based on IOU. | ||||||
|     # One is filtered out by score threshold. |     # One is filtered out by score threshold. | ||||||
| @ -934,6 +939,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): | |||||||
|       self.assertEqual(num_valid, 1) |       self.assertEqual(num_valid, 1) | ||||||
|       self.assertAllClose(indices_tf[:num_valid], [3]) |       self.assertAllClose(indices_tf[:num_valid], [3]) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testSelectFromContinuousOverLap(self): |   def testSelectFromContinuousOverLap(self): | ||||||
|     # Tests that a suppressed box does not itself suppress other boxes. |     # Tests that a suppressed box does not itself suppress other boxes. | ||||||
| 
 | 
 | ||||||
| @ -978,6 +984,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): | |||||||
| 
 | 
 | ||||||
| class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSFrom6(self): |   def testBatchedNMSFrom6(self): | ||||||
|     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], |                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], | ||||||
| @ -1015,6 +1022,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|                         indices_output) |                         indices_output) | ||||||
|     self.assertAllEqual([5, 4], num_valid_output) |     self.assertAllEqual([5, 4], num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSFrom6Max3(self): |   def testBatchedNMSFrom6Max3(self): | ||||||
|     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], |                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], | ||||||
| @ -1048,6 +1056,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|     self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) |     self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) | ||||||
|     self.assertAllEqual([3, 3], num_valid_output) |     self.assertAllEqual([3, 3], num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSSingleFrom6Max3(self): |   def testBatchedNMSSingleFrom6Max3(self): | ||||||
|     boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                   [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] |                   [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] | ||||||
| @ -1078,6 +1087,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|     self.assertAllEqual([0, 1, 2], indices_output) |     self.assertAllEqual([0, 1, 2], indices_output) | ||||||
|     self.assertAllEqual(3, num_valid_output) |     self.assertAllEqual(3, num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSSingleFrom6NoPad(self): |   def testBatchedNMSSingleFrom6NoPad(self): | ||||||
|     boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                   [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] |                   [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] | ||||||
| @ -1107,6 +1117,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|     self.assertAllEqual([0, 1, 2, 4, 5], indices_output) |     self.assertAllEqual([0, 1, 2, 4, 5], indices_output) | ||||||
|     self.assertAllEqual(5, num_valid_output) |     self.assertAllEqual(5, num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSBatchDimsFrom6Max3(self): |   def testBatchedNMSBatchDimsFrom6Max3(self): | ||||||
|     boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                     [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], |                     [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], | ||||||
| @ -1140,6 +1151,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|     self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) |     self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) | ||||||
|     self.assertAllEqual([[3, 3]], num_valid_output) |     self.assertAllEqual([[3, 3]], num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSScoreThresholdFrom6Max3(self): |   def testBatchedNMSScoreThresholdFrom6Max3(self): | ||||||
|     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], |                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], | ||||||
| @ -1175,6 +1187,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|     self.assertAllEqual([3, 2], num_valid_output) |     self.assertAllEqual([3, 2], num_valid_output) | ||||||
|     self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) |     self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSUnsortedInputFrom6(self): |   def testBatchedNMSUnsortedInputFrom6(self): | ||||||
|     boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], |     boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], | ||||||
|                    [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], |                    [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], | ||||||
| @ -1211,6 +1224,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|                         indices_output) |                         indices_output) | ||||||
|     self.assertAllEqual([5, 4], num_valid_output) |     self.assertAllEqual([5, 4], num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSNoncanonicalizedInputFrom6(self): |   def testBatchedNMSNoncanonicalizedInputFrom6(self): | ||||||
|     boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], |     boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], | ||||||
|                    [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], |                    [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], | ||||||
| @ -1248,6 +1262,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|                         indices_output) |                         indices_output) | ||||||
|     self.assertAllEqual([5, 4], num_valid_output) |     self.assertAllEqual([5, 4], num_valid_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): |   def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): | ||||||
|     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], |                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], | ||||||
| @ -1283,6 +1298,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): | |||||||
|     self.assertAllEqual([3, 2], num_valid_output) |     self.assertAllEqual([3, 2], num_valid_output) | ||||||
|     self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) |     self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) | ||||||
| 
 | 
 | ||||||
|  |   @test_util.disable_mlir_bridge("%1") | ||||||
|   def testBatchedNMSFrom6DynamicInput(self): |   def testBatchedNMSFrom6DynamicInput(self): | ||||||
|     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], |     boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], | ||||||
|                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], |                    [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user