diff --git a/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py b/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py index 6aa757e293e..22644736a16 100644 --- a/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py +++ b/tensorflow/python/kernel_tests/draw_bounding_box_op_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import image_ops_impl from tensorflow.python.ops import math_ops @@ -54,17 +55,21 @@ class DrawBoundingBoxOpTest(test.TestCase): image[height - 1, 0:width, 0:depth] = color return image - def _testDrawBoundingBoxColorCycling(self, img): + def _testDrawBoundingBoxColorCycling(self, img, colors=None): """Tests if cycling works appropriately. Args: img: 3-D numpy image on which to draw. """ # THIS TABLE MUST MATCH draw_bounding_box_op.cc - color_table = np.asarray([[1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 0, 1], - [0, 1, 0, 1], [0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], - [0.5, 0, 0, 1], [0, 0, 0.5, 1], [0, 1, 1, 1], - [1, 0, 1, 1]]) + default_color_table = np.asarray([[1, 1, 0, 1], [0, 0, 1, 1], + [1, 0, 0, 1], [0, 1, 0, 1], + [0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], + [0.5, 0, 0, 1], [0, 0, 0.5, 1], + [0, 1, 1, 1], [1, 0, 1, 1]]) + color_table = default_color_table + if colors is not None: + color_table = colors assert len(img.shape) == 3 depth = img.shape[2] assert depth <= color_table.shape[1] @@ -85,9 +90,12 @@ class DrawBoundingBoxOpTest(test.TestCase): image = ops.convert_to_tensor(image) image = image_ops_impl.convert_image_dtype(image, dtypes.float32) image = array_ops.expand_dims(image, 0) - image = image_ops.draw_bounding_boxes(image, bboxes) + if colors is None: + image = image_ops.draw_bounding_boxes(image, bboxes) + else: + image = gen_image_ops.draw_bounding_boxes_v2(image, bboxes, colors) with self.cached_session(use_gpu=False) as sess: - op_drawn_image = np.squeeze(self.evaluate(image), 0) + op_drawn_image = np.squeeze(sess.run(image), 0) self.assertAllEqual(test_drawn_image, op_drawn_image) def testDrawBoundingBoxRGBColorCycling(self): @@ -105,6 +113,21 @@ class DrawBoundingBoxOpTest(test.TestCase): image = np.zeros([4, 4, 1], "float32") self._testDrawBoundingBoxColorCycling(image) + def testDrawBoundingBoxRGBColorCyclingWithColors(self): + """Test if RGB color cycling works correctly with provided colors.""" + image = np.zeros([10, 10, 3], "float32") + colors = np.asarray([[1, 1, 0, 1], [0, 0, 1, 1], + [0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], + [0, 1, 1, 1], [1, 0, 1, 1]]) + self._testDrawBoundingBoxColorCycling(image, colors=colors) + + def testDrawBoundingBoxRGBAColorCyclingWithColors(self): + """Test if RGBA color cycling works correctly with provided colors.""" + image = np.zeros([10, 10, 4], "float32") + colors = np.asarray([[0.5, 0, 0.5, 1], [0.5, 0.5, 0, 1], + [0.5, 0, 0, 1], [0, 0, 0.5, 1]]) + self._testDrawBoundingBoxColorCycling(image, colors=colors) + if __name__ == "__main__": test.main()