diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 50fdcd76b35..1749a66579b 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -15,6 +15,8 @@ limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU +#include + #include "absl/strings/str_cat.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/cub/device/device_radix_sort.cuh" @@ -82,10 +84,9 @@ __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) { // Check whether two boxes have an IoU greater than threshold. template -__device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* __restrict__ a, - const Box* __restrict__ b, - float a_area, - T iou_threshold) { +__device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b, + const float a_area, + const T iou_threshold) { const float b_area = (b->x2 - b->x1) * (b->y2 - b->y1); if (a_area == 0.0f || b_area == 0.0f) return false; const float xx1 = fmaxf(a->x1, b->x1); @@ -94,8 +95,8 @@ __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* __restrict__ a, const float yy2 = fminf(a->y2, b->y2); // fdimf computes the positive difference between xx2+1 and xx1. - const float w = fdimf(xx2 + 1.0f, xx1); - const float h = fdimf(yy2 + 1.0f, yy1); + const float w = fdimf(xx2, xx1); + const float h = fdimf(yy2, yy1); const float intersection = w * h; // Testing for aa/bb > t @@ -118,6 +119,47 @@ __device__ EIGEN_STRONG_INLINE void Flipped(Box& box) { if (box.x1 > box.x2) Swap(box.x1, box.x2); if (box.y1 > box.y2) Swap(box.y1, box.y2); } +template +__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) { + constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1; + constexpr int kRemainderMask = 8 * sizeof(T) - 1; + int bin = bit >> kShiftLen; + return (bit_mask[bin] >> (bit & kRemainderMask)) & 1; +} + +// Produce a global bitmask (result_mask) of selected boxes from bitmask +// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask +// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box. +__global__ void NMSReduce(const int* bitmask, const int bit_mask_len, + const int num_boxes, const int max_boxes, + char* result_mask) { + extern __shared__ int local[]; + // set global mask to accept all boxes + for (int box : CudaGridRangeX(bit_mask_len)) { + local[box] = 0xFFFFFFFF; + } + __syncthreads(); + int accepted_boxes = 0; + for (int box = 0; box < num_boxes - 1; ++box) { + // if current box is masked by an earlier box, skip it. + if (!CheckBit(local, box)) { + continue; + } + accepted_boxes += 1; + int offset = box * bit_mask_len; + // update global mask with current box's mask + for (int b : CudaGridRangeX(bit_mask_len)) { + local[b] &= ~bitmask[offset + b]; + } + __syncthreads(); + if (accepted_boxes > max_boxes) break; + } + // copy global mask to result_max char array. char array is needed for + // cub::DeviceSelect later. + for (int box : CudaGridRangeX(num_boxes)) { + result_mask[box] = CheckBit(local, box); + } +} // For each box, compute a bitmask of boxes which has an overlap with given box // above threshold. @@ -131,9 +173,9 @@ __device__ EIGEN_STRONG_INLINE void Flipped(Box& box) { // x1 __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ - void NMSKernel(const Box* __restrict__ d_desc_sorted_boxes, - const int num_boxes, const float iou_threshold, - const int bit_mask_len, int* __restrict__ d_delete_mask) { + void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes, + const float iou_threshold, const int bit_mask_len, + int* d_delete_mask) { // Storing boxes used by this CUDA block in the shared memory. __shared__ Box shared_i_boxes[kNmsBlockDim]; // Same thing with areas @@ -173,8 +215,8 @@ __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ Box j_box = d_desc_sorted_boxes[j]; const Box i_box = shared_i_boxes[threadIdx.x]; Flipped(j_box); - if (OverThreshold(&i_box, &j_box, shared_i_areas[threadIdx.x], - iou_threshold)) { + if (OverThreshold(&i_box, &j_box, shared_i_areas[threadIdx.x], + iou_threshold)) { // we have score[j] <= score[i]. above_threshold |= (1U << ib); } @@ -196,8 +238,7 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, template __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, const Index i_original, - const T* __restrict__ original, - T* __restrict__ selected, + const T* original, T* selected, Args... args) { selected[i_selected] = original[i_original]; SelectHelper(i_selected, i_original, args...); @@ -210,18 +251,15 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, // IndexMultiSelect(num_elements, indices, original1 ,selected1, original2, // selected2). template -__global__ void IndexMultiSelect(const int num_elements, - const Index* __restrict__ indices, - const T* __restrict__ original, - T* __restrict__ selected, Args... args) { +__global__ void IndexMultiSelect(const int num_elements, const Index* indices, + const T* original, T* selected, Args... args) { for (const int idx : CudaGridRangeX(num_elements)) { SelectHelper(idx, indices[idx], original, selected, args...); } } template -__global__ void Iota(const int num_elements, const T offset, - T* __restrict__ to_fill) { +__global__ void Iota(const int num_elements, const T offset, T* to_fill) { for (int idx : CudaGridRangeX(num_elements)) { to_fill[idx] = static_cast(idx) + offset; } @@ -229,7 +267,7 @@ __global__ void Iota(const int num_elements, const T offset, Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_nkeep, - OpKernelContext* context, bool flip_boxes) { + OpKernelContext* context, const int max_boxes, bool flip_boxes) { // Making sure we respect the __align(16)__ // we promised to the compiler. auto iptr = reinterpret_cast(d_sorted_boxes_float_ptr); @@ -237,7 +275,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, return errors::InvalidArgument("Boxes should be aligned to 16 Bytes."); } // allocate bitmask arrays on host and on device - Tensor h_nms_mask, d_nms_mask; + Tensor h_num_selected, d_nms_mask; const int bit_mask_len = (num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread; @@ -257,12 +295,11 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, alloc_attr.set_gpu_compatible(true); // Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and // using it as a ring buffer. However savings should be a few MB . - TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, - TensorShape({max_nms_mask_size}), - &h_nms_mask, alloc_attr)); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({1}), &h_num_selected, alloc_attr)); int* d_delete_mask = d_nms_mask.flat().data(); - int* h_delete_mask = h_nms_mask.flat().data(); + int* h_selected_count = h_num_selected.flat().data(); const Box* d_sorted_boxes = reinterpret_cast(d_sorted_boxes_float_ptr); dim3 block_dim, thread_block; @@ -286,58 +323,222 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); // Overlapping CPU computes and D2H memcpy // both take about the same time - int num_to_copy = std::min(kNmsChunkSize, num_boxes); + + config = GetGpuLaunchConfig(num_boxes, device); + Tensor selected_boxes; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes)); + Tensor d_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_indices)); + TF_CHECK_OK(GpuLaunchKernel(Iota, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, 0, + d_indices.flat().data())); + + char* selected = (char*)(selected_boxes.flat().data()); + TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int), + device.stream(), d_delete_mask, bit_mask_len, + num_boxes, max_boxes, selected)); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + // do Cub::deviceSelect::flagged + size_t flagged_buffer_size = 0; + cub::DeviceSelect::Flagged(static_cast(nullptr), // temp_storage + flagged_buffer_size, + static_cast(nullptr), // input + static_cast(nullptr), // selection flag + static_cast(nullptr), // selected items + static_cast(nullptr), // num_selected + num_boxes, device.stream()); + Tensor cub_scratch; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}), + &cub_scratch)); + Tensor d_num_selected; + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({1}), &d_num_selected)); + + cub::DeviceSelect::Flagged( + (void*)cub_scratch.flat().data(), // temp_storage + flagged_buffer_size, + d_indices.flat().data(), // input + selected, // selection flag + d_selected_indices, // selected items + d_num_selected.flat().data(), num_boxes, device.stream()); cudaEvent_t copy_done; - cudaEventCreate(©_done); - device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0], - num_to_copy * bit_mask_len * sizeof(int)); + TF_RETURN_IF_CUDA_ERROR( + cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); + device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat().data(), + sizeof(int)); TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); - int offset = 0; - std::vector h_selected_indices; - // Reserve worst case scenario. Since box count is not huge, this should have - // negligible footprint. - h_selected_indices.reserve(num_boxes); - std::vector to_remove(bit_mask_len, 0); - while (offset < num_boxes) { - const int num_copied = num_to_copy; - int next_offset = offset + num_copied; - num_to_copy = std::min(kNmsChunkSize, num_boxes - next_offset); - if (num_to_copy > 0) { - device.memcpyDeviceToHost(&h_delete_mask[next_offset * bit_mask_len], - &d_delete_mask[next_offset * bit_mask_len], - num_to_copy * bit_mask_len * sizeof(int)); - } - // Waiting for previous copy - TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); - if (num_to_copy > 0) { - TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); - } - // Starting from highest scoring box, mark any box with iou>threshold and - // lower score for deletion if current box is not marked for deletion. Add - // current box to to_keep list. - for (int i = offset; i < next_offset; ++i) { - // See the comment at the beginning of the file. - // Bit shift and logical And operations are used - // instead of division and modulo operations. - int iblock = i >> kNmsBoxesPerThreadShiftBits; - int inblock = i & kNmsBoxesPerThreadModuloMask; - if (!(to_remove[iblock] & (1 << inblock))) { - h_selected_indices.push_back(i); - int* p = &h_delete_mask[i * bit_mask_len]; - for (int ib = 0; ib < bit_mask_len; ++ib) { - to_remove[ib] |= p[ib]; - } - } - } - offset = next_offset; - } + TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); + *h_nkeep = *h_selected_count; cudaEventDestroy(copy_done); + return Status::OK(); +} - const int nkeep = h_selected_indices.size(); - device.memcpyHostToDevice(d_selected_indices, &h_selected_indices[0], - nkeep * sizeof(int)); +struct GreaterThanCubOp { + float threshold_; + __host__ __device__ __forceinline__ GreaterThanCubOp(float threshold) + : threshold_(threshold) {} + __host__ __device__ __forceinline__ bool operator()(const float& val) const { + return (val > threshold_); + } +}; +// Use DeviceSelect::If to count number of elements. +// TODO(sami) Not really a good way. Perhaps consider using thrust? +template +Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op, + int num_elements, int* result) { + Tensor scratch_output; + Tensor workspace; + Tensor element_count; + size_t workspace_size = 0; + auto cuda_stream = tensorflow::GetGpuStream(context); + auto device = context->eigen_gpu_device(); + cub::DeviceSelect::If(nullptr, workspace_size, static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), num_elements, op); - *h_nkeep = nkeep; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output)); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace)); + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({1}), &element_count)); + cudaEvent_t copy_done; + TF_RETURN_IF_CUDA_ERROR( + cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); + TF_RETURN_IF_CUDA_ERROR(cub::DeviceSelect::If( + workspace.flat().data(), workspace_size, dev_array, + scratch_output.flat().data(), element_count.flat().data(), + num_elements, op, cuda_stream)); + device.memcpyDeviceToHost(result, element_count.flat().data(), + sizeof(int)); + TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); + TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); + return Status::OK(); +} + +Status DoNMS(OpKernelContext* context, const Tensor& boxes, + const Tensor& scores, const int64_t max_output_size, + const float iou_threshold_val, const float score_threshold) { + const int output_size = max_output_size; + int num_boxes = boxes.dim_size(0); + size_t cub_sort_temp_storage_bytes = 0; + auto cuda_stream = GetGpuStream(context); + auto device = context->eigen_gpu_device(); + // Calling cub with nullptrs as inputs will make it return + // workspace size needed for the operation instead of doing the operation. + // In this specific instance, cub_sort_temp_storage_bytes will contain the + // necessary workspace size for sorting after the call. + if (num_boxes == 0) { + Tensor* output_indices = nullptr; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({0}), &output_indices)); + return Status::OK(); + } + + cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( + nullptr, cub_sort_temp_storage_bytes, + static_cast(nullptr), // scores + static_cast(nullptr), // sorted scores + static_cast(nullptr), // input indices + static_cast(nullptr), // sorted indices + num_boxes, // num items + 0, 8 * sizeof(float), // sort all bits + cuda_stream); + TF_RETURN_IF_CUDA_ERROR(cuda_ret); + Tensor d_cub_sort_buffer; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}), + &d_cub_sort_buffer)); + Tensor d_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_indices)); + Tensor d_sorted_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_sorted_indices)); + Tensor d_selected_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_selected_indices)); + Tensor d_sorted_scores; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_boxes}), &d_sorted_scores)); + Tensor d_sorted_boxes; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_boxes, 4}), &d_sorted_boxes)); + + // this will return sorted scores and their indices + auto config = GetGpuLaunchConfig(num_boxes, device); + // initialize box and score indices + TF_CHECK_OK(GpuLaunchKernel(Iota, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, 0, + d_indices.flat().data())); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + cuda_ret = cub::DeviceRadixSort::SortPairsDescending( + d_cub_sort_buffer.flat().data(), cub_sort_temp_storage_bytes, + scores.flat().data(), d_sorted_scores.flat().data(), + d_indices.flat().data(), d_sorted_indices.flat().data(), + num_boxes, 0, + 8 * sizeof(float), // sort all bits + cuda_stream); + TF_RETURN_IF_CUDA_ERROR(cuda_ret); + + // get pointers for easy access + const float4* original_boxes = + reinterpret_cast(boxes.flat().data()); + float4* sorted_boxes = + reinterpret_cast(d_sorted_boxes.flat().data()); + const int* sorted_indices = d_sorted_indices.flat().data(); + // sort boxes using indices + TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, sorted_indices, + original_boxes, sorted_boxes)); + int limited_num_boxes = num_boxes; + // filter boxes by scores if nms v3 + if (score_threshold > std::numeric_limits::lowest()) { + GreaterThanCubOp score_limit(score_threshold); + TF_RETURN_IF_ERROR(CountIf(context, d_sorted_scores.flat().data(), + score_limit, num_boxes, &limited_num_boxes)); + if (limited_num_boxes == 0) { + Tensor* output_indices = nullptr; + VLOG(1) << "Number of boxes above score threshold " << score_threshold + << " is 0"; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({0}), &output_indices)); + return Status::OK(); + } else { + VLOG(2) << "Number of boxes above threshold=" << score_threshold << " is " + << limited_num_boxes; + } + } + int num_to_keep = 0; + // There is no guarantee that boxes are given in the for x1().data(), limited_num_boxes, + iou_threshold_val, d_selected_indices.flat().data(), + &num_to_keep, context, output_size, flip_boxes); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + if (!status.ok()) { + context->SetStatus(status); + return status; + } + Tensor* output_indices = nullptr; + int num_outputs = std::min(num_to_keep, output_size); // no padding! + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({num_outputs}), &output_indices)); + if (num_outputs == 0) return Status::OK(); + config = GetGpuLaunchConfig(num_outputs, device); + TF_CHECK_OK(GpuLaunchKernel( + IndexMultiSelect, config.block_count, config.thread_per_block, + 0, device.stream(), config.virtual_thread_count, + d_selected_indices.flat().data(), sorted_indices, + (*output_indices).flat().data())); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); return Status::OK(); } @@ -384,112 +585,84 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { &output_indices)); return; } - const int output_size = max_output_size.scalar()(); - size_t cub_sort_temp_storage_bytes = 0; - auto cuda_stream = GetGpuStream(context); - auto device = context->eigen_gpu_device(); - // Calling cub with nullptrs as inputs will make it return - // workspace size needed for the operation instead of doing the operation. - // In this specific instance, cub_sort_temp_storage_bytes will contain the - // necessary workspace size for sorting after the call. - cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( - nullptr, cub_sort_temp_storage_bytes, - static_cast(nullptr), // scores - static_cast(nullptr), // sorted scores - static_cast(nullptr), // input indices - static_cast(nullptr), // sorted indices - num_boxes, // num items - 0, 8 * sizeof(float), // sort all bits - cuda_stream); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); - Tensor d_cub_sort_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp( - DataType::DT_INT8, - TensorShape({(int64)cub_sort_temp_storage_bytes}), - &d_cub_sort_buffer)); - Tensor d_indices; + const int64_t output_size = max_output_size.scalar()(); OP_REQUIRES_OK( - context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), &d_indices)); - Tensor d_sorted_indices; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), - &d_sorted_indices)); - Tensor d_selected_indices; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), - &d_selected_indices)); - Tensor d_sorted_scores; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, - TensorShape({num_boxes}), - &d_sorted_scores)); - Tensor d_sorted_boxes; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, - TensorShape({num_boxes, 4}), - &d_sorted_boxes)); - - // this will return sorted scores and their indices - auto config = GetGpuLaunchConfig(num_boxes, device); - // initialize box and score indices - TF_CHECK_OK(GpuLaunchKernel(Iota, config.block_count, - config.thread_per_block, 0, device.stream(), - config.virtual_thread_count, 0, - d_indices.flat().data())); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - cuda_ret = cub::DeviceRadixSort::SortPairsDescending( - d_cub_sort_buffer.flat().data(), cub_sort_temp_storage_bytes, - scores.flat().data(), d_sorted_scores.flat().data(), - d_indices.flat().data(), d_sorted_indices.flat().data(), - num_boxes, 0, - 8 * sizeof(float), // sort all bits - cuda_stream); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); - - // get pointers for easy access - const float4* original_boxes = - reinterpret_cast(boxes.flat().data()); - float4* sorted_boxes = - reinterpret_cast(d_sorted_boxes.flat().data()); - const int* sorted_indices = d_sorted_indices.flat().data(); - // sort boxes using indices - TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect, - config.block_count, config.thread_per_block, 0, - device.stream(), config.virtual_thread_count, - sorted_indices, original_boxes, sorted_boxes)); - - int num_to_keep = 0; - // There is no guarantee that boxes are given in the for x1().data(), num_boxes, - iou_threshold_val, d_selected_indices.flat().data(), - &num_to_keep, context, flip_boxes); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - if (!status.ok()) { - context->SetStatus(status); - return; - } - Tensor* output_indices = nullptr; - int num_outputs = std::min(num_to_keep, output_size); // no padding! - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_outputs}), - &output_indices)); - if (num_outputs == 0) return; - config = GetGpuLaunchConfig(num_outputs, device); - TF_CHECK_OK(GpuLaunchKernel( - IndexMultiSelect, config.block_count, config.thread_per_block, - 0, device.stream(), config.virtual_thread_count, - d_selected_indices.flat().data(), sorted_indices, - (*output_indices).flat().data())); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); + context, + DoNMS(context, boxes, scores, output_size, iou_threshold_val, + /*score_threshold is float min if score threshold is disabled*/ + std::numeric_limits::lowest())); } }; -REGISTER_KERNEL_BUILDER( - Name("NonMaxSuppressionV2").TypeConstraint("T").Device(DEVICE_GPU), - NonMaxSuppressionV2GPUOp); +class NonMaxSuppressionV3GPUOp : public OpKernel { + public: + explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // boxes: [num_boxes, 4] + const Tensor& boxes = context->input(0); + // scores: [num_boxes] + const Tensor& scores = context->input(1); + // max_output_size: scalar + const Tensor& max_output_size = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_output_size.shape()), + errors::InvalidArgument("max_output_size must be 0-D, got shape ", + max_output_size.shape().DebugString())); + // iou_threshold: scalar + const Tensor& iou_threshold = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); + const float iou_threshold_val = iou_threshold.scalar()(); + + const Tensor& score_threshold = context->input(4); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(score_threshold.shape()), + errors::InvalidArgument("score_threshold must be 0-D, got shape ", + score_threshold.shape().DebugString())); + const float score_threshold_val = score_threshold.scalar()(); + + OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); + OP_REQUIRES(context, boxes.dims() == 2, + errors::InvalidArgument("boxes must be a rank 2 tensor!")); + int num_boxes = boxes.dim_size(0); + OP_REQUIRES(context, boxes.dim_size(1) == 4, + errors::InvalidArgument("boxes must be Nx4")); + OP_REQUIRES(context, scores.dims() == 1, + errors::InvalidArgument("scores must be a vector!")); + OP_REQUIRES( + context, scores.dim_size(0) == num_boxes, + errors::InvalidArgument( + "scores has incompatible shape")); // message must be exactly this + // otherwise tests fail! + if (num_boxes == 0) { + Tensor* output_indices = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), + &output_indices)); + return; + } + const int output_size = max_output_size.scalar()(); + OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size, + iou_threshold_val, score_threshold_val)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2") + .TypeConstraint("T") + .Device(DEVICE_GPU) + .HostMemory("iou_threshold") + .HostMemory("max_output_size"), + NonMaxSuppressionV2GPUOp); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3") + .TypeConstraint("T") + .Device(DEVICE_GPU) + .HostMemory("iou_threshold") + .HostMemory("max_output_size") + .HostMemory("score_threshold"), + NonMaxSuppressionV3GPUOp); } // namespace tensorflow #endif diff --git a/tensorflow/core/kernels/non_max_suppression_op.h b/tensorflow/core/kernels/non_max_suppression_op.h index 7ff4f16c689..eaa1b28ad4b 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.h +++ b/tensorflow/core/kernels/non_max_suppression_op.h @@ -54,7 +54,7 @@ extern const int kNmsBoxesPerTread; Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_num_boxes_to_keep, OpKernelContext* context, - bool flip_boxes = false); + const int max_boxes, bool flip_boxes = false); #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc b/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc index 1034aeaf37f..d33433f0af6 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc @@ -203,6 +203,222 @@ TEST_F(NonMaxSuppressionV2GPUOpTest, TestEmptyInput) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +// +// NonMaxSuppressionV3GPUOp Tests +// Copied from CPU tests + +class NonMaxSuppressionV3GPUOpTest : public OpsTestBase { + protected: + void MakeOp() { + SetDevice(DEVICE_GPU, + std::unique_ptr(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + + TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV3") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectFromThreeClusters) { + MakeOp(); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues(&expected, {3, 0, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectFromThreeClustersWithScoreThreshold) { + MakeOp(); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {0.5f}); + AddInputFromArray(TensorShape({}), {0.4f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues(&expected, {3, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectFromThreeClustersWithScoreThresholdZeroScores) { + MakeOp(); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.1, 0, 0, .3, .2, -5.0}); + // If we ask for more boxes than we actually expect to get back; + // should still only get 2 boxes back. + AddInputFromArray(TensorShape({}), {6}); + AddInputFromArray(TensorShape({}), {0.5f}); + AddInputFromArray(TensorShape({}), {-3.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues(&expected, {3, 0}); + + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectFromThreeClustersFlippedCoordinates) { + MakeOp(); + AddInputFromArray(TensorShape({6, 4}), + {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f, + 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues(&expected, {3, 0, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectAtMostTwoBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {2}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues(&expected, {3, 0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectAtMostThirtyBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray(TensorShape({}), {30}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues(&expected, {3, 0, 5}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectSingleBox) { + MakeOp(); + AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray(TensorShape({1}), {.9f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected, {0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectFromTenIdenticalBoxes) { + MakeOp(); + + int num_boxes = 10; + std::vector corners(num_boxes * 4); + std::vector scores(num_boxes); + for (int i = 0; i < num_boxes; ++i) { + corners[i * 4 + 0] = 0; + corners[i * 4 + 1] = 0; + corners[i * 4 + 2] = 1; + corners[i * 4 + 3] = 1; + scores[i] = .9; + } + AddInputFromArray(TensorShape({num_boxes, 4}), corners); + AddInputFromArray(TensorShape({num_boxes}), scores); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues(&expected, {0}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestInconsistentBoxAndScoreShapes) { + MakeOp(); + AddInputFromArray( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); + AddInputFromArray(TensorShape({}), {30}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape")) + << s; +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestInvalidIOUThreshold) { + MakeOp(); + AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray(TensorShape({1}), {.9f}); + AddInputFromArray(TensorShape({}), {3}); + AddInputFromArray(TensorShape({}), {1.2f}); + AddInputFromArray(TensorShape({}), {0.0f}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + absl::StrContains(s.ToString(), "iou_threshold must be in [0, 1]")) + << s; +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestEmptyInput) { + MakeOp(); + AddInputFromArray(TensorShape({0, 4}), {}); + AddInputFromArray(TensorShape({0}), {}); + AddInputFromArray(TensorShape({}), {30}); + AddInputFromArray(TensorShape({}), {.5f}); + AddInputFromArray(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({0})); + test::FillValues(&expected, {}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + #endif } // namespace tensorflow