From 031f7790efd75742f9e8541d6c9107483418abfe Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Mon, 1 Jan 2018 23:51:37 +0000 Subject: [PATCH] Add `colors` support for `tf.image.draw_bounding_boxes` This fix tries to address the issue raised in 15692 where it was not possible to specify the colors for boxes in `tf.image.draw_bounding_boxes`. Instead, a predefined fixed color table was used to cycle through colors. This fix adds `colors` Input to `DrawBoundingBoxexV2` so that it is possible to specify the color. In case no color is specified, the default color table will be used. Since there is an API change, the op is labeled as V2. This fix fixes 15692. Signed-off-by: Yong Tang --- .../core/kernels/draw_bounding_box_op.cc | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc index 618c47e6848..de9223b17b6 100644 --- a/tensorflow/core/kernels/draw_bounding_box_op.cc +++ b/tensorflow/core/kernels/draw_bounding_box_op.cc @@ -52,7 +52,6 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 batch_size = images.dim_size(0); const int64 height = images.dim_size(1); const int64 width = images.dim_size(2); - const int64 color_table_length = 10; // 0: yellow // 1: blue @@ -64,7 +63,7 @@ class DrawBoundingBoxesOp : public OpKernel { // 7: navy blue // 8: aqua // 9: fuchsia - float color_table[color_table_length][4] = { + std::vector> color_table = { {1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1}, {0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1}, {0, 1, 1, 1}, {1, 0, 1, 1}, @@ -73,10 +72,23 @@ class DrawBoundingBoxesOp : public OpKernel { // Reset first color channel to 1 if image is GRY. // For GRY images, this means all bounding boxes will be white. if (depth == 1) { - for (int64 i = 0; i < color_table_length; i++) { + for (int64 i = 0; i < color_table.size(); i++) { color_table[i][0] = 1; } } + if (context->num_inputs() == 3) { + const Tensor& colors_tensor = context->input(2); + OP_REQUIRES(context, colors_tensor.shape().dims() == 2, errors::InvalidArgument("colors must be a 2-D matrix", colors_tensor.shape().DebugString())); + OP_REQUIRES(context, colors_tensor.shape().dim_size(1) == 4, errors::InvalidArgument("colors must be n x 4 (RGBA)", colors_tensor.shape().DebugString())); + if (colors_tensor.NumElements() != 0) { + color_table.clear(); + + auto colors = colors_tensor.matrix(); + for (int64 i = 0; i < colors.dimension(0); i++) { + color_table.emplace_back(std::array{colors(i, 0), colors(i, 1), colors(i, 2), colors(i, 3)}); + } + } + } Tensor* output; OP_REQUIRES_OK( context, @@ -90,7 +102,7 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 num_boxes = boxes.dim_size(1); const auto tboxes = boxes.tensor(); for (int64 bb = 0; bb < num_boxes; ++bb) { - int64 color_index = bb % color_table_length; + int64 color_index = bb % color_table.size(); const int64 min_box_row = static_cast(tboxes(b, bb, 0)) * (height - 1); const int64 min_box_row_clamp = std::max(min_box_row, int64{0}); @@ -179,6 +191,9 @@ class DrawBoundingBoxesOp : public OpKernel { #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint("T"), \ + DrawBoundingBoxesOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("DrawBoundingBoxesV2").Device(DEVICE_CPU).TypeConstraint("T"), \ DrawBoundingBoxesOp); TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL);