diff --git a/tensorflow/core/kernels/draw_bounding_box_op.cc b/tensorflow/core/kernels/draw_bounding_box_op.cc index c75a394992b..e493d4c5dc8 100644 --- a/tensorflow/core/kernels/draw_bounding_box_op.cc +++ b/tensorflow/core/kernels/draw_bounding_box_op.cc @@ -25,6 +25,30 @@ limitations under the License. namespace tensorflow { +namespace { + +std::vector> DefaultColorTable(int depth) { + std::vector> color_table; + color_table.emplace_back(std::vector({1, 1, 0, 1})); // 0: yellow + color_table.emplace_back(std::vector({0, 0, 1, 1})); // 1: blue + color_table.emplace_back(std::vector({1, 0, 0, 1})); // 2: red + color_table.emplace_back(std::vector({0, 1, 0, 1})); // 3: lime + color_table.emplace_back(std::vector({0.5, 0, 0.5, 1})); // 4: purple + color_table.emplace_back(std::vector({0.5, 0.5, 0, 1})); // 5: olive + color_table.emplace_back(std::vector({0.5, 0, 0, 1})); // 6: maroon + color_table.emplace_back(std::vector({0, 0, 0.5, 1})); // 7: navy blue + color_table.emplace_back(std::vector({0, 1, 1, 1})); // 8: aqua + color_table.emplace_back(std::vector({1, 0, 1, 1})); // 9: fuchsia + + if (depth == 1) { + for (int64 i = 0; i < color_table.size(); i++) { + color_table[i][0] = 1; + } + } + return color_table; +} +} // namespace + template class DrawBoundingBoxesOp : public OpKernel { public: @@ -52,47 +76,15 @@ class DrawBoundingBoxesOp : public OpKernel { const int64 batch_size = images.dim_size(0); const int64 height = images.dim_size(1); const int64 width = images.dim_size(2); - const int64 default_color_table_length = 10; - - // 0: yellow - // 1: blue - // 2: red - // 3: lime - // 4: purple - // 5: olive - // 6: maroon - // 7: navy blue - // 8: aqua - // 9: fuchsia - float default_color_table[default_color_table_length][4] = { - {1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1}, - {0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1}, - {0, 1, 1, 1}, {1, 0, 1, 1}, - }; - std::vector> color_table; - for (int64 i = 0; i < default_color_table_length; i++) { - std::vector color_value(4); - for (int64 j = 0; j < 4; j++) { - color_value[j] = default_color_table[i][j]; - } - color_table.emplace_back(color_value); - } - - // Reset first color channel to 1 if image is GRY. - // For GRY images, this means all bounding boxes will be white. - if (depth == 1) { - for (int64 i = 0; i < color_table.size(); i++) { - color_table[i][0] = 1; - } - } if (context->num_inputs() == 3) { const Tensor& colors_tensor = context->input(2); OP_REQUIRES(context, colors_tensor.shape().dims() == 2, errors::InvalidArgument("colors must be a 2-D matrix", colors_tensor.shape().DebugString())); - OP_REQUIRES(context, colors_tensor.shape().dim_size(1) == 4, - errors::InvalidArgument("colors must be n x 4 (RGBA)", + OP_REQUIRES(context, colors_tensor.shape().dim_size(1) >= depth, + errors::InvalidArgument("colors must have equal or more ", + "channels than the image provided: ", colors_tensor.shape().DebugString())); if (colors_tensor.NumElements() != 0) { color_table.clear(); @@ -107,6 +99,9 @@ class DrawBoundingBoxesOp : public OpKernel { } } } + if (color_table.empty()) { + color_table = DefaultColorTable(depth); + } Tensor* output; OP_REQUIRES_OK( context,