diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc index 618c47e6848..de9223b17b6 100644 --- a/tensorflow/core/kernels/draw_bounding_box_op.cc +++ b/tensorflow/core/kernels/draw_bounding_box_op.cc @@ -52,7 +52,6 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 batch_size = images.dim_size(0); const int64 height = images.dim_size(1); const int64 width = images.dim_size(2); - const int64 color_table_length = 10; // 0: yellow // 1: blue @@ -64,7 +63,7 @@ class DrawBoundingBoxesOp : public OpKernel { // 7: navy blue // 8: aqua // 9: fuchsia - float color_table[color_table_length][4] = { + std::vector> color_table = { {1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1}, {0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1}, {0, 1, 1, 1}, {1, 0, 1, 1}, @@ -73,10 +72,23 @@ class DrawBoundingBoxesOp : public OpKernel { // Reset first color channel to 1 if image is GRY. // For GRY images, this means all bounding boxes will be white. if (depth == 1) { - for (int64 i = 0; i < color_table_length; i++) { + for (int64 i = 0; i < color_table.size(); i++) { color_table[i][0] = 1; } } + if (context->num_inputs() == 3) { + const Tensor& colors_tensor = context->input(2); + OP_REQUIRES(context, colors_tensor.shape().dims() == 2, errors::InvalidArgument("colors must be a 2-D matrix", colors_tensor.shape().DebugString())); + OP_REQUIRES(context, colors_tensor.shape().dim_size(1) == 4, errors::InvalidArgument("colors must be n x 4 (RGBA)", colors_tensor.shape().DebugString())); + if (colors_tensor.NumElements() != 0) { + color_table.clear(); + + auto colors = colors_tensor.matrix(); + for (int64 i = 0; i < colors.dimension(0); i++) { + color_table.emplace_back(std::array{colors(i, 0), colors(i, 1), colors(i, 2), colors(i, 3)}); + } + } + } Tensor* output; OP_REQUIRES_OK( context, @@ -90,7 +102,7 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 num_boxes = boxes.dim_size(1); const auto tboxes = boxes.tensor(); for (int64 bb = 0; bb < num_boxes; ++bb) { - int64 color_index = bb % color_table_length; + int64 color_index = bb % color_table.size(); const int64 min_box_row = static_cast(tboxes(b, bb, 0)) * (height - 1); const int64 min_box_row_clamp = std::max(min_box_row, int64{0}); @@ -179,6 +191,9 @@ class DrawBoundingBoxesOp : public OpKernel { #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint("T"), \ + DrawBoundingBoxesOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("DrawBoundingBoxesV2").Device(DEVICE_CPU).TypeConstraint("T"), \ DrawBoundingBoxesOp); TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL);