Add colors support for tf.image.draw_bounding_boxes

This fix tries to address the issue raised in 15692 where
it was not possible to specify the colors for boxes in
`tf.image.draw_bounding_boxes`. Instead, a predefined
fixed color table was used to cycle through colors.

This fix adds `colors` Input to `DrawBoundingBoxexV2`
so that it is possible to specify the color. In case
no color is specified, the default color table will
be used.

Since there is an API change, the op is labeled as V2.

This fix fixes 15692.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2018-01-01 23:51:37 +00:00
parent 2b9900638a
commit 031f7790ef

View File

@ -52,7 +52,6 @@ class DrawBoundingBoxesOp : public OpKernel {
const int64 batch_size = images.dim_size(0);
const int64 height = images.dim_size(1);
const int64 width = images.dim_size(2);
const int64 color_table_length = 10;
// 0: yellow
// 1: blue
@ -64,7 +63,7 @@ class DrawBoundingBoxesOp : public OpKernel {
// 7: navy blue
// 8: aqua
// 9: fuchsia
float color_table[color_table_length][4] = {
std::vector<std::array<float, 4>> color_table = {
{1, 1, 0, 1}, {0, 0, 1, 1}, {1, 0, 0, 1}, {0, 1, 0, 1},
{0.5, 0, 0.5, 1}, {0.5, 0.5, 0, 1}, {0.5, 0, 0, 1}, {0, 0, 0.5, 1},
{0, 1, 1, 1}, {1, 0, 1, 1},
@ -73,10 +72,23 @@ class DrawBoundingBoxesOp : public OpKernel {
// Reset first color channel to 1 if image is GRY.
// For GRY images, this means all bounding boxes will be white.
if (depth == 1) {
for (int64 i = 0; i < color_table_length; i++) {
for (int64 i = 0; i < color_table.size(); i++) {
color_table[i][0] = 1;
}
}
if (context->num_inputs() == 3) {
const Tensor& colors_tensor = context->input(2);
OP_REQUIRES(context, colors_tensor.shape().dims() == 2, errors::InvalidArgument("colors must be a 2-D matrix", colors_tensor.shape().DebugString()));
OP_REQUIRES(context, colors_tensor.shape().dim_size(1) == 4, errors::InvalidArgument("colors must be n x 4 (RGBA)", colors_tensor.shape().DebugString()));
if (colors_tensor.NumElements() != 0) {
color_table.clear();
auto colors = colors_tensor.matrix<float>();
for (int64 i = 0; i < colors.dimension(0); i++) {
color_table.emplace_back(std::array<float, 4>{colors(i, 0), colors(i, 1), colors(i, 2), colors(i, 3)});
}
}
}
Tensor* output;
OP_REQUIRES_OK(
context,
@ -90,7 +102,7 @@ class DrawBoundingBoxesOp : public OpKernel {
const int64 num_boxes = boxes.dim_size(1);
const auto tboxes = boxes.tensor<T, 3>();
for (int64 bb = 0; bb < num_boxes; ++bb) {
int64 color_index = bb % color_table_length;
int64 color_index = bb % color_table.size();
const int64 min_box_row =
static_cast<float>(tboxes(b, bb, 0)) * (height - 1);
const int64 min_box_row_clamp = std::max<int64>(min_box_row, int64{0});
@ -179,6 +191,9 @@ class DrawBoundingBoxesOp : public OpKernel {
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("DrawBoundingBoxes").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
DrawBoundingBoxesOp<T>); \
REGISTER_KERNEL_BUILDER( \
Name("DrawBoundingBoxesV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
DrawBoundingBoxesOp<T>);
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);