From c2a4931076e69bff99d9f6e40ae607c5335c1e6a Mon Sep 17 00:00:00 2001 From: Sami Date: Thu, 22 Aug 2019 18:47:03 -0700 Subject: [PATCH] Pure GPU NMS implementation. --- .../core/kernels/non_max_suppression_op.cu.cc | 183 +++++++++++------- .../core/kernels/non_max_suppression_op.h | 5 +- 2 files changed, 117 insertions(+), 71 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 81f458b0326..79d3ea71c8a 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -128,6 +128,47 @@ __device__ EIGEN_STRONG_INLINE void Flipped(Box& box) { if (box.x1 > box.x2) Swap(box.x1, box.x2); if (box.y1 > box.y2) Swap(box.y1, box.y2); } +template +__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) { + constexpr int SHIFTLEN = NumBits(8 * sizeof(T)) - 1; + constexpr int REMAINDER_MASK = 8 * sizeof(T) - 1; + int bin = bit >> SHIFTLEN; + return (bit_mask[bin] >> (bit & REMAINDER_MASK)) & 1; +} + +// Produce a global bitmask (result_mask) of selected boxes from bitmask +// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask +// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box. +__global__ void NMSReduce(const int* bitmask, const int bit_mask_len, + const int num_boxes, const int max_boxes, + char* result_mask) { + extern __shared__ int local[]; + // set global mask to accept all boxes + for (int box : CudaGridRangeX(bit_mask_len)) { + local[box] = 0xFFFFFFFF; + } + __syncthreads(); + int accepted_boxes = 0; + for (int box = 0; box < num_boxes - 1; ++box) { + // if current box is masked by an earlier box, skip it. + if (!CheckBit(local, box)) { + continue; + } + accepted_boxes += 1; + int offset = box * bit_mask_len; + // update global mask with current box's mask + for (int b : CudaGridRangeX(bit_mask_len)) { + local[b] &= ~bitmask[offset + b]; + } + __syncthreads(); + if (accepted_boxes > max_boxes) break; + } + // copy global mask to result_max char array. char array is needed for + // cub::DeviceSelect later. + for (int box : CudaGridRangeX(num_boxes)) { + result_mask[box] = CheckBit(local, box); + } +} // For each box, compute a bitmask of boxes which has an overlap with given box // above threshold. @@ -235,7 +276,8 @@ __global__ void Iota(const int num_elements, const T offset, T* to_fill) { Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_nkeep, - OpKernelContext* context, bool flip_boxes, bool legacy_mode) { + OpKernelContext* context, const int max_boxes, bool flip_boxes, + bool legacy_mode) { // Making sure we respect the __align(16)__ // we promised to the compiler. auto iptr = reinterpret_cast(d_sorted_boxes_float_ptr); @@ -243,7 +285,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, return errors::InvalidArgument("Boxes should be aligned to 16 Bytes."); } // allocate bitmask arrays on host and on device - Tensor h_nms_mask, d_nms_mask; + Tensor h_num_selected, d_nms_mask; const int bit_mask_len = (num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread; @@ -264,11 +306,11 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, // Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and // using it as a ring buffer. However savings should be a few MB . TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, - TensorShape({max_nms_mask_size}), - &h_nms_mask, alloc_attr)); + TensorShape({1}), + &h_num_selected, alloc_attr)); int* d_delete_mask = d_nms_mask.flat().data(); - int* h_delete_mask = h_nms_mask.flat().data(); + int* h_selected_count = h_num_selected.flat().data(); const Box* d_sorted_boxes = reinterpret_cast(d_sorted_boxes_float_ptr); dim3 block_dim, thread_block; @@ -308,59 +350,57 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); // Overlapping CPU computes and D2H memcpy // both take about the same time - int num_to_copy = std::min(kNmsChunkSize, num_boxes); + + config = GetGpuLaunchConfig(num_boxes, device); + Tensor selected_boxes; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes)); + Tensor d_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_indices)); + TF_CHECK_OK(GpuLaunchKernel(Iota, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, 0, + d_indices.flat().data())); + + char* selected = (char*)(selected_boxes.flat().data()); + TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int), + device.stream(), d_delete_mask, bit_mask_len, + num_boxes, max_boxes, selected)); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + // do Cub::deviceSelect::flagged + size_t flagged_buffer_size = 0; + cub::DeviceSelect::Flagged(static_cast(nullptr), // temp_storage + flagged_buffer_size, + static_cast(nullptr), // input + static_cast(nullptr), // selection flag + static_cast(nullptr), // selected items + static_cast(nullptr), // num_selected + num_boxes, device.stream()); + Tensor cub_scratch; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}), + &cub_scratch)); + Tensor d_num_selected; + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({1}), &d_num_selected)); + + cub::DeviceSelect::Flagged( + (void*)cub_scratch.flat().data(), // temp_storage + flagged_buffer_size, + d_indices.flat().data(), // input + selected, // selection flag + d_selected_indices, // selected items + d_num_selected.flat().data(), num_boxes, device.stream()); cudaEvent_t copy_done; TF_RETURN_IF_CUDA_ERROR( cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); - device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0], - num_to_copy * bit_mask_len * sizeof(int)); + device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat().data(), + sizeof(int)); TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); - int offset = 0; - std::vector h_selected_indices; - // Reserve worst case scenario. Since box count is not huge, this should have - // negligible footprint. - h_selected_indices.reserve(num_boxes); - std::vector to_remove(bit_mask_len, 0); - while (offset < num_boxes) { - const int num_copied = num_to_copy; - int next_offset = offset + num_copied; - num_to_copy = std::min(kNmsChunkSize, num_boxes - next_offset); - if (num_to_copy > 0) { - device.memcpyDeviceToHost(&h_delete_mask[next_offset * bit_mask_len], - &d_delete_mask[next_offset * bit_mask_len], - num_to_copy * bit_mask_len * sizeof(int)); - } - // Waiting for previous copy - TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); - if (num_to_copy > 0) { - TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); - } - // Starting from highest scoring box, mark any box with iou>threshold and - // lower score for deletion if current box is not marked for deletion. Add - // current box to to_keep list. - for (int i = offset; i < next_offset; ++i) { - // See the comment at the beginning of the file. - // Bit shift and logical And operations are used - // instead of division and modulo operations. - int iblock = i >> kNmsBoxesPerThreadShiftBits; - int inblock = i & kNmsBoxesPerThreadModuloMask; - if (!(to_remove[iblock] & (1 << inblock))) { - h_selected_indices.push_back(i); - int* p = &h_delete_mask[i * bit_mask_len]; - for (int ib = 0; ib < bit_mask_len; ++ib) { - to_remove[ib] |= p[ib]; - } - } - } - offset = next_offset; - } + TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); + *h_nkeep = *h_selected_count; cudaEventDestroy(copy_done); - - const int nkeep = h_selected_indices.size(); - device.memcpyHostToDevice(d_selected_indices, &h_selected_indices[0], - nkeep * sizeof(int)); - - *h_nkeep = nkeep; return Status::OK(); } @@ -485,10 +525,10 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { // There is no guarantee that boxes are given in the for x1().data(), num_boxes, - iou_threshold_val, d_selected_indices.flat().data(), - &num_to_keep, context, flip_boxes, /*legacy_mode*/ false); + auto status = NmsGpu( + d_sorted_boxes.flat().data(), num_boxes, iou_threshold_val, + d_selected_indices.flat().data(), &num_to_keep, context, + output_size, flip_boxes, /*legacy_mode*/ false); TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); if (!status.ok()) { context->SetStatus(status); @@ -535,13 +575,10 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op, TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output)); - TF_RETURN_IF_ERROR( - context, - context->allocate_temp(DataType::DT_INT8, - TensorShape({(int64)workspace_size}), &workspace)); - TF_RETURN_IF_ERROR( - context, context->allocate_temp(DataType::DT_INT32, TensorShape({1}), - &element_count)); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace)); + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({1}), &element_count)); cudaEvent_t copy_done; TF_RETURN_IF_CUDA_ERROR( cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); @@ -697,15 +734,18 @@ class NonMaxSuppressionV3GPUOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), &output_indices)); return; + } else { + VLOG(2) << "Number of boxes above threshold=" << score_threshold_val + << " is " << limited_num_boxes; } int num_to_keep = 0; // There is no guarantee that boxes are given in the for x1().data(), limited_num_boxes, - iou_threshold_val, d_selected_indices.flat().data(), - &num_to_keep, context, flip_boxes, /*legacy_mode*/ false); + auto status = NmsGpu( + d_sorted_boxes.flat().data(), limited_num_boxes, + iou_threshold_val, d_selected_indices.flat().data(), &num_to_keep, + context, output_size, flip_boxes, /*legacy_mode*/ false); TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); if (!status.ok()) { context->SetStatus(status); @@ -716,7 +756,12 @@ class NonMaxSuppressionV3GPUOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({num_outputs}), &output_indices)); - if (num_outputs == 0) return; + if (num_outputs == 0) { + VLOG(1) << "No outputs!"; + return; + } else { + VLOG(2) << "Num outputs= " << num_outputs; + } config = GetGpuLaunchConfig(num_outputs, device); TF_CHECK_OK(GpuLaunchKernel( IndexMultiSelect, config.block_count, config.thread_per_block, diff --git a/tensorflow/core/kernels/non_max_suppression_op.h b/tensorflow/core/kernels/non_max_suppression_op.h index bf1c97e66dd..fbf4dbfcd1a 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.h +++ b/tensorflow/core/kernels/non_max_suppression_op.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_ #define TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_ -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { @@ -54,7 +54,8 @@ extern const int kNmsBoxesPerTread; Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_num_boxes_to_keep, OpKernelContext* context, - bool flip_boxes = false,bool legacy_mode=false); + const int max_boxes, bool flip_boxes = false, + bool legacy_mode = false); #endif } // namespace tensorflow