From 6f031a7728deb2a6bc599cbed9f109c4300d478c Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Wed, 15 May 2019 16:10:45 -0700 Subject: [PATCH 01/17] Adding GenerateBoxProposals op --- tensorflow/core/kernels/BUILD | 7 + .../kernels/generate_box_proposals_op.cu.cc | 631 ++++++++++++++++++ tensorflow/core/ops/image_ops.cc | 50 ++ 3 files changed, 688 insertions(+) create mode 100644 tensorflow/core/kernels/generate_box_proposals_op.cu.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3746007278e..85d83357e39 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2636,6 +2636,7 @@ cc_library( ":encode_jpeg_op", ":encode_png_op", ":extract_jpeg_shape_op", + ":generate_box_proposals_op", ":non_max_suppression_op", ":random_crop_op", ":resize_area_op", @@ -2739,6 +2740,12 @@ tf_kernel_library( deps = IMAGE_DEPS, ) +tf_kernel_library( + name = "generate_box_proposals_op", + prefix = "generate_box_proposals_op", + deps = [":non_max_suppression_op"] + if_cuda(["@cub_archive//:cub"]), +) + tf_kernel_library( name = "non_max_suppression_op", prefix = "non_max_suppression_op", diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc new file mode 100644 index 00000000000..802ec7e139c --- /dev/null +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -0,0 +1,631 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An example Op. + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU + +#include +#include +#include "tensorflow/core/kernels/non_max_suppression_op.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/gpu_kernel_helper.h" +#include "tensorflow/core/util/gpu_launch_config.h" +#include "third_party/cub/device/device_radix_sort.cuh" +#include "third_party/cub/device/device_segmented_radix_sort.cuh" +#include "third_party/cub/device/device_select.cuh" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +namespace { + +__global__ void GeneratePreNMSUprightBoxesKernel( + const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys, + const float4* d_bbox_deltas, const float4* d_anchors, const int H, + const int W, const int A, const float min_size, const float* d_img_info_vec, + const float bbox_xform_clip, const bool correct_transform, + float4* d_out_boxes, + const int prenms_nboxes, // leading dimension of out_boxes + float* d_inout_scores, char* d_boxes_keep_flags) { + const int K = H * W; + const int WA = W * A; + const int KA = K * A; + int nboxes_to_generate = config.virtual_thread_count.x; + int num_images = config.virtual_thread_count.y; + int num_true = 0; + CUDA_AXIS_KERNEL_LOOP(image_index, config.virtual_thread_count.y, Y) { + CUDA_AXIS_KERNEL_LOOP(ibox, config.virtual_thread_count.x, X) { + // CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, + // num_images){ { box_conv_index : # of the same box, but indexed in the + // scores from the conv layer, of shape (A,H,W) the num_images dimension + // was already removed box_conv_index = a*K + h*W + w + const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox]; + + // We want to decompose box_conv_index in (h,w,a) + // such as box_conv_index = h*W*A + W*A + a + // (avoiding modulos in the process) + int remaining = box_conv_index; + const int dH = WA; // stride of H + const int h = remaining / dH; + remaining -= h * dH; + const int dW = A; // stride of H + const int w = remaining / dW; + remaining -= w * dW; + const int a = remaining; // dA = 1 + // Loading the anchor a + // float4 is a struct with float x,y,z,w + const float4 anchor = d_anchors[box_conv_index]; + // x1,y1,x2,y2 :coordinates of anchor a, shifted for position (h,w) + float x1 = anchor.y; + float x2 = anchor.w; + float y1 = anchor.x; + float y2 = anchor.z; + + // TODO use fast math when possible + + // Deltas of shape (N,H,W,A4) + int deltas_idx = box_conv_index + image_index * KA; + float4 deltas = d_bbox_deltas[deltas_idx]; + float dx = deltas.y; + float dy = deltas.x; + float dw = deltas.w; + float dh = deltas.z; + // printf("deltas_idx=%d dx=%f, dy=%f, dw=%f, + // dh=%f\n",deltas_idx,dx,dy,dw,dh); + // Upper bound on dw,dh + dw = fmin(dw, bbox_xform_clip); + dh = fmin(dh, bbox_xform_clip); + + // Applying the deltas + float width = x2 - x1 + 1.0f; + const float ctr_x = x1 + 0.5f * width; + const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd + const float pred_w = width * expf(dw); + x1 = pred_ctr_x - 0.5f * pred_w; + x2 = pred_ctr_x + 0.5f * pred_w; + + float height = y2 - y1 + 1.0f; + const float ctr_y = y1 + 0.5f * height; + const float pred_ctr_y = ctr_y + height * dy; + const float pred_h = height * expf(dh); + y1 = pred_ctr_y - 0.5f * pred_h; + y2 = pred_ctr_y + 0.5f * pred_h; + + if (correct_transform) { + x2 -= 1.0f; + y2 -= 1.0f; + } + // const float y2_old=y2; + // const float x2_old=x2; + // const float x1_old=x1; + // const float y1_old=y1; + // Clipping box to image + const float img_height = d_img_info_vec[5 * image_index + 0]; + const float img_width = d_img_info_vec[5 * image_index + 1]; + const float min_size_scaled = + min_size * d_img_info_vec[5 * image_index + 2]; + // min_size * d_img_info_vec[3 * image_index + 2]; + x1 = fmax(fmin(x1, img_width - 1.0f), 0.0f); + y1 = fmax(fmin(y1, img_height - 1.0f), 0.0f); + x2 = fmax(fmin(x2, img_width - 1.0f), 0.0f); + y2 = fmax(fmin(y2, img_height - 1.0f), 0.0f); + + // Filter boxes + // Removing boxes with one dim < min_size + // (center of box is in image, because of previous step) + width = x2 - x1 + 1.0f; // may have changed + height = y2 - y1 + 1.0f; + bool keep_box = fmin(width, height) >= min_size_scaled; + + // We are not deleting the box right now even if !keep_box + // we want to keep the relative order of the elements stable + // we'll do it in such a way later + // d_boxes_keep_flags size: (num_images,prenms_nboxes) + // d_out_boxes size: (num_images,prenms_nboxes) + const int out_index = image_index * prenms_nboxes + ibox; + + d_boxes_keep_flags[out_index] = keep_box; + d_out_boxes[out_index] = {x1, y1, x2, y2}; + // if(keep_box)printf("Has keep box %d\n",image_index); + // d_inout_scores size: (num_images,KA) + if (!keep_box) + d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS + } + } +} + +// Copy the selected boxes and scores to output tensors. +// +__global__ void WriteUprightBoxesOutput( + const CudaLaunchConfig nboxes, const float4* d_image_boxes, + const float* d_image_scores, const int* d_image_boxes_keep_list, + const int n_rois, float* d_image_out_rois, float* d_image_out_rois_probs) { + CUDA_1D_KERNEL_LOOP(i, nboxes.virtual_thread_count) { + if (i < n_rois) { // copy rois to output + const int ibox = d_image_boxes_keep_list[i]; + const float4 box = d_image_boxes[ibox]; + const float score = d_image_scores[ibox]; + // Scattered memory accesses + // postnms_nboxes is small anyway + d_image_out_rois_probs[i] = score; + const int base_idx = 4 * i; + d_image_out_rois[base_idx + 0] = box.y; + d_image_out_rois[base_idx + 1] = box.x; + d_image_out_rois[base_idx + 2] = box.w; + d_image_out_rois[base_idx + 3] = box.z; + } else { // set trailing entries to 0 + d_image_out_rois_probs[i] = 0.; + const int base_idx = 4 * i; + d_image_out_rois[base_idx + 0] = 0.; + d_image_out_rois[base_idx + 1] = 0.; + d_image_out_rois[base_idx + 2] = 0.; + d_image_out_rois[base_idx + 3] = 0.; + } + } +} + +// Allocate scratch spaces that are needed for operation +// + +Status AllocateGenerationTempTensors( + OpKernelContext* context, Tensor* d_conv_layer_indexes, + Tensor* d_image_offset, Tensor* d_cub_sort_buffer, + Tensor* d_cub_select_buffer, Tensor* d_sorted_conv_layer_indexes, + Tensor* d_sorted_scores, Tensor* dev_boxes, Tensor* dev_boxes_keep_flags, + int num_images, int conv_layer_nboxes, size_t cub_sort_temp_storage_bytes, + size_t cub_select_temp_storage_bytes, int nboxes_to_generate, int box_dim) { + auto d = context->eigen_gpu_device(); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), + d_conv_layer_indexes)); + CudaLaunchConfig zconfig = + GetCudaLaunchConfig(d_conv_layer_indexes->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*d_conv_layer_indexes).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_images + 1}), d_image_offset)); + zconfig = GetCudaLaunchConfig(d_image_offset->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*d_image_offset).flat().data()); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}), + d_cub_sort_buffer)); + zconfig = GetCudaLaunchConfig(d_cub_sort_buffer->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*d_cub_sort_buffer).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)cub_select_temp_storage_bytes}), + d_cub_select_buffer)); + zconfig = GetCudaLaunchConfig(d_cub_select_buffer->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*d_cub_select_buffer).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), + d_sorted_conv_layer_indexes)); + zconfig = GetCudaLaunchConfig(d_sorted_conv_layer_indexes->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*d_sorted_conv_layer_indexes).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_images, conv_layer_nboxes}), + d_sorted_scores)); + zconfig = GetCudaLaunchConfig(d_sorted_scores->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*d_sorted_scores).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, + TensorShape({num_images, box_dim * nboxes_to_generate}), dev_boxes)); + zconfig = GetCudaLaunchConfig(dev_boxes->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*dev_boxes).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({num_images, nboxes_to_generate}), + dev_boxes_keep_flags)); + zconfig = GetCudaLaunchConfig(dev_boxes_keep_flags->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*dev_boxes_keep_flags).flat().data()); + + return Status::OK(); +} + +// Allocate workspace for NMS operation +Status AllocatePreNMSTempTensors( + OpKernelContext* context, Tensor* dev_image_prenms_boxes, + Tensor* dev_image_prenms_scores, Tensor* dev_image_boxes_keep_list, + Tensor* dev_postnms_rois, Tensor* dev_postnms_rois_probs, + Tensor* dev_prenms_nboxes, Tensor* dev_nms_mask, Tensor* host_nms_mask, + int num_images, int nboxes_to_generate, int box_dim, int post_nms_topn, + int pre_nms_topn) { + auto d = context->eigen_gpu_device(); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({box_dim * nboxes_to_generate}), + dev_image_prenms_boxes)); + CudaLaunchConfig zconfig = + GetCudaLaunchConfig(dev_image_prenms_boxes->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*dev_image_prenms_boxes).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_FLOAT, + TensorShape({nboxes_to_generate}), + dev_image_prenms_scores)); + + zconfig = GetCudaLaunchConfig(dev_image_prenms_scores->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*dev_image_prenms_scores).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({nboxes_to_generate}), + dev_image_boxes_keep_list)); + zconfig = GetCudaLaunchConfig(dev_image_boxes_keep_list->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*dev_image_boxes_keep_list).flat().data()); + + const int max_postnms_nboxes = std::min(nboxes_to_generate, post_nms_topn); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, + TensorShape({box_dim * num_images * max_postnms_nboxes}), + dev_postnms_rois)); + zconfig = GetCudaLaunchConfig(dev_postnms_rois->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*dev_postnms_rois).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_images * max_postnms_nboxes}), + dev_postnms_rois_probs)); + zconfig = GetCudaLaunchConfig(dev_postnms_rois_probs->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, + (*dev_postnms_rois_probs).flat().data()); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_images}), dev_prenms_nboxes)); + zconfig = GetCudaLaunchConfig(dev_prenms_nboxes->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*dev_prenms_nboxes).flat().data()); + int64 max_nms_mask_size = + pre_nms_topn * + ((pre_nms_topn + NMS_BOXES_PER_THREAD - 1) / NMS_BOXES_PER_THREAD); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({max_nms_mask_size}), dev_nms_mask)); + + zconfig = GetCudaLaunchConfig(dev_nms_mask->NumElements(), d); + SetZero<<>>( + zconfig.virtual_thread_count, (*dev_nms_mask).flat().data()); + + AllocatorAttributes alloc_attr; + alloc_attr.set_on_host(true); + alloc_attr.set_gpu_compatible(true); + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({max_nms_mask_size}), + host_nms_mask, alloc_attr)); + return Status::OK(); +} + +// Initialize index and offset arrays. +// num_images is the batch size, KA is the number of anchors +__global__ void InitializeDataKernel(const Cuda2DLaunchConfig config, + int* d_image_offsets, + int* d_boxes_keys_iota) { + const int KA = config.virtual_thread_count.x; + const int num_images = config.virtual_thread_count.y; + // printf("num_images %d KA %d\n",num_images,KA); + CUDA_AXIS_KERNEL_LOOP(img_idx, config.virtual_thread_count.y, Y) { + CUDA_AXIS_KERNEL_LOOP(box_idx, config.virtual_thread_count.x, X) { + // CUDA_2D_KERNEL_LOOP(box_idx, KA, img_idx, num_images) { + d_boxes_keys_iota[img_idx * KA + box_idx] = box_idx; + + // One 1D line sets the 1D data + if (box_idx == 0) { + d_image_offsets[img_idx] = KA * img_idx; + // One thread sets the last+1 offset + if (img_idx == 0) d_image_offsets[num_images] = KA * num_images; + } + } + } +} + +} // namespace + +class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { + public: + explicit GenerateBoundingBoxProposals( + tensorflow::OpKernelConstruction* context) + : AsyncOpKernel(context) { + // OP_REQUIRES_OK(context, context->GetAttr("spatial_scale", + // &spatial_scale_)); feat_stride_ = 1.0 / spatial_scale_; + OP_REQUIRES_OK(context, context->GetAttr("pre_nms_topn", &pre_nms_topn_)); + OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_)); + OP_REQUIRES_OK(context, context->GetAttr("nms_threshold", &nms_threshold_)); + OP_REQUIRES_OK(context, context->GetAttr("min_size", &min_size_)); + OP_REQUIRES_OK(context, context->GetAttr("debug", &debug_)); + // compatibility for detectron like networks. False for generic case + OP_REQUIRES_OK(context, context->GetAttr("correct_transform_coords", + &correct_transform_coords_)); + CHECK_GT(pre_nms_topn_, 0); + CHECK_GT(post_nms_topn_, 0); + CHECK_GT(nms_threshold_, 0); + CHECK_GE(min_size_, 0); + bbox_xform_clip_default_ = log(1000.0 / 16.); + } + + void ComputeAsync(tensorflow::OpKernelContext* context, + DoneCallback done) override { + // .Input("scores: float") + // .Input("bbox_deltas: float") + // .Input("image_info: float") + // .Input("anchors: float") + + const auto scores = context->input(0); + const auto bbox_deltas = context->input(1); + const auto image_info = context->input(2); + const auto anchors = context->input(3); + const auto num_images = scores.dim_size(0); + const auto A = scores.dim_size(3); + const auto H = scores.dim_size(1); + const auto W = scores.dim_size(2); + const auto box_dim = anchors.dim_size(0) / A; + CHECK_EQ(box_dim, 4); + // TODO(skama): make sure that inputs are ok. + const int K = H * W; + // VLOG(0)<<"num_images="<(d_image_boxes), + d_image_boxes_keep_flags, + reinterpret_cast(d_image_prenms_boxes), d_prenms_nboxes, + nboxes_generated, d.stream()); + CHECK_EQ(cuda_ret, CUDA_SUCCESS); + cuda_ret = cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, cub_select_temp_storage_bytes, + d_image_sorted_scores, d_image_boxes_keep_flags, + d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, d.stream()); + CHECK_EQ(cuda_ret, CUDA_SUCCESS); + d.memcpyDeviceToHost(&h_prenms_nboxes, d_prenms_nboxes, sizeof(int)); + d.synchronize(); + + // We know prenms_boxes <= topN_prenms, because nboxes_generated <= + // topN_prenms. Calling NMS on the generated boxes + const int prenms_nboxes = h_prenms_nboxes; + // printf("Host boxes=%d ngen=%d\n",h_prenms_nboxes,nboxes_generated); + int nkeep; + // printf("Before nms\n"); + nms_gpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold_, + d_image_boxes_keep_list, &nkeep, d_nms_mask, h_nms_mask, context); + CHECK_EQ(cudaGetLastError(), CUDA_SUCCESS); + // printf("After nms nkeep=%d\n",nkeep); + // All operations done after previous sort were keeping the relative order + // of the elements the elements are still sorted keep topN <=> truncate + // the array + const int postnms_nboxes = std::min(nkeep, post_nms_topn_); + // Moving the out boxes to the output tensors, + // adding the image_index dimension on the fly + CudaLaunchConfig config = GetCudaLaunchConfig(post_nms_topn_, d); + // make this single kernel + WriteUprightBoxesOutput<<>>( + config, reinterpret_cast(d_image_prenms_boxes), + d_image_prenms_scores, d_image_boxes_keep_list, postnms_nboxes, + d_image_postnms_rois, d_image_postnms_rois_probs); + nrois_in_output += postnms_nboxes; + CHECK_EQ(cudaGetLastError(), CUDA_SUCCESS); + } + done(); + } + + private: + int pre_nms_topn_; + int post_nms_topn_; + float nms_threshold_; + float min_size_; + float feat_stride_; + float bbox_xform_clip_default_; + bool correct_transform_coords_; + bool debug_; +}; + +REGISTER_KERNEL_BUILDER( + Name("GenerateBoundingBoxProposals").Device(tensorflow::DEVICE_GPU), + tensorflow::GenerateBoundingBoxProposals); +} // namespace tensorflow +#endif \ No newline at end of file diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 3dd37bd97ce..763ac91cef4 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -925,4 +925,54 @@ REGISTER_OP("CombinedNonMaxSuppression") .Attr("clip_boxes: bool = true") .SetShapeFn(CombinedNMSShapeFn); +REGISTER_OP("GenerateBoundingBoxProposals") + .Input("scores: float") + .Input("bbox_deltas: float") + .Input("image_info: float") + .Input("anchors: float") + .Output("rois: float") + .Output("roi_probabilities: float") + .Attr("pre_nms_topn: int = 6000") + .Attr("post_nms_topn: int = 300") + .Attr("nms_threshold: float = 0.7") + .Attr("min_size: float = 16") + .Attr("debug: bool = false") + .Attr("correct_transform_coords: bool = true") + .SetShapeFn([](InferenceContext* c) -> Status { + // make sure input tensors have are correct rank + ShapeHandle scores, images, bounding_boxes, anchors; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &scores)); //(N, H, W, A) + TF_RETURN_IF_ERROR( + c->WithRank(c->input(1), 4, &bounding_boxes)); //(N,H,W,A4) + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &images)); // (N,5) + auto im_info = c->Dim(images, 1); + TF_RETURN_IF_ERROR(c->WithValue(im_info, 5, &im_info)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 3, &anchors)); // (A4) + // TODO(skama): verify that the inputs are compatible + int post_nms_top_n; + TF_RETURN_IF_ERROR(c->GetAttr("post_nms_topn", &post_nms_top_n)); + auto roi_shape = c->MakeShape( + {c->Dim(scores, 0), post_nms_top_n, 4}); //(N,post_nms_top_n,4) + auto prob_shape = c->MakeShape( + {c->Dim(scores, 0), post_nms_top_n}); // (N,post_nms_top_n) + c->set_output(0, roi_shape); + c->set_output(1, prob_shape); + return Status::OK(); + }) + .Doc(R"doc( + This op produces Region of Interests from given bounding boxes(bbox_deltas) encoded wrt + anchors according to eq.2 in arXiv:1506.01497 + The op selects top pre_nms_topn scoring boxes, decodes them with respect to anchors, + applies non-maximal suppression on overlapping boxes with higher than + nms_threshold intersection-over-union (iou) value, discarding boxes where shorter + side is less than min_size. + + scores: A 4D tensor of shape [Batch, Height, Width, Num Anchors] containing the scores per anchor at given postion + bbox_deltas: is a tensor of shape [Batch, Height, Width, 4 x Num Anchors] boxes encoded to each anchor + anchors: A 1D tensor of shape [4 x Num Anchors], representing the anchors. + + rois: output RoIs, a 3D tensor of shape [Batch, post_nms_topn, 4], padded by 0 if less than post_nms_topn candidates found. + roi_probabilities: probability scores of each roi in 'rois', a 2D tensor of shape [Batch,post_nms_topn], padded with 0 if needed. + + )doc"); } // namespace tensorflow From 34d8fab68a706b36f6678367268e5540799fc967 Mon Sep 17 00:00:00 2001 From: Sami Kama Date: Wed, 15 May 2019 17:40:23 -0700 Subject: [PATCH 02/17] Update api_def files and add api_def for the op --- .../api_def_GenerateBoundingBoxProposals.pbtxt | 17 +++++++++++++++++ tensorflow/core/ops/image_ops.cc | 18 +----------------- .../tools/api/golden/v1/tensorflow.pbtxt | 4 ++++ .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 ++++ .../tools/api/golden/v2/tensorflow.pbtxt | 4 ++++ .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 ++++ 6 files changed, 34 insertions(+), 17 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt new file mode 100644 index 00000000000..b560095ae89 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt @@ -0,0 +1,17 @@ +op { + graph_op_name: "GenerateBoundingBoxProposals" + summary: "This op produces Region of Interests from given bounding boxes(bbox_deltas) encoded wrtanchors according to eq.2 in arXiv:1506.01497" + description: <set_output(0, roi_shape); c->set_output(1, prob_shape); return Status::OK(); - }) - .Doc(R"doc( - This op produces Region of Interests from given bounding boxes(bbox_deltas) encoded wrt - anchors according to eq.2 in arXiv:1506.01497 - The op selects top pre_nms_topn scoring boxes, decodes them with respect to anchors, - applies non-maximal suppression on overlapping boxes with higher than - nms_threshold intersection-over-union (iou) value, discarding boxes where shorter - side is less than min_size. - - scores: A 4D tensor of shape [Batch, Height, Width, Num Anchors] containing the scores per anchor at given postion - bbox_deltas: is a tensor of shape [Batch, Height, Width, 4 x Num Anchors] boxes encoded to each anchor - anchors: A 1D tensor of shape [4 x Num Anchors], representing the anchors. - - rois: output RoIs, a 3D tensor of shape [Batch, post_nms_topn, 4], padded by 0 if less than post_nms_topn candidates found. - roi_probabilities: probability scores of each roi in 'rois', a 2D tensor of shape [Batch,post_nms_topn], padded with 0 if needed. - - )doc"); + }); } // namespace tensorflow diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 091cc04357e..14ef81917dc 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1296,6 +1296,10 @@ tf_module { name: "gather_nd" argspec: "args=[\'params\', \'indices\', \'name\', \'batch_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " } + member_method { + name: "generate_bounding_box_proposals" + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + } member_method { name: "get_collection" argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 12e668952bc..4a81413e4fb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1432,6 +1432,10 @@ tf_module { name: "GatherV2" argspec: "args=[\'params\', \'indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GenerateBoundingBoxProposals" + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + } member_method { name: "GenerateVocabRemapping" argspec: "args=[\'new_vocab_file\', \'old_vocab_file\', \'new_vocab_offset\', \'num_new_vocab\', \'old_vocab_size\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 656d026cb63..b4d8ebd865b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -632,6 +632,10 @@ tf_module { name: "gather_nd" argspec: "args=[\'params\', \'indices\', \'batch_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " } + member_method { + name: "generate_bounding_box_proposals" + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + } member_method { name: "get_logger" argspec: "args=[], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 12e668952bc..4a81413e4fb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1432,6 +1432,10 @@ tf_module { name: "GatherV2" argspec: "args=[\'params\', \'indices\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "GenerateBoundingBoxProposals" + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + } member_method { name: "GenerateVocabRemapping" argspec: "args=[\'new_vocab_file\', \'old_vocab_file\', \'new_vocab_offset\', \'num_new_vocab\', \'old_vocab_size\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " From 1ccc9c64bf54a5b87aa80fbace46e35b70da6493 Mon Sep 17 00:00:00 2001 From: Sami Date: Mon, 1 Jul 2019 16:41:25 -0700 Subject: [PATCH 03/17] Fixes for review --- ...api_def_GenerateBoundingBoxProposals.pbtxt | 71 +++- tensorflow/core/kernels/BUILD | 2 +- .../kernels/generate_box_proposals_op.cu.cc | 334 ++++++++---------- tensorflow/core/ops/image_ops.cc | 15 +- tensorflow/python/ops/image_ops_impl.py | 10 + .../api/golden/v1/tensorflow.image.pbtxt | 4 + .../tools/api/golden/v1/tensorflow.pbtxt | 2 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../api/golden/v2/tensorflow.image.pbtxt | 4 + .../tools/api/golden/v2/tensorflow.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 11 files changed, 257 insertions(+), 191 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt index b560095ae89..31796cbf151 100644 --- a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt @@ -1,6 +1,75 @@ op { graph_op_name: "GenerateBoundingBoxProposals" - summary: "This op produces Region of Interests from given bounding boxes(bbox_deltas) encoded wrtanchors according to eq.2 in arXiv:1506.01497" + in_arg { + name: "scores" + description: <SetStatus(errors::Internal("Cuda call failed with", \ + cudaGetErrorString(error))); \ + done(); \ + return; \ + } \ + } while (0) namespace { - +// Decode d_bbox_deltas with respect to anchors into absolute coordinates, +// clipping if necessary. +// H is height, W is width, A is the number of anchors. +// prenms_nboxes maximum number of boxes per image to decode. +// d_boxes_keep_flags mask for boxes to consider in NMS. +// min_size is the lower bound of the shortest edge for the boxes to consider. +// bbox_xform_clip is the upper bound of encoded width and height. +// correct_transform is a flag to apply coordinate transformation correction to +// x2 and y2. __global__ void GeneratePreNMSUprightBoxesKernel( const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys, const float4* d_bbox_deltas, const float4* d_anchors, const int H, @@ -49,17 +67,14 @@ __global__ void GeneratePreNMSUprightBoxesKernel( float4* d_out_boxes, const int prenms_nboxes, // leading dimension of out_boxes float* d_inout_scores, char* d_boxes_keep_flags) { - const int K = H * W; - const int WA = W * A; - const int KA = K * A; - int nboxes_to_generate = config.virtual_thread_count.x; - int num_images = config.virtual_thread_count.y; - int num_true = 0; + // constants to calculate offsets in to the input and output arrays. + const int K = H * W; // Stride of Anchor + const int WA = W * A; // Stride of H + const int KA = K * A; // Stride of image CUDA_AXIS_KERNEL_LOOP(image_index, config.virtual_thread_count.y, Y) { CUDA_AXIS_KERNEL_LOOP(ibox, config.virtual_thread_count.x, X) { - // CUDA_2D_KERNEL_LOOP(ibox, nboxes_to_generate, image_index, - // num_images){ { box_conv_index : # of the same box, but indexed in the - // scores from the conv layer, of shape (A,H,W) the num_images dimension + // box_conv_index : # of the same box, but indexed in the + // scores from the conv layer, of shape (H,W,A) the num_images dimension // was already removed box_conv_index = a*K + h*W + w const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox]; @@ -70,10 +85,9 @@ __global__ void GeneratePreNMSUprightBoxesKernel( const int dH = WA; // stride of H const int h = remaining / dH; remaining -= h * dH; - const int dW = A; // stride of H + const int dW = A; // stride of W const int w = remaining / dW; remaining -= w * dW; - const int a = remaining; // dA = 1 // Loading the anchor a // float4 is a struct with float x,y,z,w const float4 anchor = d_anchors[box_conv_index]; @@ -148,7 +162,6 @@ __global__ void GeneratePreNMSUprightBoxesKernel( d_boxes_keep_flags[out_index] = keep_box; d_out_boxes[out_index] = {x1, y1, x2, y2}; - // if(keep_box)printf("Has keep box %d\n",image_index); // d_inout_scores size: (num_images,KA) if (!keep_box) d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS @@ -186,6 +199,13 @@ __global__ void WriteUprightBoxesOutput( } } +template +void ResetTensor(Tensor* t, const Eigen::GpuDevice& d) { + CudaLaunchConfig zconfig = GetCudaLaunchConfig(t->NumElements(), d); + TF_CHECK_OK(GpuLaunchKernel( + SetZero, zconfig.block_count, zconfig.thread_per_block, 0, d.stream(), + zconfig.virtual_thread_count, (*t).flat().data())); +} // Allocate scratch spaces that are needed for operation // @@ -200,61 +220,34 @@ Status AllocateGenerationTempTensors( TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), d_conv_layer_indexes)); - CudaLaunchConfig zconfig = - GetCudaLaunchConfig(d_conv_layer_indexes->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*d_conv_layer_indexes).flat().data()); - + ResetTensor(d_conv_layer_indexes, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images + 1}), d_image_offset)); - zconfig = GetCudaLaunchConfig(d_image_offset->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*d_image_offset).flat().data()); + ResetTensor(d_image_offset, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}), d_cub_sort_buffer)); - zconfig = GetCudaLaunchConfig(d_cub_sort_buffer->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*d_cub_sort_buffer).flat().data()); - + ResetTensor(d_cub_sort_buffer, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT8, TensorShape({(int64)cub_select_temp_storage_bytes}), d_cub_select_buffer)); - zconfig = GetCudaLaunchConfig(d_cub_select_buffer->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*d_cub_select_buffer).flat().data()); - + ResetTensor(d_cub_select_buffer, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), d_sorted_conv_layer_indexes)); - zconfig = GetCudaLaunchConfig(d_sorted_conv_layer_indexes->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*d_sorted_conv_layer_indexes).flat().data()); - + ResetTensor(d_sorted_conv_layer_indexes, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_images, conv_layer_nboxes}), d_sorted_scores)); - zconfig = GetCudaLaunchConfig(d_sorted_scores->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*d_sorted_scores).flat().data()); - + ResetTensor(d_sorted_scores, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_images, box_dim * nboxes_to_generate}), dev_boxes)); - zconfig = GetCudaLaunchConfig(dev_boxes->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*dev_boxes).flat().data()); - + ResetTensor(dev_boxes, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT8, TensorShape({num_images, nboxes_to_generate}), dev_boxes_keep_flags)); - zconfig = GetCudaLaunchConfig(dev_boxes_keep_flags->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*dev_boxes_keep_flags).flat().data()); - + ResetTensor(dev_boxes_keep_flags, d); return Status::OK(); } @@ -263,75 +256,40 @@ Status AllocatePreNMSTempTensors( OpKernelContext* context, Tensor* dev_image_prenms_boxes, Tensor* dev_image_prenms_scores, Tensor* dev_image_boxes_keep_list, Tensor* dev_postnms_rois, Tensor* dev_postnms_rois_probs, - Tensor* dev_prenms_nboxes, Tensor* dev_nms_mask, Tensor* host_nms_mask, - int num_images, int nboxes_to_generate, int box_dim, int post_nms_topn, - int pre_nms_topn) { + Tensor* dev_prenms_nboxes, int num_images, int num_boxes_to_generate, + int box_dim, int post_nms_topn, int pre_nms_topn) { auto d = context->eigen_gpu_device(); TF_RETURN_IF_ERROR(context->allocate_temp( - DataType::DT_FLOAT, TensorShape({box_dim * nboxes_to_generate}), + DataType::DT_FLOAT, TensorShape({box_dim * num_boxes_to_generate}), dev_image_prenms_boxes)); - CudaLaunchConfig zconfig = - GetCudaLaunchConfig(dev_image_prenms_boxes->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*dev_image_prenms_boxes).flat().data()); + ResetTensor(dev_image_prenms_boxes, d); - TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_FLOAT, - TensorShape({nboxes_to_generate}), - dev_image_prenms_scores)); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_boxes_to_generate}), + dev_image_prenms_scores)); + ResetTensor(dev_image_prenms_scores, d); - zconfig = GetCudaLaunchConfig(dev_image_prenms_scores->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*dev_image_prenms_scores).flat().data()); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes_to_generate}), + dev_image_boxes_keep_list)); + ResetTensor(dev_image_boxes_keep_list, d); - TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, - TensorShape({nboxes_to_generate}), - dev_image_boxes_keep_list)); - zconfig = GetCudaLaunchConfig(dev_image_boxes_keep_list->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*dev_image_boxes_keep_list).flat().data()); - - const int max_postnms_nboxes = std::min(nboxes_to_generate, post_nms_topn); + const int max_postnms_nboxes = std::min(num_boxes_to_generate, post_nms_topn); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({box_dim * num_images * max_postnms_nboxes}), dev_postnms_rois)); - zconfig = GetCudaLaunchConfig(dev_postnms_rois->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*dev_postnms_rois).flat().data()); + ResetTensor(dev_postnms_rois, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_images * max_postnms_nboxes}), dev_postnms_rois_probs)); - zconfig = GetCudaLaunchConfig(dev_postnms_rois_probs->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, - (*dev_postnms_rois_probs).flat().data()); + ResetTensor(dev_postnms_rois_probs, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images}), dev_prenms_nboxes)); - zconfig = GetCudaLaunchConfig(dev_prenms_nboxes->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*dev_prenms_nboxes).flat().data()); - int64 max_nms_mask_size = - pre_nms_topn * - ((pre_nms_topn + NMS_BOXES_PER_THREAD - 1) / NMS_BOXES_PER_THREAD); + ResetTensor(dev_prenms_nboxes, d); - TF_RETURN_IF_ERROR(context->allocate_temp( - DataType::DT_INT32, TensorShape({max_nms_mask_size}), dev_nms_mask)); - - zconfig = GetCudaLaunchConfig(dev_nms_mask->NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, (*dev_nms_mask).flat().data()); - - AllocatorAttributes alloc_attr; - alloc_attr.set_on_host(true); - alloc_attr.set_gpu_compatible(true); - TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, - TensorShape({max_nms_mask_size}), - host_nms_mask, alloc_attr)); return Status::OK(); } @@ -365,23 +323,26 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { explicit GenerateBoundingBoxProposals( tensorflow::OpKernelConstruction* context) : AsyncOpKernel(context) { - // OP_REQUIRES_OK(context, context->GetAttr("spatial_scale", - // &spatial_scale_)); feat_stride_ = 1.0 / spatial_scale_; - OP_REQUIRES_OK(context, context->GetAttr("pre_nms_topn", &pre_nms_topn_)); OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_)); - OP_REQUIRES_OK(context, context->GetAttr("nms_threshold", &nms_threshold_)); - OP_REQUIRES_OK(context, context->GetAttr("min_size", &min_size_)); - OP_REQUIRES_OK(context, context->GetAttr("debug", &debug_)); // compatibility for detectron like networks. False for generic case OP_REQUIRES_OK(context, context->GetAttr("correct_transform_coords", &correct_transform_coords_)); - CHECK_GT(pre_nms_topn_, 0); CHECK_GT(post_nms_topn_, 0); - CHECK_GT(nms_threshold_, 0); - CHECK_GE(min_size_, 0); bbox_xform_clip_default_ = log(1000.0 / 16.); } + template + Status GetScalarValue(OpKernelContext* context, int input, T* value) { + const Tensor& scalar_tensor = context->input(input); + if (!TensorShapeUtils::IsScalar(scalar_tensor.shape())) { + return errors::InvalidArgument("Expected a scalar in input ", input, + "but got shape ", + scalar_tensor.shape().DebugString()); + } + *value = scalar_tensor.scalar()(); + return Status::OK(); + } + void ComputeAsync(tensorflow::OpKernelContext* context, DoneCallback done) override { // .Input("scores: float") @@ -407,6 +368,27 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { // The following calls to CUB primitives do nothing // (because the first arg is nullptr) // except setting cub_*_temp_storage_bytes + float nms_threshold; + int pre_nms_topn; + float min_size; + OP_REQUIRES_OK_ASYNC(context, GetScalarValue(context, 4, &nms_threshold), + done); + if (nms_threshold < 0 || nms_threshold > 1.0) { + context->SetStatus(errors::InvalidArgument( + "nms_threshold should be between 0 and 1. Got ", nms_threshold)); + done(); + return; + } + OP_REQUIRES_OK_ASYNC(context, GetScalarValue(context, 5, &pre_nms_topn), + done); + if (pre_nms_topn <= 0) { + context->SetStatus(errors::InvalidArgument( + "pre_nms_topn should be greater than 0", pre_nms_topn)); + done(); + return; + } + + OP_REQUIRES_OK_ASYNC(context, GetScalarValue(context, 6, &min_size), done); auto cuda_stream = GetCudaStream(context); size_t cub_sort_temp_storage_bytes = 0; float* flt_ptr = nullptr; @@ -421,10 +403,12 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { size_t cub_select_temp_storage_bytes = 0; char* char_ptr = nullptr; float4* f4_ptr = nullptr; - cuda_ret = cub::DeviceSelect::Flagged( - nullptr, cub_select_temp_storage_bytes, f4_ptr, char_ptr, f4_ptr, - int_ptr, K * A, cuda_stream); - CHECK_EQ(cuda_ret, CUDA_SUCCESS); + TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( + context, + cub::DeviceSelect::Flagged(nullptr, cub_select_temp_storage_bytes, + f4_ptr, char_ptr, f4_ptr, int_ptr, K * A, + cuda_stream), + done); Tensor d_conv_layer_indexes; // box indices on device Tensor d_image_offset; // starting offsets boxes for each image Tensor d_cub_sort_buffer; // buffer for cub sorting @@ -435,7 +419,7 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { Tensor dev_boxes; // boxes on device Tensor dev_boxes_keep_flags; // bitmask for keeping the boxes or rejecting // from output - const int nboxes_to_generate = std::min(conv_layer_nboxes, pre_nms_topn_); + const int nboxes_to_generate = std::min(conv_layer_nboxes, pre_nms_topn); OP_REQUIRES_OK_ASYNC( context, AllocateGenerationTempTensors( @@ -449,41 +433,43 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { Cuda2DLaunchConfig conf2d = GetCuda2DLaunchConfig(conv_layer_nboxes, num_images, d); // create box indices and offsets for each image on device - InitializeDataKernel<<>>( - conf2d, d_image_offset.flat().data(), - d_conv_layer_indexes.flat().data()); + TF_CHECK_OK(GpuLaunchKernel(InitializeDataKernel, conf2d.block_count, + conf2d.thread_per_block, 0, d.stream(), conf2d, + d_image_offset.flat().data(), + d_conv_layer_indexes.flat().data())); // sort boxes with their scores. // d_sorted_conv_layer_indexes will hold the pointers to old indices. - cuda_ret = cub::DeviceSegmentedRadixSort::SortPairsDescending( - d_cub_sort_buffer.flat().data(), cub_sort_temp_storage_bytes, - scores.flat().data(), dev_sorted_scores.flat().data(), - d_conv_layer_indexes.flat().data(), - d_sorted_conv_layer_indexes.flat().data(), - num_images * conv_layer_nboxes, num_images, - d_image_offset.flat().data(), - d_image_offset.flat().data() + 1, 0, - 8 * sizeof(float), // sort all bits - cuda_stream); + TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( + context, + cub::DeviceSegmentedRadixSort::SortPairsDescending( + d_cub_sort_buffer.flat().data(), cub_sort_temp_storage_bytes, + scores.flat().data(), dev_sorted_scores.flat().data(), + d_conv_layer_indexes.flat().data(), + d_sorted_conv_layer_indexes.flat().data(), + num_images * conv_layer_nboxes, num_images, + d_image_offset.flat().data(), + d_image_offset.flat().data() + 1, 0, + 8 * sizeof(float), // sort all bits + cuda_stream), + done); // Keeping only the topN pre_nms - CHECK_EQ(cuda_ret, CUDA_SUCCESS); conf2d = GetCuda2DLaunchConfig(nboxes_to_generate, num_images, d); // create box y1,x1,y2,x2 from box_deltas and anchors (decode the boxes) and - // mark the boxes which are smaller that min_size_ ignored. - GeneratePreNMSUprightBoxesKernel<<< - conf2d.block_count, conf2d.thread_per_block, 0, d.stream()>>>( - conf2d, d_sorted_conv_layer_indexes.flat().data(), + // mark the boxes which are smaller that min_size ignored. + TF_CHECK_OK(GpuLaunchKernel( + GeneratePreNMSUprightBoxesKernel, conf2d.block_count, + conf2d.thread_per_block, 0, d.stream(), conf2d, + d_sorted_conv_layer_indexes.flat().data(), reinterpret_cast(bbox_deltas.flat().data()), reinterpret_cast(anchors.flat().data()), H, W, A, - min_size_, image_info.flat().data(), bbox_xform_clip_default_, + min_size, image_info.flat().data(), bbox_xform_clip_default_, correct_transform_coords_, reinterpret_cast(dev_boxes.flat().data()), nboxes_to_generate, dev_sorted_scores.flat().data(), - (char*)dev_boxes_keep_flags.flat().data()); - CHECK_EQ(cudaGetLastError(), CUDA_SUCCESS); + (char*)dev_boxes_keep_flags.flat().data())); const int nboxes_generated = nboxes_to_generate; const int roi_cols = box_dim; const int max_postnms_nboxes = std::min(nboxes_generated, post_nms_topn_); @@ -493,17 +479,14 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { Tensor dev_postnms_rois; Tensor dev_postnms_rois_probs; Tensor dev_prenms_nboxes; - Tensor dev_nms_mask; - Tensor host_nms_mask; // Allocate workspaces needed for NMS OP_REQUIRES_OK_ASYNC( context, AllocatePreNMSTempTensors( context, &dev_image_prenms_boxes, &dev_image_prenms_scores, &dev_image_boxes_keep_list, &dev_postnms_rois, - &dev_postnms_rois_probs, &dev_prenms_nboxes, &dev_nms_mask, - &host_nms_mask, num_images, nboxes_generated, box_dim, - this->post_nms_topn_, this->pre_nms_topn_), + &dev_postnms_rois_probs, &dev_prenms_nboxes, num_images, + nboxes_generated, box_dim, post_nms_topn_, pre_nms_topn), done); // get the pointers for temp storages int* d_prenms_nboxes = dev_prenms_nboxes.flat().data(); @@ -513,8 +496,7 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { float* d_image_prenms_boxes = dev_image_prenms_boxes.flat().data(); float* d_image_prenms_scores = dev_image_prenms_scores.flat().data(); int* d_image_boxes_keep_list = dev_image_boxes_keep_list.flat().data(); - int* h_nms_mask = host_nms_mask.flat().data(); - int* d_nms_mask = dev_nms_mask.flat().data(); + int nrois_in_output = 0; // get the pointers to boxes and scores char* d_boxes_keep_flags = (char*)dev_boxes_keep_flags.flat().data(); @@ -539,15 +521,9 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { float* d_postnms_rois_probs = (*output_roi_probs).flat().data(); // Do per-image nms - CudaLaunchConfig zconfig; for (int image_index = 0; image_index < num_images; ++image_index) { // reset output workspaces - zconfig = GetCudaLaunchConfig(dev_nms_mask.NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, d_nms_mask); - zconfig = GetCudaLaunchConfig(dev_image_boxes_keep_list.NumElements(), d); - SetZero<<>>( - zconfig.virtual_thread_count, d_image_boxes_keep_list); + ResetTensor(&dev_image_boxes_keep_list, d); // Sub matrices for current image // boxes const float* d_image_boxes = @@ -567,33 +543,35 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { // Moving valid boxes (ie the ones with d_boxes_keep_flags[ibox] == true) // to the output tensors - // printf("Host before flagged boxes=%d - // ngen=%d\n",h_prenms_nboxes,nboxes_generated); - cuda_ret = cub::DeviceSelect::Flagged( - d_cub_select_temp_storage, cub_select_temp_storage_bytes, - reinterpret_cast(d_image_boxes), - d_image_boxes_keep_flags, - reinterpret_cast(d_image_prenms_boxes), d_prenms_nboxes, - nboxes_generated, d.stream()); - CHECK_EQ(cuda_ret, CUDA_SUCCESS); - cuda_ret = cub::DeviceSelect::Flagged( - d_cub_select_temp_storage, cub_select_temp_storage_bytes, - d_image_sorted_scores, d_image_boxes_keep_flags, - d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, d.stream()); - CHECK_EQ(cuda_ret, CUDA_SUCCESS); + TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( + context, + cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, cub_select_temp_storage_bytes, + reinterpret_cast(d_image_boxes), + d_image_boxes_keep_flags, + reinterpret_cast(d_image_prenms_boxes), d_prenms_nboxes, + nboxes_generated, d.stream()), + done); + + TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( + context, + cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, cub_select_temp_storage_bytes, + d_image_sorted_scores, d_image_boxes_keep_flags, + d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, + d.stream()), + done); d.memcpyDeviceToHost(&h_prenms_nboxes, d_prenms_nboxes, sizeof(int)); d.synchronize(); - // We know prenms_boxes <= topN_prenms, because nboxes_generated <= // topN_prenms. Calling NMS on the generated boxes const int prenms_nboxes = h_prenms_nboxes; - // printf("Host boxes=%d ngen=%d\n",h_prenms_nboxes,nboxes_generated); int nkeep; - // printf("Before nms\n"); - nms_gpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold_, - d_image_boxes_keep_list, &nkeep, d_nms_mask, h_nms_mask, context); - CHECK_EQ(cudaGetLastError(), CUDA_SUCCESS); - // printf("After nms nkeep=%d\n",nkeep); + OP_REQUIRES_OK_ASYNC( + context, + NmsGpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold, + d_image_boxes_keep_list, &nkeep, context), + done); // All operations done after previous sort were keeping the relative order // of the elements the elements are still sorted keep topN <=> truncate // the array @@ -602,26 +580,22 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { // adding the image_index dimension on the fly CudaLaunchConfig config = GetCudaLaunchConfig(post_nms_topn_, d); // make this single kernel - WriteUprightBoxesOutput<<>>( - config, reinterpret_cast(d_image_prenms_boxes), + TF_CHECK_OK(GpuLaunchKernel( + WriteUprightBoxesOutput, config.block_count, config.thread_per_block, + 0, d.stream(), config, + reinterpret_cast(d_image_prenms_boxes), d_image_prenms_scores, d_image_boxes_keep_list, postnms_nboxes, - d_image_postnms_rois, d_image_postnms_rois_probs); + d_image_postnms_rois, d_image_postnms_rois_probs)); nrois_in_output += postnms_nboxes; - CHECK_EQ(cudaGetLastError(), CUDA_SUCCESS); + TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC(context, cudaGetLastError(), done); } done(); } private: - int pre_nms_topn_; int post_nms_topn_; - float nms_threshold_; - float min_size_; - float feat_stride_; float bbox_xform_clip_default_; bool correct_transform_coords_; - bool debug_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 95a9b71fffb..86762cf2c70 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -988,17 +988,17 @@ REGISTER_OP("GenerateBoundingBoxProposals") .Input("bbox_deltas: float") .Input("image_info: float") .Input("anchors: float") + .Input("nms_threshold: float") + .Input("pre_nms_topn: int32") + .Input("min_size: float") .Output("rois: float") .Output("roi_probabilities: float") - .Attr("pre_nms_topn: int = 6000") .Attr("post_nms_topn: int = 300") - .Attr("nms_threshold: float = 0.7") - .Attr("min_size: float = 16") - .Attr("debug: bool = false") .Attr("correct_transform_coords: bool = true") .SetShapeFn([](InferenceContext* c) -> Status { // make sure input tensors have are correct rank - ShapeHandle scores, images, bounding_boxes, anchors; + ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold, + n_pre_nms, min_box_size; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &scores)); //(N, H, W, A) TF_RETURN_IF_ERROR( c->WithRank(c->input(1), 4, &bounding_boxes)); //(N,H,W,A4) @@ -1006,6 +1006,11 @@ REGISTER_OP("GenerateBoundingBoxProposals") auto im_info = c->Dim(images, 1); TF_RETURN_IF_ERROR(c->WithValue(im_info, 5, &im_info)); TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 3, &anchors)); // (A4) + // check scalar tensors + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &nms_threshold)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &n_pre_nms)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &min_box_size)); + // TODO(skama): verify that the inputs are compatible int post_nms_top_n; TF_RETURN_IF_ERROR(c->GetAttr("post_nms_topn", &post_nms_top_n)); diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index f1164d5adbd..362d5f7f130 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3888,3 +3888,13 @@ def draw_bounding_boxes(images, boxes, name=None, colors=None): A `Tensor`. Has the same type as `images`. """ return draw_bounding_boxes_v2(images, boxes, colors, name) + + +@tf_export("image.generate_bounding_box_proposals") +def generate_bounding_box_proposals(scores, bbox_deltas, image_info, anchors, nms_threshold=0.7, pre_nms_topn=6000, min_size=16, post_nms_topn=300, correct_transform_coords=True, name=None): + """ Generate bounding box proposals from encoded bounding boxes. + Returns: + rois: Region of interest boxes sorted by their scores. + roi_probabilities: scores of the roi boxes in the rois tensor. + """ + return gen_image_ops.generate_bounding_box_proposals(scores, bbox_deltas, image_info, anchors, nms_threshold, pre_nms_topn, min_size, post_nms_topn, correct_transform_coords) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt index ea5110674a6..ef946d406cd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt @@ -108,6 +108,10 @@ tf_module { name: "flip_up_down" argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "generate_bounding_box_proposals" + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'0.7\', \'6000\', \'16\', \'300\', \'True\', \'None\'], " + } member_method { name: "grayscale_to_rgb" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 4d6a6f0a737..64c93cd25e2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1330,7 +1330,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " } member_method { name: "get_collection" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 50a66d78780..c50ca51efee 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1478,7 +1478,7 @@ tf_module { } member_method { name: "GenerateBoundingBoxProposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " } member_method { name: "GenerateVocabRemapping" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt index 231fc631e83..dae98dc3ed7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt @@ -104,6 +104,10 @@ tf_module { name: "flip_up_down" argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "generate_bounding_box_proposals" + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'0.7\', \'6000\', \'16\', \'300\', \'True\', \'None\'], " + } member_method { name: "grayscale_to_rgb" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 3cd85b06683..92b9a3340ab 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -658,7 +658,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " } member_method { name: "get_logger" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 50a66d78780..c50ca51efee 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1478,7 +1478,7 @@ tf_module { } member_method { name: "GenerateBoundingBoxProposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'pre_nms_topn\', \'post_nms_topn\', \'nms_threshold\', \'min_size\', \'debug\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'6000\', \'300\', \'0.7\', \'16\', \'False\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " } member_method { name: "GenerateVocabRemapping" From 4d7c3a22349d707e2efd078dfed5bca8c75318e5 Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 11 Jul 2019 13:28:55 -0700 Subject: [PATCH 04/17] Fix pylint problem --- tensorflow/python/ops/image_ops_impl.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 362d5f7f130..492048d89fe 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3891,10 +3891,21 @@ def draw_bounding_boxes(images, boxes, name=None, colors=None): @tf_export("image.generate_bounding_box_proposals") -def generate_bounding_box_proposals(scores, bbox_deltas, image_info, anchors, nms_threshold=0.7, pre_nms_topn=6000, min_size=16, post_nms_topn=300, correct_transform_coords=True, name=None): +def generate_bounding_box_proposals(scores, + bbox_deltas, + image_info, + anchors, + nms_threshold=0.7, + pre_nms_topn=6000, + min_size=16, + post_nms_topn=300, + correct_transform_coords=True, + name=None): """ Generate bounding box proposals from encoded bounding boxes. Returns: rois: Region of interest boxes sorted by their scores. roi_probabilities: scores of the roi boxes in the rois tensor. """ - return gen_image_ops.generate_bounding_box_proposals(scores, bbox_deltas, image_info, anchors, nms_threshold, pre_nms_topn, min_size, post_nms_topn, correct_transform_coords) + return gen_image_ops.generate_bounding_box_proposals( + scores, bbox_deltas, image_info, anchors, nms_threshold, pre_nms_topn, + min_size, post_nms_topn, correct_transform_coords) From d2ecf134432047cc3c4e3fb4fdcba3e0cfe140d9 Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 18 Jul 2019 15:49:53 -0700 Subject: [PATCH 05/17] Address review comments --- ...api_def_GenerateBoundingBoxProposals.pbtxt | 4 +- .../kernels/generate_box_proposals_op.cu.cc | 293 +++++++++--------- tensorflow/core/ops/image_ops.cc | 2 +- .../tools/api/golden/v1/tensorflow.pbtxt | 2 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../tools/api/golden/v2/tensorflow.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 7 files changed, 148 insertions(+), 159 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt index 31796cbf151..e4405becabb 100644 --- a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt @@ -64,9 +64,9 @@ An integer. Maximum number of rois in the output. END } attr { - name: "correct_transform_coords" + name: "use_detectron_offset" description: <SetStatus(errors::Internal("Cuda call failed with", \ cudaGetErrorString(error))); \ - done(); \ return; \ } \ } while (0) namespace { + +template +__device__ float legacy_op(float); +template <> +__device__ float legacy_op(float a) { + return a + 1.; +} +template <> +__device__ float legacy_op(float a) { + return a; +} // Decode d_bbox_deltas with respect to anchors into absolute coordinates, // clipping if necessary. -// H is height, W is width, A is the number of anchors. // prenms_nboxes maximum number of boxes per image to decode. // d_boxes_keep_flags mask for boxes to consider in NMS. // min_size is the lower bound of the shortest edge for the boxes to consider. // bbox_xform_clip is the upper bound of encoded width and height. -// correct_transform is a flag to apply coordinate transformation correction to -// x2 and y2. +template __global__ void GeneratePreNMSUprightBoxesKernel( const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys, - const float4* d_bbox_deltas, const float4* d_anchors, const int H, - const int W, const int A, const float min_size, const float* d_img_info_vec, - const float bbox_xform_clip, const bool correct_transform, + const float4* d_bbox_deltas, const float4* d_anchors, const int height, + const int width, const int num_anchors, const float min_size, + const float* d_img_info_vec, const float bbox_xform_clip, float4* d_out_boxes, const int prenms_nboxes, // leading dimension of out_boxes - float* d_inout_scores, char* d_boxes_keep_flags) { + float* d_scores, char* d_boxes_keep_flags) { // constants to calculate offsets in to the input and output arrays. - const int K = H * W; // Stride of Anchor - const int WA = W * A; // Stride of H - const int KA = K * A; // Stride of image + const int anchor_stride = height * width; // Stride of Anchor + const int height_stride = width * num_anchors; // Stride of height + const int image_stride = anchor_stride * num_anchors; // Stride of image CUDA_AXIS_KERNEL_LOOP(image_index, config.virtual_thread_count.y, Y) { CUDA_AXIS_KERNEL_LOOP(ibox, config.virtual_thread_count.x, X) { // box_conv_index : # of the same box, but indexed in the - // scores from the conv layer, of shape (H,W,A) the num_images dimension - // was already removed box_conv_index = a*K + h*W + w - const int box_conv_index = d_sorted_scores_keys[image_index * KA + ibox]; + // scores from the conv layer, of shape (height,width,A) the num_images + // dimension was already removed box_conv_index = a*image_stride + h*width + // + w + const int box_conv_index = + d_sorted_scores_keys[image_index * image_stride + ibox]; // We want to decompose box_conv_index in (h,w,a) - // such as box_conv_index = h*W*A + W*A + a + // such as box_conv_index = h*width*A + width*A + a // (avoiding modulos in the process) int remaining = box_conv_index; - const int dH = WA; // stride of H - const int h = remaining / dH; - remaining -= h * dH; - const int dW = A; // stride of W - const int w = remaining / dW; - remaining -= w * dW; + const int delta_height = height_stride; // stride of height + const int h = remaining / delta_height; + remaining -= h * delta_height; + const int delta_width = num_anchors; // stride of width + const int w = remaining / delta_width; + remaining -= w * delta_width; // Loading the anchor a // float4 is a struct with float x,y,z,w const float4 anchor = d_anchors[box_conv_index]; @@ -99,58 +109,47 @@ __global__ void GeneratePreNMSUprightBoxesKernel( // TODO use fast math when possible - // Deltas of shape (N,H,W,A4) - int deltas_idx = box_conv_index + image_index * KA; + // Deltas of shape (N,height,width,A4) + int deltas_idx = box_conv_index + image_index * image_stride; float4 deltas = d_bbox_deltas[deltas_idx]; float dx = deltas.y; float dy = deltas.x; float dw = deltas.w; float dh = deltas.z; - // printf("deltas_idx=%d dx=%f, dy=%f, dw=%f, - // dh=%f\n",deltas_idx,dx,dy,dw,dh); // Upper bound on dw,dh dw = fmin(dw, bbox_xform_clip); dh = fmin(dh, bbox_xform_clip); // Applying the deltas - float width = x2 - x1 + 1.0f; + float width = legacy_op(x2 - x1); const float ctr_x = x1 + 0.5f * width; const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd const float pred_w = width * expf(dw); x1 = pred_ctr_x - 0.5f * pred_w; - x2 = pred_ctr_x + 0.5f * pred_w; + x2 = -legacy_op(-(pred_ctr_x + 0.5f * pred_w)); - float height = y2 - y1 + 1.0f; + float height = legacy_op(y2 - y1); const float ctr_y = y1 + 0.5f * height; const float pred_ctr_y = ctr_y + height * dy; const float pred_h = height * expf(dh); y1 = pred_ctr_y - 0.5f * pred_h; - y2 = pred_ctr_y + 0.5f * pred_h; + y2 = -legacy_op(-(pred_ctr_y + 0.5f * pred_h)); // -1 if legacy_op - if (correct_transform) { - x2 -= 1.0f; - y2 -= 1.0f; - } - // const float y2_old=y2; - // const float x2_old=x2; - // const float x1_old=x1; - // const float y1_old=y1; // Clipping box to image const float img_height = d_img_info_vec[5 * image_index + 0]; const float img_width = d_img_info_vec[5 * image_index + 1]; const float min_size_scaled = min_size * d_img_info_vec[5 * image_index + 2]; - // min_size * d_img_info_vec[3 * image_index + 2]; - x1 = fmax(fmin(x1, img_width - 1.0f), 0.0f); - y1 = fmax(fmin(y1, img_height - 1.0f), 0.0f); - x2 = fmax(fmin(x2, img_width - 1.0f), 0.0f); - y2 = fmax(fmin(y2, img_height - 1.0f), 0.0f); + x1 = fmax(fmin(x1, -legacy_op(-img_width)), 0.0f); + y1 = fmax(fmin(y1, -legacy_op(-img_height)), 0.0f); + x2 = fmax(fmin(x2, -legacy_op(-img_width)), 0.0f); + y2 = fmax(fmin(y2, -legacy_op(-img_height)), 0.0f); // Filter boxes // Removing boxes with one dim < min_size // (center of box is in image, because of previous step) - width = x2 - x1 + 1.0f; // may have changed - height = y2 - y1 + 1.0f; + width = legacy_op(x2 - x1); // may have changed + height = legacy_op(y2 - y1); bool keep_box = fmin(width, height) >= min_size_scaled; // We are not deleting the box right now even if !keep_box @@ -162,9 +161,12 @@ __global__ void GeneratePreNMSUprightBoxesKernel( d_boxes_keep_flags[out_index] = keep_box; d_out_boxes[out_index] = {x1, y1, x2, y2}; - // d_inout_scores size: (num_images,KA) - if (!keep_box) - d_inout_scores[image_index * KA + ibox] = FLT_MIN; // for NMS + // d_scores size: (num_images,KA) + // Set the score of the box to float minimum + if (!keep_box) { + d_scores[image_index * image_stride + ibox] = + std::numeric_limits::min(); // for NMS + } } } } @@ -215,7 +217,8 @@ Status AllocateGenerationTempTensors( Tensor* d_cub_select_buffer, Tensor* d_sorted_conv_layer_indexes, Tensor* d_sorted_scores, Tensor* dev_boxes, Tensor* dev_boxes_keep_flags, int num_images, int conv_layer_nboxes, size_t cub_sort_temp_storage_bytes, - size_t cub_select_temp_storage_bytes, int nboxes_to_generate, int box_dim) { + size_t cub_select_temp_storage_bytes, int num_boxes_to_generate, + int box_dim) { auto d = context->eigen_gpu_device(); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), @@ -242,10 +245,10 @@ Status AllocateGenerationTempTensors( ResetTensor(d_sorted_scores, d); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, - TensorShape({num_images, box_dim * nboxes_to_generate}), dev_boxes)); + TensorShape({num_images, box_dim * num_boxes_to_generate}), dev_boxes)); ResetTensor(dev_boxes, d); TF_RETURN_IF_ERROR(context->allocate_temp( - DataType::DT_INT8, TensorShape({num_images, nboxes_to_generate}), + DataType::DT_INT8, TensorShape({num_images, num_boxes_to_generate}), dev_boxes_keep_flags)); ResetTensor(dev_boxes_keep_flags, d); return Status::OK(); @@ -298,19 +301,17 @@ Status AllocatePreNMSTempTensors( __global__ void InitializeDataKernel(const Cuda2DLaunchConfig config, int* d_image_offsets, int* d_boxes_keys_iota) { - const int KA = config.virtual_thread_count.x; + const int image_size = config.virtual_thread_count.x; const int num_images = config.virtual_thread_count.y; - // printf("num_images %d KA %d\n",num_images,KA); CUDA_AXIS_KERNEL_LOOP(img_idx, config.virtual_thread_count.y, Y) { CUDA_AXIS_KERNEL_LOOP(box_idx, config.virtual_thread_count.x, X) { - // CUDA_2D_KERNEL_LOOP(box_idx, KA, img_idx, num_images) { - d_boxes_keys_iota[img_idx * KA + box_idx] = box_idx; + d_boxes_keys_iota[img_idx * image_size + box_idx] = box_idx; // One 1D line sets the 1D data if (box_idx == 0) { - d_image_offsets[img_idx] = KA * img_idx; + d_image_offsets[img_idx] = image_size * img_idx; // One thread sets the last+1 offset - if (img_idx == 0) d_image_offsets[num_images] = KA * num_images; + if (img_idx == 0) d_image_offsets[num_images] = image_size * num_images; } } } @@ -318,15 +319,15 @@ __global__ void InitializeDataKernel(const Cuda2DLaunchConfig config, } // namespace -class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { +class GenerateBoundingBoxProposals : public tensorflow::OpKernel { public: explicit GenerateBoundingBoxProposals( tensorflow::OpKernelConstruction* context) - : AsyncOpKernel(context) { + : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_)); // compatibility for detectron like networks. False for generic case - OP_REQUIRES_OK(context, context->GetAttr("correct_transform_coords", - &correct_transform_coords_)); + OP_REQUIRES_OK(context, context->GetAttr("use_detectron_offset", + &use_detectron_offset_)); CHECK_GT(post_nms_topn_, 0); bbox_xform_clip_default_ = log(1000.0 / 16.); } @@ -343,52 +344,43 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { return Status::OK(); } - void ComputeAsync(tensorflow::OpKernelContext* context, - DoneCallback done) override { - // .Input("scores: float") - // .Input("bbox_deltas: float") - // .Input("image_info: float") - // .Input("anchors: float") - + void Compute(tensorflow::OpKernelContext* context) override { const auto scores = context->input(0); const auto bbox_deltas = context->input(1); const auto image_info = context->input(2); const auto anchors = context->input(3); const auto num_images = scores.dim_size(0); - const auto A = scores.dim_size(3); - const auto H = scores.dim_size(1); - const auto W = scores.dim_size(2); - const auto box_dim = anchors.dim_size(0) / A; - CHECK_EQ(box_dim, 4); + const auto num_anchors = scores.dim_size(3); + const auto height = scores.dim_size(1); + const auto width = scores.dim_size(2); + const auto box_dim = anchors.dim_size(0) / num_anchors; + OP_REQUIRES(context, box_dim == 4, + errors::OutOfRange("Box dimensions need to be 4")); // TODO(skama): make sure that inputs are ok. - const int K = H * W; - // VLOG(0)<<"num_images="<SetStatus(errors::InvalidArgument( "pre_nms_topn should be greater than 0", pre_nms_topn)); - done(); return; } - OP_REQUIRES_OK_ASYNC(context, GetScalarValue(context, 6, &min_size), done); + OP_REQUIRES_OK(context, GetScalarValue(context, 6, &min_size)); auto cuda_stream = GetCudaStream(context); size_t cub_sort_temp_storage_bytes = 0; float* flt_ptr = nullptr; @@ -398,17 +390,15 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { int_ptr, num_images * conv_layer_nboxes, num_images, int_ptr, int_ptr, 0, 8 * sizeof(float), // sort all bits cuda_stream); - CHECK_EQ(cuda_ret, cudaSuccess); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); // get the size of select temp buffer size_t cub_select_temp_storage_bytes = 0; char* char_ptr = nullptr; float4* f4_ptr = nullptr; - TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( - context, - cub::DeviceSelect::Flagged(nullptr, cub_select_temp_storage_bytes, - f4_ptr, char_ptr, f4_ptr, int_ptr, K * A, - cuda_stream), - done); + TF_OP_REQUIRES_CUDA_SUCCESS( + context, cub::DeviceSelect::Flagged( + nullptr, cub_select_temp_storage_bytes, f4_ptr, char_ptr, + f4_ptr, int_ptr, image_stride * num_anchors, cuda_stream)); Tensor d_conv_layer_indexes; // box indices on device Tensor d_image_offset; // starting offsets boxes for each image Tensor d_cub_sort_buffer; // buffer for cub sorting @@ -420,15 +410,14 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { Tensor dev_boxes_keep_flags; // bitmask for keeping the boxes or rejecting // from output const int nboxes_to_generate = std::min(conv_layer_nboxes, pre_nms_topn); - OP_REQUIRES_OK_ASYNC( + OP_REQUIRES_OK( context, AllocateGenerationTempTensors( context, &d_conv_layer_indexes, &d_image_offset, &d_cub_sort_buffer, &d_cub_select_buffer, &d_sorted_conv_layer_indexes, &dev_sorted_scores, &dev_boxes, &dev_boxes_keep_flags, num_images, conv_layer_nboxes, cub_sort_temp_storage_bytes, - cub_select_temp_storage_bytes, nboxes_to_generate, box_dim), - done); + cub_select_temp_storage_bytes, nboxes_to_generate, box_dim)); const GPUDevice& d = context->eigen_device(); Cuda2DLaunchConfig conf2d = GetCuda2DLaunchConfig(conv_layer_nboxes, num_images, d); @@ -441,7 +430,7 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { // sort boxes with their scores. // d_sorted_conv_layer_indexes will hold the pointers to old indices. - TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( + TF_OP_REQUIRES_CUDA_SUCCESS( context, cub::DeviceSegmentedRadixSort::SortPairsDescending( d_cub_sort_buffer.flat().data(), cub_sort_temp_storage_bytes, @@ -452,24 +441,37 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { d_image_offset.flat().data(), d_image_offset.flat().data() + 1, 0, 8 * sizeof(float), // sort all bits - cuda_stream), - done); + cuda_stream)); // Keeping only the topN pre_nms conf2d = GetCuda2DLaunchConfig(nboxes_to_generate, num_images, d); // create box y1,x1,y2,x2 from box_deltas and anchors (decode the boxes) and // mark the boxes which are smaller that min_size ignored. - TF_CHECK_OK(GpuLaunchKernel( - GeneratePreNMSUprightBoxesKernel, conf2d.block_count, - conf2d.thread_per_block, 0, d.stream(), conf2d, - d_sorted_conv_layer_indexes.flat().data(), - reinterpret_cast(bbox_deltas.flat().data()), - reinterpret_cast(anchors.flat().data()), H, W, A, - min_size, image_info.flat().data(), bbox_xform_clip_default_, - correct_transform_coords_, - reinterpret_cast(dev_boxes.flat().data()), - nboxes_to_generate, dev_sorted_scores.flat().data(), - (char*)dev_boxes_keep_flags.flat().data())); + if (use_detectron_offset_) { + TF_CHECK_OK(GpuLaunchKernel( + GeneratePreNMSUprightBoxesKernel, conf2d.block_count, + conf2d.thread_per_block, 0, d.stream(), conf2d, + d_sorted_conv_layer_indexes.flat().data(), + reinterpret_cast(bbox_deltas.flat().data()), + reinterpret_cast(anchors.flat().data()), height, + width, num_anchors, min_size, image_info.flat().data(), + bbox_xform_clip_default_, + reinterpret_cast(dev_boxes.flat().data()), + nboxes_to_generate, dev_sorted_scores.flat().data(), + (char*)dev_boxes_keep_flags.flat().data())); + } else { + TF_CHECK_OK(GpuLaunchKernel( + GeneratePreNMSUprightBoxesKernel, conf2d.block_count, + conf2d.thread_per_block, 0, d.stream(), conf2d, + d_sorted_conv_layer_indexes.flat().data(), + reinterpret_cast(bbox_deltas.flat().data()), + reinterpret_cast(anchors.flat().data()), height, + width, num_anchors, min_size, image_info.flat().data(), + bbox_xform_clip_default_, + reinterpret_cast(dev_boxes.flat().data()), + nboxes_to_generate, dev_sorted_scores.flat().data(), + (char*)dev_boxes_keep_flags.flat().data())); + } const int nboxes_generated = nboxes_to_generate; const int roi_cols = box_dim; const int max_postnms_nboxes = std::min(nboxes_generated, post_nms_topn_); @@ -480,14 +482,12 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { Tensor dev_postnms_rois_probs; Tensor dev_prenms_nboxes; // Allocate workspaces needed for NMS - OP_REQUIRES_OK_ASYNC( - context, - AllocatePreNMSTempTensors( - context, &dev_image_prenms_boxes, &dev_image_prenms_scores, - &dev_image_boxes_keep_list, &dev_postnms_rois, - &dev_postnms_rois_probs, &dev_prenms_nboxes, num_images, - nboxes_generated, box_dim, post_nms_topn_, pre_nms_topn), - done); + OP_REQUIRES_OK( + context, AllocatePreNMSTempTensors( + context, &dev_image_prenms_boxes, &dev_image_prenms_scores, + &dev_image_boxes_keep_list, &dev_postnms_rois, + &dev_postnms_rois_probs, &dev_prenms_nboxes, num_images, + nboxes_generated, box_dim, post_nms_topn_, pre_nms_topn)); // get the pointers for temp storages int* d_prenms_nboxes = dev_prenms_nboxes.flat().data(); int h_prenms_nboxes = 0; @@ -506,17 +506,13 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { // Create output tensors Tensor* output_rois = nullptr; Tensor* output_roi_probs = nullptr; - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output( - 0, TensorShape({num_images, post_nms_topn_, roi_cols}), - &output_rois), - done); - OP_REQUIRES_OK_ASYNC( - context, - context->allocate_output(1, TensorShape({num_images, post_nms_topn_}), - &output_roi_probs), - done); + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({num_images, post_nms_topn_, roi_cols}), + &output_rois)); + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({num_images, post_nms_topn_}), + &output_roi_probs)); float* d_postnms_rois = (*output_rois).flat().data(); float* d_postnms_rois_probs = (*output_roi_probs).flat().data(); @@ -530,7 +526,7 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { &d_boxes[image_index * nboxes_generated * box_dim]; // scores const float* d_image_sorted_scores = - &d_sorted_scores[image_index * K * A]; + &d_sorted_scores[image_index * image_stride * num_anchors]; // keep flags char* d_image_boxes_keep_flags = &d_boxes_keep_flags[image_index * nboxes_generated]; @@ -543,35 +539,29 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { // Moving valid boxes (ie the ones with d_boxes_keep_flags[ibox] == true) // to the output tensors - TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( - context, - cub::DeviceSelect::Flagged( - d_cub_select_temp_storage, cub_select_temp_storage_bytes, - reinterpret_cast(d_image_boxes), - d_image_boxes_keep_flags, - reinterpret_cast(d_image_prenms_boxes), d_prenms_nboxes, - nboxes_generated, d.stream()), - done); + TF_OP_REQUIRES_CUDA_SUCCESS( + context, cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, cub_select_temp_storage_bytes, + reinterpret_cast(d_image_boxes), + d_image_boxes_keep_flags, + reinterpret_cast(d_image_prenms_boxes), + d_prenms_nboxes, nboxes_generated, d.stream())); - TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC( - context, - cub::DeviceSelect::Flagged( - d_cub_select_temp_storage, cub_select_temp_storage_bytes, - d_image_sorted_scores, d_image_boxes_keep_flags, - d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, - d.stream()), - done); + TF_OP_REQUIRES_CUDA_SUCCESS( + context, cub::DeviceSelect::Flagged( + d_cub_select_temp_storage, cub_select_temp_storage_bytes, + d_image_sorted_scores, d_image_boxes_keep_flags, + d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, + d.stream())); d.memcpyDeviceToHost(&h_prenms_nboxes, d_prenms_nboxes, sizeof(int)); d.synchronize(); // We know prenms_boxes <= topN_prenms, because nboxes_generated <= // topN_prenms. Calling NMS on the generated boxes const int prenms_nboxes = h_prenms_nboxes; int nkeep; - OP_REQUIRES_OK_ASYNC( - context, - NmsGpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold, - d_image_boxes_keep_list, &nkeep, context), - done); + OP_REQUIRES_OK(context, + NmsGpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold, + d_image_boxes_keep_list, &nkeep, context)); // All operations done after previous sort were keeping the relative order // of the elements the elements are still sorted keep topN <=> truncate // the array @@ -587,15 +577,14 @@ class GenerateBoundingBoxProposals : public tensorflow::AsyncOpKernel { d_image_prenms_scores, d_image_boxes_keep_list, postnms_nboxes, d_image_postnms_rois, d_image_postnms_rois_probs)); nrois_in_output += postnms_nboxes; - TF_OP_REQUIRES_CUDA_SUCCESS_ASYNC(context, cudaGetLastError(), done); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); } - done(); } private: int post_nms_topn_; float bbox_xform_clip_default_; - bool correct_transform_coords_; + bool use_detectron_offset_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 86762cf2c70..50945c3c2ea 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -994,7 +994,7 @@ REGISTER_OP("GenerateBoundingBoxProposals") .Output("rois: float") .Output("roi_probabilities: float") .Attr("post_nms_topn: int = 300") - .Attr("correct_transform_coords: bool = true") + .Attr("use_detectron_offset: bool = false") .SetShapeFn([](InferenceContext* c) -> Status { // make sure input tensors have are correct rank ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 64c93cd25e2..3b6649643a2 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1330,7 +1330,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " } member_method { name: "get_collection" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index c50ca51efee..685644e9d0c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1478,7 +1478,7 @@ tf_module { } member_method { name: "GenerateBoundingBoxProposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " } member_method { name: "GenerateVocabRemapping" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 92b9a3340ab..e3c58b0b24b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -658,7 +658,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " } member_method { name: "get_logger" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index c50ca51efee..685644e9d0c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1478,7 +1478,7 @@ tf_module { } member_method { name: "GenerateBoundingBoxProposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " } member_method { name: "GenerateVocabRemapping" From 7f0b4b77f29b6062fdfbc66b2e310ee8cbd4c7f4 Mon Sep 17 00:00:00 2001 From: Sami Date: Wed, 24 Jul 2019 13:24:38 -0700 Subject: [PATCH 06/17] Renamed legacy op and added subtract version --- .../kernels/generate_box_proposals_op.cu.cc | 67 ++++++++++--------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index 67126b11cb1..2f6c4a16d34 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -51,13 +51,24 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template -__device__ float legacy_op(float); +__device__ float AddLegacyOffset(float); template <> -__device__ float legacy_op(float a) { +__device__ float AddLegacyOffset(float a) { return a + 1.; } template <> -__device__ float legacy_op(float a) { +__device__ float AddLegacyOffset(float a) { + return a; +} + +template +__device__ float SubtractLegacyOffset(float); +template <> +__device__ float SubtractLegacyOffset(float a) { + return a - 1.; +} +template <> +__device__ float SubtractLegacyOffset(float a) { return a; } // Decode d_bbox_deltas with respect to anchors into absolute coordinates, @@ -71,10 +82,10 @@ __global__ void GeneratePreNMSUprightBoxesKernel( const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys, const float4* d_bbox_deltas, const float4* d_anchors, const int height, const int width, const int num_anchors, const float min_size, - const float* d_img_info_vec, const float bbox_xform_clip, - float4* d_out_boxes, + const float* d_img_info_vec, // Input "image_info" to the op [N,5] + const float bbox_xform_clip, float4* d_out_boxes, const int prenms_nboxes, // leading dimension of out_boxes - float* d_scores, char* d_boxes_keep_flags) { + char* d_boxes_keep_flags) { // constants to calculate offsets in to the input and output arrays. const int anchor_stride = height * width; // Stride of Anchor const int height_stride = width * num_anchors; // Stride of height @@ -82,14 +93,15 @@ __global__ void GeneratePreNMSUprightBoxesKernel( CUDA_AXIS_KERNEL_LOOP(image_index, config.virtual_thread_count.y, Y) { CUDA_AXIS_KERNEL_LOOP(ibox, config.virtual_thread_count.x, X) { // box_conv_index : # of the same box, but indexed in the - // scores from the conv layer, of shape (height,width,A) the num_images - // dimension was already removed box_conv_index = a*image_stride + h*width + // scores from the conv layer, of shape (height,width,num_anchors) the + // num_images dimension was already removed box_conv_index = + // a*image_stride + h*width // + w const int box_conv_index = d_sorted_scores_keys[image_index * image_stride + ibox]; // We want to decompose box_conv_index in (h,w,a) - // such as box_conv_index = h*width*A + width*A + a + // such as box_conv_index = h*width*num_anchors + width*num_anchors + a // (avoiding modulos in the process) int remaining = box_conv_index; const int delta_height = height_stride; // stride of height @@ -109,7 +121,7 @@ __global__ void GeneratePreNMSUprightBoxesKernel( // TODO use fast math when possible - // Deltas of shape (N,height,width,A4) + // Deltas of shape (N,height,width,num_anchors x 4) int deltas_idx = box_conv_index + image_index * image_stride; float4 deltas = d_bbox_deltas[deltas_idx]; float dx = deltas.y; @@ -121,35 +133,36 @@ __global__ void GeneratePreNMSUprightBoxesKernel( dh = fmin(dh, bbox_xform_clip); // Applying the deltas - float width = legacy_op(x2 - x1); + float width = AddLegacyOffset(x2 - x1); const float ctr_x = x1 + 0.5f * width; const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd const float pred_w = width * expf(dw); x1 = pred_ctr_x - 0.5f * pred_w; - x2 = -legacy_op(-(pred_ctr_x + 0.5f * pred_w)); + x2 = SubtractLegacyOffset(pred_ctr_x + 0.5f * pred_w); - float height = legacy_op(y2 - y1); + float height = AddLegacyOffset(y2 - y1); const float ctr_y = y1 + 0.5f * height; const float pred_ctr_y = ctr_y + height * dy; const float pred_h = height * expf(dh); y1 = pred_ctr_y - 0.5f * pred_h; - y2 = -legacy_op(-(pred_ctr_y + 0.5f * pred_h)); // -1 if legacy_op + y2 = SubtractLegacyOffset(pred_ctr_y + + 0.5f * pred_h); // -1 if legacy_op // Clipping box to image const float img_height = d_img_info_vec[5 * image_index + 0]; const float img_width = d_img_info_vec[5 * image_index + 1]; const float min_size_scaled = min_size * d_img_info_vec[5 * image_index + 2]; - x1 = fmax(fmin(x1, -legacy_op(-img_width)), 0.0f); - y1 = fmax(fmin(y1, -legacy_op(-img_height)), 0.0f); - x2 = fmax(fmin(x2, -legacy_op(-img_width)), 0.0f); - y2 = fmax(fmin(y2, -legacy_op(-img_height)), 0.0f); + x1 = fmax(fmin(x1, SubtractLegacyOffset(img_width)), 0.0f); + y1 = fmax(fmin(y1, SubtractLegacyOffset(img_height)), 0.0f); + x2 = fmax(fmin(x2, SubtractLegacyOffset(img_width)), 0.0f); + y2 = fmax(fmin(y2, SubtractLegacyOffset(img_height)), 0.0f); // Filter boxes // Removing boxes with one dim < min_size // (center of box is in image, because of previous step) - width = legacy_op(x2 - x1); // may have changed - height = legacy_op(y2 - y1); + width = AddLegacyOffset(x2 - x1); // may have changed + height = AddLegacyOffset(y2 - y1); bool keep_box = fmin(width, height) >= min_size_scaled; // We are not deleting the box right now even if !keep_box @@ -161,12 +174,6 @@ __global__ void GeneratePreNMSUprightBoxesKernel( d_boxes_keep_flags[out_index] = keep_box; d_out_boxes[out_index] = {x1, y1, x2, y2}; - // d_scores size: (num_images,KA) - // Set the score of the box to float minimum - if (!keep_box) { - d_scores[image_index * image_stride + ibox] = - std::numeric_limits::min(); // for NMS - } } } } @@ -297,7 +304,7 @@ Status AllocatePreNMSTempTensors( } // Initialize index and offset arrays. -// num_images is the batch size, KA is the number of anchors +// num_images is the batch size. __global__ void InitializeDataKernel(const Cuda2DLaunchConfig config, int* d_image_offsets, int* d_boxes_keys_iota) { @@ -457,8 +464,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { width, num_anchors, min_size, image_info.flat().data(), bbox_xform_clip_default_, reinterpret_cast(dev_boxes.flat().data()), - nboxes_to_generate, dev_sorted_scores.flat().data(), - (char*)dev_boxes_keep_flags.flat().data())); + nboxes_to_generate, (char*)dev_boxes_keep_flags.flat().data())); } else { TF_CHECK_OK(GpuLaunchKernel( GeneratePreNMSUprightBoxesKernel, conf2d.block_count, @@ -469,8 +475,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { width, num_anchors, min_size, image_info.flat().data(), bbox_xform_clip_default_, reinterpret_cast(dev_boxes.flat().data()), - nboxes_to_generate, dev_sorted_scores.flat().data(), - (char*)dev_boxes_keep_flags.flat().data())); + nboxes_to_generate, (char*)dev_boxes_keep_flags.flat().data())); } const int nboxes_generated = nboxes_to_generate; const int roi_cols = box_dim; From 0c8c8c22bf25acaf7b69cdced73021bf2b939d58 Mon Sep 17 00:00:00 2001 From: Sami Date: Wed, 28 Aug 2019 18:06:07 -0700 Subject: [PATCH 07/17] Removed legacy support --- ...api_def_GenerateBoundingBoxProposals.pbtxt | 10 +- tensorflow/core/kernels/BUILD | 2 +- .../kernels/generate_box_proposals_op.cu.cc | 96 +++++++------------ tensorflow/core/ops/image_ops.cc | 1 - .../api/golden/v1/tensorflow.raw_ops.pbtxt | 2 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 2 +- 6 files changed, 39 insertions(+), 74 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt index e4405becabb..648b23e3c6e 100644 --- a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "scores" description: < -__device__ float AddLegacyOffset(float); -template <> -__device__ float AddLegacyOffset(float a) { - return a + 1.; -} -template <> -__device__ float AddLegacyOffset(float a) { - return a; -} - -template -__device__ float SubtractLegacyOffset(float); -template <> -__device__ float SubtractLegacyOffset(float a) { - return a - 1.; -} -template <> -__device__ float SubtractLegacyOffset(float a) { - return a; -} // Decode d_bbox_deltas with respect to anchors into absolute coordinates, // clipping if necessary. // prenms_nboxes maximum number of boxes per image to decode. // d_boxes_keep_flags mask for boxes to consider in NMS. // min_size is the lower bound of the shortest edge for the boxes to consider. // bbox_xform_clip is the upper bound of encoded width and height. -template __global__ void GeneratePreNMSUprightBoxesKernel( const Cuda2DLaunchConfig config, const int* d_sorted_scores_keys, const float4* d_bbox_deltas, const float4* d_anchors, const int height, @@ -133,36 +120,35 @@ __global__ void GeneratePreNMSUprightBoxesKernel( dh = fmin(dh, bbox_xform_clip); // Applying the deltas - float width = AddLegacyOffset(x2 - x1); + float width = x2 - x1; const float ctr_x = x1 + 0.5f * width; const float pred_ctr_x = ctr_x + width * dx; // TODO fuse madd const float pred_w = width * expf(dw); x1 = pred_ctr_x - 0.5f * pred_w; - x2 = SubtractLegacyOffset(pred_ctr_x + 0.5f * pred_w); + x2 = pred_ctr_x + 0.5f * pred_w; - float height = AddLegacyOffset(y2 - y1); + float height = y2 - y1; const float ctr_y = y1 + 0.5f * height; const float pred_ctr_y = ctr_y + height * dy; const float pred_h = height * expf(dh); y1 = pred_ctr_y - 0.5f * pred_h; - y2 = SubtractLegacyOffset(pred_ctr_y + - 0.5f * pred_h); // -1 if legacy_op + y2 = pred_ctr_y + 0.5f * pred_h; // Clipping box to image const float img_height = d_img_info_vec[5 * image_index + 0]; const float img_width = d_img_info_vec[5 * image_index + 1]; const float min_size_scaled = min_size * d_img_info_vec[5 * image_index + 2]; - x1 = fmax(fmin(x1, SubtractLegacyOffset(img_width)), 0.0f); - y1 = fmax(fmin(y1, SubtractLegacyOffset(img_height)), 0.0f); - x2 = fmax(fmin(x2, SubtractLegacyOffset(img_width)), 0.0f); - y2 = fmax(fmin(y2, SubtractLegacyOffset(img_height)), 0.0f); + x1 = fmax(fmin(x1, img_width), 0.0f); + y1 = fmax(fmin(y1, img_height), 0.0f); + x2 = fmax(fmin(x2, img_width), 0.0f); + y2 = fmax(fmin(y2, img_height), 0.0f); // Filter boxes // Removing boxes with one dim < min_size // (center of box is in image, because of previous step) - width = AddLegacyOffset(x2 - x1); // may have changed - height = AddLegacyOffset(y2 - y1); + width = x2 - x1; // may have changed + height = y2 - y1; bool keep_box = fmin(width, height) >= min_size_scaled; // We are not deleting the box right now even if !keep_box @@ -332,9 +318,6 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { tensorflow::OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_)); - // compatibility for detectron like networks. False for generic case - OP_REQUIRES_OK(context, context->GetAttr("use_detectron_offset", - &use_detectron_offset_)); CHECK_GT(post_nms_topn_, 0); bbox_xform_clip_default_ = log(1000.0 / 16.); } @@ -454,29 +437,16 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // create box y1,x1,y2,x2 from box_deltas and anchors (decode the boxes) and // mark the boxes which are smaller that min_size ignored. - if (use_detectron_offset_) { - TF_CHECK_OK(GpuLaunchKernel( - GeneratePreNMSUprightBoxesKernel, conf2d.block_count, - conf2d.thread_per_block, 0, d.stream(), conf2d, - d_sorted_conv_layer_indexes.flat().data(), - reinterpret_cast(bbox_deltas.flat().data()), - reinterpret_cast(anchors.flat().data()), height, - width, num_anchors, min_size, image_info.flat().data(), - bbox_xform_clip_default_, - reinterpret_cast(dev_boxes.flat().data()), - nboxes_to_generate, (char*)dev_boxes_keep_flags.flat().data())); - } else { - TF_CHECK_OK(GpuLaunchKernel( - GeneratePreNMSUprightBoxesKernel, conf2d.block_count, - conf2d.thread_per_block, 0, d.stream(), conf2d, - d_sorted_conv_layer_indexes.flat().data(), - reinterpret_cast(bbox_deltas.flat().data()), - reinterpret_cast(anchors.flat().data()), height, - width, num_anchors, min_size, image_info.flat().data(), - bbox_xform_clip_default_, - reinterpret_cast(dev_boxes.flat().data()), - nboxes_to_generate, (char*)dev_boxes_keep_flags.flat().data())); - } + TF_CHECK_OK(GpuLaunchKernel( + GeneratePreNMSUprightBoxesKernel, conf2d.block_count, + conf2d.thread_per_block, 0, d.stream(), conf2d, + d_sorted_conv_layer_indexes.flat().data(), + reinterpret_cast(bbox_deltas.flat().data()), + reinterpret_cast(anchors.flat().data()), height, + width, num_anchors, min_size, image_info.flat().data(), + bbox_xform_clip_default_, + reinterpret_cast(dev_boxes.flat().data()), + nboxes_to_generate, (char*)dev_boxes_keep_flags.flat().data())); const int nboxes_generated = nboxes_to_generate; const int roi_cols = box_dim; const int max_postnms_nboxes = std::min(nboxes_generated, post_nms_topn_); @@ -520,6 +490,8 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { &output_roi_probs)); float* d_postnms_rois = (*output_rois).flat().data(); float* d_postnms_rois_probs = (*output_roi_probs).flat().data(); + cudaEvent_t copy_done; + cudaEventCreate(©_done); // Do per-image nms for (int image_index = 0; image_index < num_images; ++image_index) { @@ -559,14 +531,15 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, d.stream())); d.memcpyDeviceToHost(&h_prenms_nboxes, d_prenms_nboxes, sizeof(int)); - d.synchronize(); + TF_OP_REQUIRES_CUDA_SUCCESS(context,cudaEventRecord(copy_done, d.stream())); + TF_OP_REQUIRES_CUDA_SUCCESS(context,cudaEventSynchronize(copy_done)); // We know prenms_boxes <= topN_prenms, because nboxes_generated <= // topN_prenms. Calling NMS on the generated boxes const int prenms_nboxes = h_prenms_nboxes; int nkeep; OP_REQUIRES_OK(context, NmsGpu(d_image_prenms_boxes, prenms_nboxes, nms_threshold, - d_image_boxes_keep_list, &nkeep, context)); + d_image_boxes_keep_list, &nkeep, context, post_nms_topn_)); // All operations done after previous sort were keeping the relative order // of the elements the elements are still sorted keep topN <=> truncate // the array @@ -589,7 +562,6 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { private: int post_nms_topn_; float bbox_xform_clip_default_; - bool use_detectron_offset_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 50945c3c2ea..57c032fbf37 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -994,7 +994,6 @@ REGISTER_OP("GenerateBoundingBoxProposals") .Output("rois: float") .Output("roi_probabilities: float") .Attr("post_nms_topn: int = 300") - .Attr("use_detectron_offset: bool = false") .SetShapeFn([](InferenceContext* c) -> Status { // make sure input tensors have are correct rank ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 685644e9d0c..6fd8b63e962 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -1478,7 +1478,7 @@ tf_module { } member_method { name: "GenerateBoundingBoxProposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], " } member_method { name: "GenerateVocabRemapping" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 685644e9d0c..6fd8b63e962 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -1478,7 +1478,7 @@ tf_module { } member_method { name: "GenerateBoundingBoxProposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], " } member_method { name: "GenerateVocabRemapping" From 629fa2f9fb241b29c0d98f83ee01df05dd9d699e Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 29 Aug 2019 11:04:26 -0700 Subject: [PATCH 08/17] Change GetCudaStream to GetGpuStream --- tensorflow/core/kernels/generate_box_proposals_op.cu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index da4a0a4099e..9cf50cc214d 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -371,7 +371,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { } OP_REQUIRES_OK(context, GetScalarValue(context, 6, &min_size)); - auto cuda_stream = GetCudaStream(context); + auto cuda_stream = GetGpuStream(context); size_t cub_sort_temp_storage_bytes = 0; float* flt_ptr = nullptr; int* int_ptr = nullptr; From 1fdd6e1fbc5595831a7a07c9f4c60d07a3b7d36b Mon Sep 17 00:00:00 2001 From: Sami Date: Tue, 3 Sep 2019 12:54:03 -0700 Subject: [PATCH 09/17] Fix api files --- tensorflow/tools/api/golden/v1/tensorflow.pbtxt | 2 +- tensorflow/tools/api/golden/v2/tensorflow.pbtxt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 6f0b11af7cd..8e7d068cd1c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1342,7 +1342,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], " } member_method { name: "get_collection" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index a9798d7eaae..4d4cd74c213 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -662,7 +662,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'use_detectron_offset\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'False\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], " } member_method { name: "get_logger" From d411e31e70bf666170bec818aadbbfc9b2f5c2ae Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 24 Oct 2019 14:34:32 -0700 Subject: [PATCH 10/17] Fix review comments and remove obsolete flag that is removed by earlier review modifications --- ...api_def_GenerateBoundingBoxProposals.pbtxt | 1 + tensorflow/core/kernels/BUILD | 2 +- .../kernels/generate_box_proposals_op.cu.cc | 54 +++++++++---------- tensorflow/python/ops/image_ops_impl.py | 12 +++-- .../api/golden/v1/tensorflow.image.pbtxt | 2 +- .../tools/api/golden/v1/tensorflow.pbtxt | 4 -- .../api/golden/v2/tensorflow.image.pbtxt | 2 +- .../tools/api/golden/v2/tensorflow.pbtxt | 4 -- 8 files changed, 38 insertions(+), 43 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt index 648b23e3c6e..6403e16a8bc 100644 --- a/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_GenerateBoundingBoxProposals.pbtxt @@ -1,5 +1,6 @@ op { graph_op_name: "GenerateBoundingBoxProposals" + visibility: HIDDEN in_arg { name: "scores" description: <eigen_gpu_device(); TF_RETURN_IF_ERROR(context->allocate_temp( @@ -221,13 +221,8 @@ Status AllocateGenerationTempTensors( DataType::DT_INT32, TensorShape({num_images + 1}), d_image_offset)); ResetTensor(d_image_offset, d); TF_RETURN_IF_ERROR(context->allocate_temp( - DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}), - d_cub_sort_buffer)); - ResetTensor(d_cub_sort_buffer, d); - TF_RETURN_IF_ERROR(context->allocate_temp( - DataType::DT_INT8, TensorShape({(int64)cub_select_temp_storage_bytes}), - d_cub_select_buffer)); - ResetTensor(d_cub_select_buffer, d); + DataType::DT_INT8, TensorShape({(int64)cub_temp_storage_bytes}), + d_cub_temp_buffer)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), d_sorted_conv_layer_indexes)); @@ -335,6 +330,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { } void Compute(tensorflow::OpKernelContext* context) override { + VLOG(1)<<"Starting Compute "<input(0); const auto bbox_deltas = context->input(1); const auto image_info = context->input(2); @@ -343,7 +339,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { const auto num_anchors = scores.dim_size(3); const auto height = scores.dim_size(1); const auto width = scores.dim_size(2); - const auto box_dim = anchors.dim_size(0) / num_anchors; + const auto box_dim = anchors.dim_size(2) / num_anchors; OP_REQUIRES(context, box_dim == 4, errors::OutOfRange("Box dimensions need to be 4")); // TODO(skama): make sure that inputs are ok. @@ -369,7 +365,6 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { "pre_nms_topn should be greater than 0", pre_nms_topn)); return; } - OP_REQUIRES_OK(context, GetScalarValue(context, 6, &min_size)); auto cuda_stream = GetGpuStream(context); size_t cub_sort_temp_storage_bytes = 0; @@ -391,8 +386,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { f4_ptr, int_ptr, image_stride * num_anchors, cuda_stream)); Tensor d_conv_layer_indexes; // box indices on device Tensor d_image_offset; // starting offsets boxes for each image - Tensor d_cub_sort_buffer; // buffer for cub sorting - Tensor d_cub_select_buffer; // buffer for cub selection + Tensor d_cub_temp_buffer; // buffer for cub sorting Tensor d_sorted_conv_layer_indexes; // output of cub sorting, indices of // the sorted boxes Tensor dev_sorted_scores; // sorted scores, cub output @@ -400,14 +394,15 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { Tensor dev_boxes_keep_flags; // bitmask for keeping the boxes or rejecting // from output const int nboxes_to_generate = std::min(conv_layer_nboxes, pre_nms_topn); + size_t cub_temp_storage_bytes=max(cub_sort_temp_storage_bytes,cub_select_temp_storage_bytes); OP_REQUIRES_OK( context, AllocateGenerationTempTensors( - context, &d_conv_layer_indexes, &d_image_offset, &d_cub_sort_buffer, - &d_cub_select_buffer, &d_sorted_conv_layer_indexes, + context, &d_conv_layer_indexes, &d_image_offset, &d_cub_temp_buffer, + &d_sorted_conv_layer_indexes, &dev_sorted_scores, &dev_boxes, &dev_boxes_keep_flags, num_images, - conv_layer_nboxes, cub_sort_temp_storage_bytes, - cub_select_temp_storage_bytes, nboxes_to_generate, box_dim)); + conv_layer_nboxes, cub_temp_storage_bytes, + nboxes_to_generate, box_dim)); const GPUDevice& d = context->eigen_device(); Cuda2DLaunchConfig conf2d = GetCuda2DLaunchConfig(conv_layer_nboxes, num_images, d); @@ -419,11 +414,10 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // sort boxes with their scores. // d_sorted_conv_layer_indexes will hold the pointers to old indices. - TF_OP_REQUIRES_CUDA_SUCCESS( context, cub::DeviceSegmentedRadixSort::SortPairsDescending( - d_cub_sort_buffer.flat().data(), cub_sort_temp_storage_bytes, + d_cub_temp_buffer.flat().data(), cub_temp_storage_bytes, scores.flat().data(), dev_sorted_scores.flat().data(), d_conv_layer_indexes.flat().data(), d_sorted_conv_layer_indexes.flat().data(), @@ -466,8 +460,8 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // get the pointers for temp storages int* d_prenms_nboxes = dev_prenms_nboxes.flat().data(); int h_prenms_nboxes = 0; - char* d_cub_select_temp_storage = - (char*)d_cub_select_buffer.flat().data(); + char* d_cub_temp_storage = + (char*)d_cub_temp_buffer.flat().data(); float* d_image_prenms_boxes = dev_image_prenms_boxes.flat().data(); float* d_image_prenms_scores = dev_image_prenms_scores.flat().data(); int* d_image_boxes_keep_list = dev_image_boxes_keep_list.flat().data(); @@ -518,15 +512,14 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // to the output tensors TF_OP_REQUIRES_CUDA_SUCCESS( context, cub::DeviceSelect::Flagged( - d_cub_select_temp_storage, cub_select_temp_storage_bytes, + d_cub_temp_storage, cub_temp_storage_bytes, reinterpret_cast(d_image_boxes), d_image_boxes_keep_flags, reinterpret_cast(d_image_prenms_boxes), d_prenms_nboxes, nboxes_generated, d.stream())); - TF_OP_REQUIRES_CUDA_SUCCESS( context, cub::DeviceSelect::Flagged( - d_cub_select_temp_storage, cub_select_temp_storage_bytes, + d_cub_temp_storage, cub_temp_storage_bytes, d_image_sorted_scores, d_image_boxes_keep_flags, d_image_prenms_scores, d_prenms_nboxes, nboxes_generated, d.stream())); @@ -564,8 +557,11 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { float bbox_xform_clip_default_; }; -REGISTER_KERNEL_BUILDER( - Name("GenerateBoundingBoxProposals").Device(tensorflow::DEVICE_GPU), - tensorflow::GenerateBoundingBoxProposals); +REGISTER_KERNEL_BUILDER(Name("GenerateBoundingBoxProposals") + .Device(tensorflow::DEVICE_GPU) + .HostMemory("nms_threshold") + .HostMemory("min_size") + .HostMemory("pre_nms_topn"), + tensorflow::GenerateBoundingBoxProposals); } // namespace tensorflow #endif \ No newline at end of file diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index a145abb94ad..794b996d46f 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -4010,7 +4010,6 @@ def generate_bounding_box_proposals(scores, pre_nms_topn=6000, min_size=16, post_nms_topn=300, - correct_transform_coords=True, name=None): """ Generate bounding box proposals from encoded bounding boxes. Returns: @@ -4018,5 +4017,12 @@ def generate_bounding_box_proposals(scores, roi_probabilities: scores of the roi boxes in the rois tensor. """ return gen_image_ops.generate_bounding_box_proposals( - scores, bbox_deltas, image_info, anchors, nms_threshold, pre_nms_topn, - min_size, post_nms_topn, correct_transform_coords) + scores=scores, + bbox_deltas=bbox_deltas, + image_info=image_info, + anchors=anchors, + nms_threshold=nms_threshold, + pre_nms_topn=pre_nms_topn, + min_size=min_size, + post_nms_topn=post_nms_topn, + name=Name) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt index ef946d406cd..70a6c7d966a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt @@ -110,7 +110,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'0.7\', \'6000\', \'16\', \'300\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'0.7\', \'6000\', \'16\', \'300\', \'None\'], " } member_method { name: "grayscale_to_rgb" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 8e7d068cd1c..bdccd5b436c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1340,10 +1340,6 @@ tf_module { name: "gather_nd" argspec: "args=[\'params\', \'indices\', \'name\', \'batch_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'0\'], " } - member_method { - name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], " - } member_method { name: "get_collection" argspec: "args=[\'key\', \'scope\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt index dae98dc3ed7..5436714c9d5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt @@ -106,7 +106,7 @@ tf_module { } member_method { name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'correct_transform_coords\', \'name\'], varargs=None, keywords=None, defaults=[\'0.7\', \'6000\', \'16\', \'300\', \'True\', \'None\'], " + argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'0.7\', \'6000\', \'16\', \'300\', \'None\'], " } member_method { name: "grayscale_to_rgb" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 4d4cd74c213..ee3c0cc22bb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -660,10 +660,6 @@ tf_module { name: "gather_nd" argspec: "args=[\'params\', \'indices\', \'batch_dims\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], " } - member_method { - name: "generate_bounding_box_proposals" - argspec: "args=[\'scores\', \'bbox_deltas\', \'image_info\', \'anchors\', \'nms_threshold\', \'pre_nms_topn\', \'min_size\', \'post_nms_topn\', \'name\'], varargs=None, keywords=None, defaults=[\'300\', \'None\'], " - } member_method { name: "get_logger" argspec: "args=[], varargs=None, keywords=None, defaults=None" From 4dc3be29621bac333cb4a46135aef0a46ec5d5f0 Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 24 Oct 2019 15:13:30 -0700 Subject: [PATCH 11/17] Fix buildifier issue --- tensorflow/core/kernels/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 082452ff3cf..1b3288df338 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2905,7 +2905,10 @@ tf_kernel_library( tf_kernel_library( name = "generate_box_proposals_op", gpu_srcs = ["generate_box_proposals_op.cu.cc"], - deps = if_cuda(["@cub_archive//:cub",":non_max_suppression_op_gpu"]), + deps = if_cuda([ + "@cub_archive//:cub", + ":non_max_suppression_op_gpu", + ]), ) tf_kernel_library( From dfc8a1e7f366becf7e988d008e0743424741bb18 Mon Sep 17 00:00:00 2001 From: Sami Date: Fri, 1 Nov 2019 15:50:44 -0700 Subject: [PATCH 12/17] Replaced CHECK_OK calls with context returns even though most of the failures were already FATAL. --- .../kernels/generate_box_proposals_op.cu.cc | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index ae988e9892f..e0c76be40bf 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -195,11 +195,11 @@ __global__ void WriteUprightBoxesOutput( } template -void ResetTensor(Tensor* t, const Eigen::GpuDevice& d) { +Status ResetTensor(Tensor* t, const Eigen::GpuDevice& d) { CudaLaunchConfig zconfig = GetCudaLaunchConfig(t->NumElements(), d); - TF_CHECK_OK(GpuLaunchKernel( + return GpuLaunchKernel( SetZero, zconfig.block_count, zconfig.thread_per_block, 0, d.stream(), - zconfig.virtual_thread_count, (*t).flat().data())); + zconfig.virtual_thread_count, (*t).flat().data()); } // Allocate scratch spaces that are needed for operation // @@ -216,29 +216,29 @@ Status AllocateGenerationTempTensors( TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), d_conv_layer_indexes)); - ResetTensor(d_conv_layer_indexes, d); + TF_RETURN_IF_ERROR(ResetTensor(d_conv_layer_indexes, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images + 1}), d_image_offset)); - ResetTensor(d_image_offset, d); + TF_RETURN_IF_ERROR(ResetTensor(d_image_offset, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT8, TensorShape({(int64)cub_temp_storage_bytes}), d_cub_temp_buffer)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images, conv_layer_nboxes}), d_sorted_conv_layer_indexes)); - ResetTensor(d_sorted_conv_layer_indexes, d); + TF_RETURN_IF_ERROR(ResetTensor(d_sorted_conv_layer_indexes, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_images, conv_layer_nboxes}), d_sorted_scores)); - ResetTensor(d_sorted_scores, d); + TF_RETURN_IF_ERROR(ResetTensor(d_sorted_scores, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_images, box_dim * num_boxes_to_generate}), dev_boxes)); - ResetTensor(dev_boxes, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_boxes, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT8, TensorShape({num_images, num_boxes_to_generate}), dev_boxes_keep_flags)); - ResetTensor(dev_boxes_keep_flags, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_boxes_keep_flags, d)); return Status::OK(); } @@ -253,33 +253,33 @@ Status AllocatePreNMSTempTensors( TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({box_dim * num_boxes_to_generate}), dev_image_prenms_boxes)); - ResetTensor(dev_image_prenms_boxes, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_image_prenms_boxes, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_boxes_to_generate}), dev_image_prenms_scores)); - ResetTensor(dev_image_prenms_scores, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_image_prenms_scores, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_boxes_to_generate}), dev_image_boxes_keep_list)); - ResetTensor(dev_image_boxes_keep_list, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_image_boxes_keep_list, d)); const int max_postnms_nboxes = std::min(num_boxes_to_generate, post_nms_topn); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({box_dim * num_images * max_postnms_nboxes}), dev_postnms_rois)); - ResetTensor(dev_postnms_rois, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_postnms_rois, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_images * max_postnms_nboxes}), dev_postnms_rois_probs)); - ResetTensor(dev_postnms_rois_probs, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_postnms_rois_probs, d)); TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_INT32, TensorShape({num_images}), dev_prenms_nboxes)); - ResetTensor(dev_prenms_nboxes, d); + TF_RETURN_IF_ERROR(ResetTensor(dev_prenms_nboxes, d)); return Status::OK(); } @@ -313,7 +313,8 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { tensorflow::OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("post_nms_topn", &post_nms_topn_)); - CHECK_GT(post_nms_topn_, 0); + OP_REQUIRES(context, post_nms_topn_<=0, + errors::InvalidArgument("post_nms_topn can't be 0 or less")); bbox_xform_clip_default_ = log(1000.0 / 16.); } @@ -407,7 +408,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { Cuda2DLaunchConfig conf2d = GetCuda2DLaunchConfig(conv_layer_nboxes, num_images, d); // create box indices and offsets for each image on device - TF_CHECK_OK(GpuLaunchKernel(InitializeDataKernel, conf2d.block_count, + OP_REQUIRES_OK(context,GpuLaunchKernel(InitializeDataKernel, conf2d.block_count, conf2d.thread_per_block, 0, d.stream(), conf2d, d_image_offset.flat().data(), d_conv_layer_indexes.flat().data())); @@ -431,7 +432,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // create box y1,x1,y2,x2 from box_deltas and anchors (decode the boxes) and // mark the boxes which are smaller that min_size ignored. - TF_CHECK_OK(GpuLaunchKernel( + OP_REQUIRES_OK(context,GpuLaunchKernel( GeneratePreNMSUprightBoxesKernel, conf2d.block_count, conf2d.thread_per_block, 0, d.stream(), conf2d, d_sorted_conv_layer_indexes.flat().data(), @@ -490,7 +491,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // Do per-image nms for (int image_index = 0; image_index < num_images; ++image_index) { // reset output workspaces - ResetTensor(&dev_image_boxes_keep_list, d); + OP_REQUIRES_OK(context,ResetTensor(&dev_image_boxes_keep_list, d)); // Sub matrices for current image // boxes const float* d_image_boxes = @@ -541,7 +542,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { // adding the image_index dimension on the fly CudaLaunchConfig config = GetCudaLaunchConfig(post_nms_topn_, d); // make this single kernel - TF_CHECK_OK(GpuLaunchKernel( + OP_REQUIRES_OK(context,GpuLaunchKernel( WriteUprightBoxesOutput, config.block_count, config.thread_per_block, 0, d.stream(), config, reinterpret_cast(d_image_prenms_boxes), From 920b543dd0de83823efc5ad9dae5d5d4733ec79e Mon Sep 17 00:00:00 2001 From: Sami Date: Tue, 5 Nov 2019 10:21:41 -0800 Subject: [PATCH 13/17] Mark GenerateBoxProposalsOp non-differentiable --- tensorflow/python/ops/image_ops_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 794b996d46f..6713059ecb3 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -54,7 +54,7 @@ ops.NotDifferentiable('ExtractGlimpse') ops.NotDifferentiable('NonMaxSuppression') ops.NotDifferentiable('NonMaxSuppressionV2') ops.NotDifferentiable('NonMaxSuppressionWithOverlaps') - +ops.NotDifferentiable('GenerateBoundingBoxProposals') # pylint: disable=invalid-name def _assert(cond, ex_type, msg): From 05aa39e2fa9d5ff2bdcce61b90d70b653ae763f8 Mon Sep 17 00:00:00 2001 From: Sami Date: Wed, 6 Nov 2019 15:24:32 -0800 Subject: [PATCH 14/17] max->std::max --- tensorflow/core/kernels/generate_box_proposals_op.cu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index e0c76be40bf..269ad743147 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -395,7 +395,7 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { Tensor dev_boxes_keep_flags; // bitmask for keeping the boxes or rejecting // from output const int nboxes_to_generate = std::min(conv_layer_nboxes, pre_nms_topn); - size_t cub_temp_storage_bytes=max(cub_sort_temp_storage_bytes,cub_select_temp_storage_bytes); + size_t cub_temp_storage_bytes=std::max(cub_sort_temp_storage_bytes,cub_select_temp_storage_bytes); OP_REQUIRES_OK( context, AllocateGenerationTempTensors( From 25e9335a234c6517865d146105fad553e5a24b4e Mon Sep 17 00:00:00 2001 From: Guangda Lai <31743510+aaroey@users.noreply.github.com> Date: Wed, 6 Nov 2019 22:03:14 -0800 Subject: [PATCH 15/17] Fix error 'unused variable 'max_postnms_nboxes'' --- tensorflow/core/kernels/generate_box_proposals_op.cu.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index 269ad743147..fed27f0c27c 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -444,7 +444,6 @@ class GenerateBoundingBoxProposals : public tensorflow::OpKernel { nboxes_to_generate, (char*)dev_boxes_keep_flags.flat().data())); const int nboxes_generated = nboxes_to_generate; const int roi_cols = box_dim; - const int max_postnms_nboxes = std::min(nboxes_generated, post_nms_topn_); Tensor dev_image_prenms_boxes; Tensor dev_image_prenms_scores; Tensor dev_image_boxes_keep_list; @@ -565,4 +564,4 @@ REGISTER_KERNEL_BUILDER(Name("GenerateBoundingBoxProposals") .HostMemory("pre_nms_topn"), tensorflow::GenerateBoundingBoxProposals); } // namespace tensorflow -#endif \ No newline at end of file +#endif From fbd53c577a51538239e5392756060d41f7e15940 Mon Sep 17 00:00:00 2001 From: Guangda Lai <31743510+aaroey@users.noreply.github.com> Date: Wed, 6 Nov 2019 22:03:25 -0800 Subject: [PATCH 16/17] Fix error 'unused variable 'max_postnms_nboxes'' From b0f61d8fbbc28a09de0017b93171ad99fb3a30be Mon Sep 17 00:00:00 2001 From: Guangda Lai <31743510+aaroey@users.noreply.github.com> Date: Thu, 7 Nov 2019 07:44:46 -0800 Subject: [PATCH 17/17] Add some changes to trigger copybara import again. --- tensorflow/core/kernels/generate_box_proposals_op.cu.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index fed27f0c27c..344d71bfc59 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An example Op. - #if GOOGLE_CUDA #define EIGEN_USE_GPU @@ -82,8 +80,7 @@ __global__ void GeneratePreNMSUprightBoxesKernel( // box_conv_index : # of the same box, but indexed in the // scores from the conv layer, of shape (height,width,num_anchors) the // num_images dimension was already removed box_conv_index = - // a*image_stride + h*width - // + w + // a*image_stride + h*width + w const int box_conv_index = d_sorted_scores_keys[image_index * image_stride + ibox];