From b2b8e56f5f68987d22e7f1405299891908431fc4 Mon Sep 17 00:00:00 2001 From: Sami <skama@nvidia.com> Date: Fri, 19 Jul 2019 20:08:06 -0700 Subject: [PATCH 1/5] Add NMSv3 GPU op --- .../core/kernels/non_max_suppression_op.cu.cc | 299 ++++++++++++++++-- .../core/kernels/non_max_suppression_op.h | 2 +- .../non_max_suppression_op_gpu_test.cc | 216 +++++++++++++ 3 files changed, 494 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 9b8526a75c3..81f458b0326 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -16,10 +16,6 @@ limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU #include "absl/strings/str_cat.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/device/device_select.cuh" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" @@ -27,6 +23,10 @@ limitations under the License. #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow/stream_executor/stream_executor.h" +#include "third_party/cub/device/device_radix_sort.cuh" +#include "third_party/cub/device/device_segmented_radix_sort.cuh" +#include "third_party/cub/device/device_select.cuh" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #define TF_RETURN_IF_CUDA_ERROR(result) \ do { \ @@ -80,8 +80,19 @@ __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) { b = c; } +template <bool T> +__device__ float legacy_offset(float); +template <> +__device__ EIGEN_STRONG_INLINE float legacy_offset<true>(float a) { + return a + 1.0; +} +template <> +__device__ EIGEN_STRONG_INLINE float legacy_offset<false>(float a) { + return a; +} + // Check whether two boxes have an IoU greater than threshold. -template <typename T> +template <typename T, bool L> __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b, float a_area, T iou_threshold) { @@ -93,8 +104,8 @@ __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b, const float yy2 = fminf(a->y2, b->y2); // fdimf computes the positive difference between xx2+1 and xx1. - const float w = fdimf(xx2 + 1.0f, xx1); - const float h = fdimf(yy2 + 1.0f, yy1); + const float w = fdimf(legacy_offset<L>(xx2), xx1); + const float h = fdimf(legacy_offset<L>(yy2), yy1); const float intersection = w * h; // Testing for aa/bb > t @@ -128,7 +139,7 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) { // If flip_box is true boxes may have x1>x2 and or y1>y2. If so change the // coordinates such that for all boxes x1<x2 and y1<y2. Else boxes should have // x1<x2 and y1<y2. -template <bool flip_box> +template <bool flip_box, bool legacy_mode> __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes, const float iou_threshold, const int bit_mask_len, @@ -172,8 +183,8 @@ __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ Box j_box = d_desc_sorted_boxes[j]; const Box i_box = shared_i_boxes[threadIdx.x]; Flipped<flip_box>(j_box); - if (OverThreshold(&i_box, &j_box, shared_i_areas[threadIdx.x], - iou_threshold)) { + if (OverThreshold<float, legacy_mode>( + &i_box, &j_box, shared_i_areas[threadIdx.x], iou_threshold)) { // we have score[j] <= score[i]. above_threshold |= (1U << ib); } @@ -224,7 +235,7 @@ __global__ void Iota(const int num_elements, const T offset, T* to_fill) { Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_nkeep, - OpKernelContext* context, bool flip_boxes) { + OpKernelContext* context, bool flip_boxes, bool legacy_mode) { // Making sure we respect the __align(16)__ // we promised to the compiler. auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr); @@ -270,20 +281,37 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, thread_block.y = kNmsBlockDim; thread_block.z = 1; if (flip_boxes) { - TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true>, block_dim, thread_block, 0, - device.stream(), d_sorted_boxes, num_boxes, - iou_threshold, bit_mask_len, d_delete_mask)); + if (!legacy_mode) { + TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true, false>, block_dim, + thread_block, 0, device.stream(), + d_sorted_boxes, num_boxes, iou_threshold, + bit_mask_len, d_delete_mask)); + } else { + TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true, true>, block_dim, + thread_block, 0, device.stream(), + d_sorted_boxes, num_boxes, iou_threshold, + bit_mask_len, d_delete_mask)); + } } else { - TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false>, block_dim, thread_block, 0, - device.stream(), d_sorted_boxes, num_boxes, - iou_threshold, bit_mask_len, d_delete_mask)); + if (!legacy_mode) { + TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false, false>, block_dim, + thread_block, 0, device.stream(), + d_sorted_boxes, num_boxes, iou_threshold, + bit_mask_len, d_delete_mask)); + } else { + TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false, true>, block_dim, + thread_block, 0, device.stream(), + d_sorted_boxes, num_boxes, iou_threshold, + bit_mask_len, d_delete_mask)); + } } TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); // Overlapping CPU computes and D2H memcpy // both take about the same time int num_to_copy = std::min(kNmsChunkSize, num_boxes); cudaEvent_t copy_done; - cudaEventCreate(©_done); + TF_RETURN_IF_CUDA_ERROR( + cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0], num_to_copy * bit_mask_len * sizeof(int)); TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); @@ -460,7 +488,7 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { auto status = NmsGpu(d_sorted_boxes.flat<float>().data(), num_boxes, iou_threshold_val, d_selected_indices.flat<int>().data(), - &num_to_keep, context, flip_boxes); + &num_to_keep, context, flip_boxes, /*legacy_mode*/ false); TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); if (!status.ok()) { context->SetStatus(status); @@ -482,9 +510,236 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { } }; -REGISTER_KERNEL_BUILDER( - Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_GPU), - NonMaxSuppressionV2GPUOp); +struct GreaterThanCubOp { + float threshold_; + __host__ __device__ __forceinline__ GreaterThanCubOp(float threshold) + : threshold_(threshold) {} + __host__ __device__ __forceinline__ bool operator()(const float& val) const { + return (val > threshold_); + } +}; +// Use DeviceSelect::If to count number of elements. +// TODO(sami) Not really a good way. Perhaps consider using thrust? +template <typename Op> +Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op, + int num_elements, int* result) { + Tensor scratch_output; + Tensor workspace; + Tensor element_count; + size_t workspace_size = 0; + auto cuda_stream = tensorflow::GetGpuStream(context); + auto device = context->eigen_gpu_device(); + cub::DeviceSelect::If(nullptr, workspace_size, static_cast<float*>(nullptr), + static_cast<float*>(nullptr), + static_cast<int*>(nullptr), num_elements, op); + + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output)); + TF_RETURN_IF_ERROR( + context, + context->allocate_temp(DataType::DT_INT8, + TensorShape({(int64)workspace_size}), &workspace)); + TF_RETURN_IF_ERROR( + context, context->allocate_temp(DataType::DT_INT32, TensorShape({1}), + &element_count)); + cudaEvent_t copy_done; + TF_RETURN_IF_CUDA_ERROR( + cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); + TF_RETURN_IF_CUDA_ERROR(cub::DeviceSelect::If( + workspace.flat<int8>().data(), workspace_size, dev_array, + scratch_output.flat<float>().data(), element_count.flat<int32>().data(), + num_elements, op, cuda_stream)); + device.memcpyDeviceToHost(result, element_count.flat<int32>().data(), + sizeof(int)); + TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); + TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); + return Status::OK(); +} + +class NonMaxSuppressionV3GPUOp : public OpKernel { + public: + explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // boxes: [num_boxes, 4] + const Tensor& boxes = context->input(0); + // scores: [num_boxes] + const Tensor& scores = context->input(1); + // max_output_size: scalar + const Tensor& max_output_size = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_output_size.shape()), + errors::InvalidArgument("max_output_size must be 0-D, got shape ", + max_output_size.shape().DebugString())); + // iou_threshold: scalar + const Tensor& iou_threshold = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); + const float iou_threshold_val = iou_threshold.scalar<float>()(); + + const Tensor& score_threshold = context->input(4); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(score_threshold.shape()), + errors::InvalidArgument("score_threshold must be 0-D, got shape ", + score_threshold.shape().DebugString())); + const float score_threshold_val = score_threshold.scalar<float>()(); + + OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); + OP_REQUIRES(context, boxes.dims() == 2, + errors::InvalidArgument("boxes must be a rank 2 tensor!")); + int num_boxes = boxes.dim_size(0); + OP_REQUIRES(context, boxes.dim_size(1) == 4, + errors::InvalidArgument("boxes must be Nx4")); + OP_REQUIRES(context, scores.dims() == 1, + errors::InvalidArgument("scores must be a vector!")); + OP_REQUIRES( + context, scores.dim_size(0) == num_boxes, + errors::InvalidArgument( + "scores has incompatible shape")); // message must be exactly this + // otherwise tests fail! + if (num_boxes == 0) { + Tensor* output_indices = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), + &output_indices)); + return; + } + const int output_size = max_output_size.scalar<int>()(); + size_t cub_sort_temp_storage_bytes = 0; + auto cuda_stream = tensorflow::GetGpuStream(context); + auto device = context->eigen_gpu_device(); + // Calling cub with nullptrs as inputs will make it return + // workspace size needed for the operation instead of doing the operation. + // In this specific instance, cub_sort_temp_storage_bytes will contain the + // necessary workspace size for sorting after the call. + cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( + nullptr, cub_sort_temp_storage_bytes, + static_cast<float*>(nullptr), // scores + static_cast<float*>(nullptr), // sorted scores + static_cast<int*>(nullptr), // input indices + static_cast<int*>(nullptr), // sorted indices + num_boxes, // num items + 0, 8 * sizeof(float), // sort all bits + cuda_stream); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); + Tensor d_cub_sort_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp( + DataType::DT_INT8, + TensorShape({(int64)cub_sort_temp_storage_bytes}), + &d_cub_sort_buffer)); + Tensor d_indices; + OP_REQUIRES_OK( + context, context->allocate_temp(DataType::DT_INT32, + TensorShape({num_boxes}), &d_indices)); + Tensor d_sorted_indices; + OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, + TensorShape({num_boxes}), + &d_sorted_indices)); + Tensor d_selected_indices; + OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, + TensorShape({num_boxes}), + &d_selected_indices)); + Tensor d_sorted_scores; + OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, + TensorShape({num_boxes}), + &d_sorted_scores)); + Tensor d_sorted_boxes; + OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, + TensorShape({num_boxes, 4}), + &d_sorted_boxes)); + + // this will return sorted scores and their indices + auto config = GetGpuLaunchConfig(num_boxes, device); + // initialize box and score indices + TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, 0, + d_indices.flat<int>().data())); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); + cuda_ret = cub::DeviceRadixSort::SortPairsDescending( + d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes, + scores.flat<float>().data(), d_sorted_scores.flat<float>().data(), + d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(), + num_boxes, 0, + 8 * sizeof(float), // sort all bits + cuda_stream); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); + // get pointers for easy access + const float4* original_boxes = + reinterpret_cast<const float4*>(boxes.flat<float>().data()); + float4* sorted_boxes = + reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data()); + const int* sorted_indices = d_sorted_indices.flat<int>().data(); + // sort boxes using indices + TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, + config.block_count, config.thread_per_block, 0, + device.stream(), config.virtual_thread_count, + sorted_indices, original_boxes, sorted_boxes)); + + // Unfortunately we had to sort scores to find the number of boxes which has + // a threshold above score_threshold_val. It can be done before sorting but + // that would require either implementing a custom sort or a generic random + // access iterator for cub. For the time being we search for the location of + // the score_threshold_val in the sorted array and limit num_boxes to its + // index. + GreaterThanCubOp score_limit(score_threshold_val); + int limited_num_boxes = 0; + OP_REQUIRES_OK(context, + CountIf(context, d_sorted_scores.flat<float>().data(), + score_limit, num_boxes, &limited_num_boxes)); + if (limited_num_boxes == 0) { + Tensor* output_indices = nullptr; + VLOG(1) << "Number of boxes above score threshold " << score_threshold_val + << " is 0"; + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), + &output_indices)); + return; + } + int num_to_keep = 0; + // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2, + // flip boxes if necessary! + const bool flip_boxes = true; + auto status = + NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes, + iou_threshold_val, d_selected_indices.flat<int>().data(), + &num_to_keep, context, flip_boxes, /*legacy_mode*/ false); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); + if (!status.ok()) { + context->SetStatus(status); + return; + } + Tensor* output_indices = nullptr; + int num_outputs = std::min(num_to_keep, output_size); // no padding! + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({num_outputs}), + &output_indices)); + if (num_outputs == 0) return; + config = GetGpuLaunchConfig(num_outputs, device); + TF_CHECK_OK(GpuLaunchKernel( + IndexMultiSelect<int, int>, config.block_count, config.thread_per_block, + 0, device.stream(), config.virtual_thread_count, + d_selected_indices.flat<int>().data(), sorted_indices, + (*output_indices).flat<int>().data())); + TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); + } +}; + +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2") + .TypeConstraint<float>("T") + .Device(DEVICE_GPU) + .HostMemory("iou_threshold") + .HostMemory("max_output_size"), + NonMaxSuppressionV2GPUOp); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3") + .TypeConstraint<float>("T") + .Device(DEVICE_GPU) + .HostMemory("iou_threshold") + .HostMemory("max_output_size") + .HostMemory("score_threshold"), + NonMaxSuppressionV3GPUOp); } // namespace tensorflow #endif diff --git a/tensorflow/core/kernels/non_max_suppression_op.h b/tensorflow/core/kernels/non_max_suppression_op.h index 7ff4f16c689..bf1c97e66dd 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.h +++ b/tensorflow/core/kernels/non_max_suppression_op.h @@ -54,7 +54,7 @@ extern const int kNmsBoxesPerTread; Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_num_boxes_to_keep, OpKernelContext* context, - bool flip_boxes = false); + bool flip_boxes = false,bool legacy_mode=false); #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc b/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc index 1034aeaf37f..d33433f0af6 100644 --- a/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc +++ b/tensorflow/core/kernels/non_max_suppression_op_gpu_test.cc @@ -203,6 +203,222 @@ TEST_F(NonMaxSuppressionV2GPUOpTest, TestEmptyInput) { test::ExpectTensorEqual<int>(expected, *GetOutput(0)); } +// +// NonMaxSuppressionV3GPUOp Tests +// Copied from CPU tests + +class NonMaxSuppressionV3GPUOpTest : public OpsTestBase { + protected: + void MakeOp() { + SetDevice(DEVICE_GPU, + std::unique_ptr<tensorflow::Device>(DeviceFactory::NewDevice( + "GPU", {}, "/job:a/replica:0/task:0"))); + + TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV3") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } +}; + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectFromThreeClusters) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues<int>(&expected, {3, 0, 5}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectFromThreeClustersWithScoreThreshold) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {0.5f}); + AddInputFromArray<float>(TensorShape({}), {0.4f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues<int>(&expected, {3, 0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectFromThreeClustersWithScoreThresholdZeroScores) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.1, 0, 0, .3, .2, -5.0}); + // If we ask for more boxes than we actually expect to get back; + // should still only get 2 boxes back. + AddInputFromArray<int>(TensorShape({}), {6}); + AddInputFromArray<float>(TensorShape({}), {0.5f}); + AddInputFromArray<float>(TensorShape({}), {-3.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues<int>(&expected, {3, 0}); + + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectFromThreeClustersFlippedCoordinates) { + MakeOp(); + AddInputFromArray<float>(TensorShape({6, 4}), + {1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f, + 0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues<int>(&expected, {3, 0, 5}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectAtMostTwoBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {2}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({2})); + test::FillValues<int>(&expected, {3, 0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, + TestSelectAtMostThirtyBoxesFromThreeClusters) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f}); + AddInputFromArray<int>(TensorShape({}), {30}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({3})); + test::FillValues<int>(&expected, {3, 0, 5}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectSingleBox) { + MakeOp(); + AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray<float>(TensorShape({1}), {.9f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues<int>(&expected, {0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectFromTenIdenticalBoxes) { + MakeOp(); + + int num_boxes = 10; + std::vector<float> corners(num_boxes * 4); + std::vector<float> scores(num_boxes); + for (int i = 0; i < num_boxes; ++i) { + corners[i * 4 + 0] = 0; + corners[i * 4 + 1] = 0; + corners[i * 4 + 2] = 1; + corners[i * 4 + 3] = 1; + scores[i] = .9; + } + AddInputFromArray<float>(TensorShape({num_boxes, 4}), corners); + AddInputFromArray<float>(TensorShape({num_boxes}), scores); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({1})); + test::FillValues<int>(&expected, {0}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestInconsistentBoxAndScoreShapes) { + MakeOp(); + AddInputFromArray<float>( + TensorShape({6, 4}), + {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); + AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); + AddInputFromArray<int>(TensorShape({}), {30}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape")) + << s; +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestInvalidIOUThreshold) { + MakeOp(); + AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1}); + AddInputFromArray<float>(TensorShape({1}), {.9f}); + AddInputFromArray<int>(TensorShape({}), {3}); + AddInputFromArray<float>(TensorShape({}), {1.2f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + Status s = RunOpKernel(); + + ASSERT_FALSE(s.ok()); + EXPECT_TRUE( + absl::StrContains(s.ToString(), "iou_threshold must be in [0, 1]")) + << s; +} + +TEST_F(NonMaxSuppressionV3GPUOpTest, TestEmptyInput) { + MakeOp(); + AddInputFromArray<float>(TensorShape({0, 4}), {}); + AddInputFromArray<float>(TensorShape({0}), {}); + AddInputFromArray<int>(TensorShape({}), {30}); + AddInputFromArray<float>(TensorShape({}), {.5f}); + AddInputFromArray<float>(TensorShape({}), {0.0f}); + TF_ASSERT_OK(RunOpKernel()); + + Tensor expected(allocator(), DT_INT32, TensorShape({0})); + test::FillValues<int>(&expected, {}); + test::ExpectTensorEqual<int>(expected, *GetOutput(0)); +} + #endif } // namespace tensorflow From c2a4931076e69bff99d9f6e40ae607c5335c1e6a Mon Sep 17 00:00:00 2001 From: Sami <skama@nvidia.com> Date: Thu, 22 Aug 2019 18:47:03 -0700 Subject: [PATCH 2/5] Pure GPU NMS implementation. --- .../core/kernels/non_max_suppression_op.cu.cc | 183 +++++++++++------- .../core/kernels/non_max_suppression_op.h | 5 +- 2 files changed, 117 insertions(+), 71 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 81f458b0326..79d3ea71c8a 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -128,6 +128,47 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) { if (box.x1 > box.x2) Swap(box.x1, box.x2); if (box.y1 > box.y2) Swap(box.y1, box.y2); } +template <typename T> +__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) { + constexpr int SHIFTLEN = NumBits(8 * sizeof(T)) - 1; + constexpr int REMAINDER_MASK = 8 * sizeof(T) - 1; + int bin = bit >> SHIFTLEN; + return (bit_mask[bin] >> (bit & REMAINDER_MASK)) & 1; +} + +// Produce a global bitmask (result_mask) of selected boxes from bitmask +// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask +// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box. +__global__ void NMSReduce(const int* bitmask, const int bit_mask_len, + const int num_boxes, const int max_boxes, + char* result_mask) { + extern __shared__ int local[]; + // set global mask to accept all boxes + for (int box : CudaGridRangeX(bit_mask_len)) { + local[box] = 0xFFFFFFFF; + } + __syncthreads(); + int accepted_boxes = 0; + for (int box = 0; box < num_boxes - 1; ++box) { + // if current box is masked by an earlier box, skip it. + if (!CheckBit(local, box)) { + continue; + } + accepted_boxes += 1; + int offset = box * bit_mask_len; + // update global mask with current box's mask + for (int b : CudaGridRangeX(bit_mask_len)) { + local[b] &= ~bitmask[offset + b]; + } + __syncthreads(); + if (accepted_boxes > max_boxes) break; + } + // copy global mask to result_max char array. char array is needed for + // cub::DeviceSelect later. + for (int box : CudaGridRangeX(num_boxes)) { + result_mask[box] = CheckBit(local, box); + } +} // For each box, compute a bitmask of boxes which has an overlap with given box // above threshold. @@ -235,7 +276,8 @@ __global__ void Iota(const int num_elements, const T offset, T* to_fill) { Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_nkeep, - OpKernelContext* context, bool flip_boxes, bool legacy_mode) { + OpKernelContext* context, const int max_boxes, bool flip_boxes, + bool legacy_mode) { // Making sure we respect the __align(16)__ // we promised to the compiler. auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr); @@ -243,7 +285,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, return errors::InvalidArgument("Boxes should be aligned to 16 Bytes."); } // allocate bitmask arrays on host and on device - Tensor h_nms_mask, d_nms_mask; + Tensor h_num_selected, d_nms_mask; const int bit_mask_len = (num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread; @@ -264,11 +306,11 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, // Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and // using it as a ring buffer. However savings should be a few MB . TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, - TensorShape({max_nms_mask_size}), - &h_nms_mask, alloc_attr)); + TensorShape({1}), + &h_num_selected, alloc_attr)); int* d_delete_mask = d_nms_mask.flat<int>().data(); - int* h_delete_mask = h_nms_mask.flat<int>().data(); + int* h_selected_count = h_num_selected.flat<int>().data(); const Box* d_sorted_boxes = reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr); dim3 block_dim, thread_block; @@ -308,59 +350,57 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); // Overlapping CPU computes and D2H memcpy // both take about the same time - int num_to_copy = std::min(kNmsChunkSize, num_boxes); + + config = GetGpuLaunchConfig(num_boxes, device); + Tensor selected_boxes; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes)); + Tensor d_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_indices)); + TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, 0, + d_indices.flat<int>().data())); + + char* selected = (char*)(selected_boxes.flat<int8>().data()); + TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int), + device.stream(), d_delete_mask, bit_mask_len, + num_boxes, max_boxes, selected)); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + // do Cub::deviceSelect::flagged + size_t flagged_buffer_size = 0; + cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage + flagged_buffer_size, + static_cast<int*>(nullptr), // input + static_cast<char*>(nullptr), // selection flag + static_cast<int*>(nullptr), // selected items + static_cast<int*>(nullptr), // num_selected + num_boxes, device.stream()); + Tensor cub_scratch; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}), + &cub_scratch)); + Tensor d_num_selected; + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({1}), &d_num_selected)); + + cub::DeviceSelect::Flagged( + (void*)cub_scratch.flat<int8>().data(), // temp_storage + flagged_buffer_size, + d_indices.flat<int>().data(), // input + selected, // selection flag + d_selected_indices, // selected items + d_num_selected.flat<int>().data(), num_boxes, device.stream()); cudaEvent_t copy_done; TF_RETURN_IF_CUDA_ERROR( cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); - device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0], - num_to_copy * bit_mask_len * sizeof(int)); + device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(), + sizeof(int)); TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); - int offset = 0; - std::vector<int> h_selected_indices; - // Reserve worst case scenario. Since box count is not huge, this should have - // negligible footprint. - h_selected_indices.reserve(num_boxes); - std::vector<int> to_remove(bit_mask_len, 0); - while (offset < num_boxes) { - const int num_copied = num_to_copy; - int next_offset = offset + num_copied; - num_to_copy = std::min(kNmsChunkSize, num_boxes - next_offset); - if (num_to_copy > 0) { - device.memcpyDeviceToHost(&h_delete_mask[next_offset * bit_mask_len], - &d_delete_mask[next_offset * bit_mask_len], - num_to_copy * bit_mask_len * sizeof(int)); - } - // Waiting for previous copy - TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); - if (num_to_copy > 0) { - TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); - } - // Starting from highest scoring box, mark any box with iou>threshold and - // lower score for deletion if current box is not marked for deletion. Add - // current box to to_keep list. - for (int i = offset; i < next_offset; ++i) { - // See the comment at the beginning of the file. - // Bit shift and logical And operations are used - // instead of division and modulo operations. - int iblock = i >> kNmsBoxesPerThreadShiftBits; - int inblock = i & kNmsBoxesPerThreadModuloMask; - if (!(to_remove[iblock] & (1 << inblock))) { - h_selected_indices.push_back(i); - int* p = &h_delete_mask[i * bit_mask_len]; - for (int ib = 0; ib < bit_mask_len; ++ib) { - to_remove[ib] |= p[ib]; - } - } - } - offset = next_offset; - } + TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); + *h_nkeep = *h_selected_count; cudaEventDestroy(copy_done); - - const int nkeep = h_selected_indices.size(); - device.memcpyHostToDevice(d_selected_indices, &h_selected_indices[0], - nkeep * sizeof(int)); - - *h_nkeep = nkeep; return Status::OK(); } @@ -485,10 +525,10 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2, // flip boxes if necessary! const bool flip_boxes = true; - auto status = - NmsGpu(d_sorted_boxes.flat<float>().data(), num_boxes, - iou_threshold_val, d_selected_indices.flat<int>().data(), - &num_to_keep, context, flip_boxes, /*legacy_mode*/ false); + auto status = NmsGpu( + d_sorted_boxes.flat<float>().data(), num_boxes, iou_threshold_val, + d_selected_indices.flat<int>().data(), &num_to_keep, context, + output_size, flip_boxes, /*legacy_mode*/ false); TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); if (!status.ok()) { context->SetStatus(status); @@ -535,13 +575,10 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op, TF_RETURN_IF_ERROR(context->allocate_temp( DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output)); - TF_RETURN_IF_ERROR( - context, - context->allocate_temp(DataType::DT_INT8, - TensorShape({(int64)workspace_size}), &workspace)); - TF_RETURN_IF_ERROR( - context, context->allocate_temp(DataType::DT_INT32, TensorShape({1}), - &element_count)); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace)); + TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, + TensorShape({1}), &element_count)); cudaEvent_t copy_done; TF_RETURN_IF_CUDA_ERROR( cudaEventCreateWithFlags(©_done, cudaEventDisableTiming)); @@ -697,15 +734,18 @@ class NonMaxSuppressionV3GPUOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), &output_indices)); return; + } else { + VLOG(2) << "Number of boxes above threshold=" << score_threshold_val + << " is " << limited_num_boxes; } int num_to_keep = 0; // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2, // flip boxes if necessary! const bool flip_boxes = true; - auto status = - NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes, - iou_threshold_val, d_selected_indices.flat<int>().data(), - &num_to_keep, context, flip_boxes, /*legacy_mode*/ false); + auto status = NmsGpu( + d_sorted_boxes.flat<float>().data(), limited_num_boxes, + iou_threshold_val, d_selected_indices.flat<int>().data(), &num_to_keep, + context, output_size, flip_boxes, /*legacy_mode*/ false); TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); if (!status.ok()) { context->SetStatus(status); @@ -716,7 +756,12 @@ class NonMaxSuppressionV3GPUOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({num_outputs}), &output_indices)); - if (num_outputs == 0) return; + if (num_outputs == 0) { + VLOG(1) << "No outputs!"; + return; + } else { + VLOG(2) << "Num outputs= " << num_outputs; + } config = GetGpuLaunchConfig(num_outputs, device); TF_CHECK_OK(GpuLaunchKernel( IndexMultiSelect<int, int>, config.block_count, config.thread_per_block, diff --git a/tensorflow/core/kernels/non_max_suppression_op.h b/tensorflow/core/kernels/non_max_suppression_op.h index bf1c97e66dd..fbf4dbfcd1a 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.h +++ b/tensorflow/core/kernels/non_max_suppression_op.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_ #define TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_ -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" namespace tensorflow { namespace functor { @@ -54,7 +54,8 @@ extern const int kNmsBoxesPerTread; Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_num_boxes_to_keep, OpKernelContext* context, - bool flip_boxes = false,bool legacy_mode=false); + const int max_boxes, bool flip_boxes = false, + bool legacy_mode = false); #endif } // namespace tensorflow From e882f1749879ecd4d479ca8aeba4b04f48fa7085 Mon Sep 17 00:00:00 2001 From: Sami <skama@nvidia.com> Date: Wed, 28 Aug 2019 18:01:54 -0700 Subject: [PATCH 3/5] Revert all __restrict__ keywords, they are negatively impacting performance --- .../core/kernels/non_max_suppression_op.cu.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 8b51a39cb58..4afb5d16e51 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -182,9 +182,9 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len, // x1<x2 and y1<y2. template <bool flip_box, bool legacy_mode> __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ - void NMSKernel(const Box* __restrict__ d_desc_sorted_boxes, + void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes, const float iou_threshold, - const int bit_mask_len, int* __restrict__ d_delete_mask) { + const int bit_mask_len, int* d_delete_mask) { // Storing boxes used by this CUDA block in the shared memory. __shared__ Box shared_i_boxes[kNmsBlockDim]; // Same thing with areas @@ -247,8 +247,8 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, template <typename Index, typename T, typename... Args> __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, const Index i_original, - const T* __restrict__ original, - T* __restrict__ selected, + const T* original, + T* selected, Args... args) { selected[i_selected] = original[i_original]; SelectHelper(i_selected, i_original, args...); @@ -262,9 +262,9 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, // selected2). template <typename Index, typename T, typename... Args> __global__ void IndexMultiSelect(const int num_elements, - const Index* __restrict__ indices, - const T* __restrict__ original, - T* __restrict__ selected, Args... args) { + const Index* indices, + const T* original, + T* selected, Args... args) { for (const int idx : CudaGridRangeX(num_elements)) { SelectHelper(idx, indices[idx], original, selected, args...); } From 736eba374e2976d1e8bd415dd13a0d547ab0c8c9 Mon Sep 17 00:00:00 2001 From: Sami <skama@nvidia.com> Date: Mon, 9 Sep 2019 13:54:46 -0700 Subject: [PATCH 4/5] Addressing review comments --- .../core/kernels/non_max_suppression_op.cu.cc | 533 +++++++----------- .../core/kernels/non_max_suppression_op.h | 3 +- 2 files changed, 204 insertions(+), 332 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 4afb5d16e51..af3b36a464d 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -15,6 +15,7 @@ limitations under the License. #if GOOGLE_CUDA #define EIGEN_USE_GPU +#include <limits> #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" @@ -80,19 +81,8 @@ __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) { b = c; } -template <bool T> -__device__ float legacy_offset(float); -template <> -__device__ EIGEN_STRONG_INLINE float legacy_offset<true>(float a) { - return a + 1.0; -} -template <> -__device__ EIGEN_STRONG_INLINE float legacy_offset<false>(float a) { - return a; -} - // Check whether two boxes have an IoU greater than threshold. -template <typename T, bool L> +template <typename T> __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b, const float a_area, const T iou_threshold) { @@ -104,8 +94,8 @@ __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b, const float yy2 = fminf(a->y2, b->y2); // fdimf computes the positive difference between xx2+1 and xx1. - const float w = fdimf(legacy_offset<L>(xx2), xx1); - const float h = fdimf(legacy_offset<L>(yy2), yy1); + const float w = fdimf(xx2, xx1); + const float h = fdimf(yy2, yy1); const float intersection = w * h; // Testing for aa/bb > t @@ -130,10 +120,10 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) { } template <typename T> __device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) { - constexpr int SHIFTLEN = NumBits(8 * sizeof(T)) - 1; - constexpr int REMAINDER_MASK = 8 * sizeof(T) - 1; - int bin = bit >> SHIFTLEN; - return (bit_mask[bin] >> (bit & REMAINDER_MASK)) & 1; + constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1; + constexpr int kRemainderMask = 8 * sizeof(T) - 1; + int bin = bit >> kShiftLen; + return (bit_mask[bin] >> (bit & kRemainderMask)) & 1; } // Produce a global bitmask (result_mask) of selected boxes from bitmask @@ -180,11 +170,11 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len, // If flip_box is true boxes may have x1>x2 and or y1>y2. If so change the // coordinates such that for all boxes x1<x2 and y1<y2. Else boxes should have // x1<x2 and y1<y2. -template <bool flip_box, bool legacy_mode> +template <bool flip_box> __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ - void NMSKernel(const Box* d_desc_sorted_boxes, - const int num_boxes, const float iou_threshold, - const int bit_mask_len, int* d_delete_mask) { + void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes, + const float iou_threshold, const int bit_mask_len, + int* d_delete_mask) { // Storing boxes used by this CUDA block in the shared memory. __shared__ Box shared_i_boxes[kNmsBlockDim]; // Same thing with areas @@ -224,8 +214,8 @@ __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__ Box j_box = d_desc_sorted_boxes[j]; const Box i_box = shared_i_boxes[threadIdx.x]; Flipped<flip_box>(j_box); - if (OverThreshold<float, legacy_mode>( - &i_box, &j_box, shared_i_areas[threadIdx.x], iou_threshold)) { + if (OverThreshold<float>(&i_box, &j_box, shared_i_areas[threadIdx.x], + iou_threshold)) { // we have score[j] <= score[i]. above_threshold |= (1U << ib); } @@ -247,8 +237,7 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, template <typename Index, typename T, typename... Args> __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, const Index i_original, - const T* original, - T* selected, + const T* original, T* selected, Args... args) { selected[i_selected] = original[i_original]; SelectHelper(i_selected, i_original, args...); @@ -261,18 +250,15 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected, // IndexMultiSelect(num_elements, indices, original1 ,selected1, original2, // selected2). template <typename Index, typename T, typename... Args> -__global__ void IndexMultiSelect(const int num_elements, - const Index* indices, - const T* original, - T* selected, Args... args) { +__global__ void IndexMultiSelect(const int num_elements, const Index* indices, + const T* original, T* selected, Args... args) { for (const int idx : CudaGridRangeX(num_elements)) { SelectHelper(idx, indices[idx], original, selected, args...); } } template <typename T> -__global__ void Iota(const int num_elements, const T offset, - T* to_fill) { +__global__ void Iota(const int num_elements, const T offset, T* to_fill) { for (int idx : CudaGridRangeX(num_elements)) { to_fill[idx] = static_cast<T>(idx) + offset; } @@ -280,8 +266,7 @@ __global__ void Iota(const int num_elements, const T offset, Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_nkeep, - OpKernelContext* context, const int max_boxes, bool flip_boxes, - bool legacy_mode) { + OpKernelContext* context, const int max_boxes, bool flip_boxes) { // Making sure we respect the __align(16)__ // we promised to the compiler. auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr); @@ -309,9 +294,8 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, alloc_attr.set_gpu_compatible(true); // Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and // using it as a ring buffer. However savings should be a few MB . - TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, - TensorShape({1}), - &h_num_selected, alloc_attr)); + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({1}), &h_num_selected, alloc_attr)); int* d_delete_mask = d_nms_mask.flat<int>().data(); int* h_selected_count = h_num_selected.flat<int>().data(); @@ -327,29 +311,13 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, thread_block.y = kNmsBlockDim; thread_block.z = 1; if (flip_boxes) { - if (!legacy_mode) { - TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true, false>, block_dim, - thread_block, 0, device.stream(), - d_sorted_boxes, num_boxes, iou_threshold, - bit_mask_len, d_delete_mask)); - } else { - TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true, true>, block_dim, - thread_block, 0, device.stream(), - d_sorted_boxes, num_boxes, iou_threshold, - bit_mask_len, d_delete_mask)); - } + TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true>, block_dim, thread_block, 0, + device.stream(), d_sorted_boxes, num_boxes, + iou_threshold, bit_mask_len, d_delete_mask)); } else { - if (!legacy_mode) { - TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false, false>, block_dim, - thread_block, 0, device.stream(), - d_sorted_boxes, num_boxes, iou_threshold, - bit_mask_len, d_delete_mask)); - } else { - TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false, true>, block_dim, - thread_block, 0, device.stream(), - d_sorted_boxes, num_boxes, iou_threshold, - bit_mask_len, d_delete_mask)); - } + TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false>, block_dim, thread_block, 0, + device.stream(), d_sorted_boxes, num_boxes, + iou_threshold, bit_mask_len, d_delete_mask)); } TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); // Overlapping CPU computes and D2H memcpy @@ -408,152 +376,6 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, return Status::OK(); } -class NonMaxSuppressionV2GPUOp : public OpKernel { - public: - explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context) - : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - // boxes: [num_boxes, 4] - const Tensor& boxes = context->input(0); - // scores: [num_boxes] - const Tensor& scores = context->input(1); - // max_output_size: scalar - const Tensor& max_output_size = context->input(2); - OP_REQUIRES( - context, TensorShapeUtils::IsScalar(max_output_size.shape()), - errors::InvalidArgument("max_output_size must be 0-D, got shape ", - max_output_size.shape().DebugString())); - // iou_threshold: scalar - const Tensor& iou_threshold = context->input(3); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), - errors::InvalidArgument("iou_threshold must be 0-D, got shape ", - iou_threshold.shape().DebugString())); - const float iou_threshold_val = iou_threshold.scalar<float>()(); - - OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, - errors::InvalidArgument("iou_threshold must be in [0, 1]")); - OP_REQUIRES(context, boxes.dims() == 2, - errors::InvalidArgument("boxes must be a rank 2 tensor!")); - int num_boxes = boxes.dim_size(0); - OP_REQUIRES(context, boxes.dim_size(1) == 4, - errors::InvalidArgument("boxes must be Nx4")); - OP_REQUIRES(context, scores.dims() == 1, - errors::InvalidArgument("scores must be a vector!")); - OP_REQUIRES( - context, scores.dim_size(0) == num_boxes, - errors::InvalidArgument( - "scores has incompatible shape")); // message must be exactly this - // otherwise tests fail! - if (num_boxes == 0) { - Tensor* output_indices = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), - &output_indices)); - return; - } - const int output_size = max_output_size.scalar<int>()(); - size_t cub_sort_temp_storage_bytes = 0; - auto cuda_stream = GetGpuStream(context); - auto device = context->eigen_gpu_device(); - // Calling cub with nullptrs as inputs will make it return - // workspace size needed for the operation instead of doing the operation. - // In this specific instance, cub_sort_temp_storage_bytes will contain the - // necessary workspace size for sorting after the call. - cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( - nullptr, cub_sort_temp_storage_bytes, - static_cast<float*>(nullptr), // scores - static_cast<float*>(nullptr), // sorted scores - static_cast<int*>(nullptr), // input indices - static_cast<int*>(nullptr), // sorted indices - num_boxes, // num items - 0, 8 * sizeof(float), // sort all bits - cuda_stream); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); - Tensor d_cub_sort_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp( - DataType::DT_INT8, - TensorShape({(int64)cub_sort_temp_storage_bytes}), - &d_cub_sort_buffer)); - Tensor d_indices; - OP_REQUIRES_OK( - context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), &d_indices)); - Tensor d_sorted_indices; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), - &d_sorted_indices)); - Tensor d_selected_indices; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), - &d_selected_indices)); - Tensor d_sorted_scores; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, - TensorShape({num_boxes}), - &d_sorted_scores)); - Tensor d_sorted_boxes; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, - TensorShape({num_boxes, 4}), - &d_sorted_boxes)); - - // this will return sorted scores and their indices - auto config = GetGpuLaunchConfig(num_boxes, device); - // initialize box and score indices - TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count, - config.thread_per_block, 0, device.stream(), - config.virtual_thread_count, 0, - d_indices.flat<int>().data())); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - cuda_ret = cub::DeviceRadixSort::SortPairsDescending( - d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes, - scores.flat<float>().data(), d_sorted_scores.flat<float>().data(), - d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(), - num_boxes, 0, - 8 * sizeof(float), // sort all bits - cuda_stream); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); - - // get pointers for easy access - const float4* original_boxes = - reinterpret_cast<const float4*>(boxes.flat<float>().data()); - float4* sorted_boxes = - reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data()); - const int* sorted_indices = d_sorted_indices.flat<int>().data(); - // sort boxes using indices - TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, - config.block_count, config.thread_per_block, 0, - device.stream(), config.virtual_thread_count, - sorted_indices, original_boxes, sorted_boxes)); - - int num_to_keep = 0; - // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2, - // flip boxes if necessary! - const bool flip_boxes = true; - auto status = NmsGpu( - d_sorted_boxes.flat<float>().data(), num_boxes, iou_threshold_val, - d_selected_indices.flat<int>().data(), &num_to_keep, context, - output_size, flip_boxes, /*legacy_mode*/ false); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - if (!status.ok()) { - context->SetStatus(status); - return; - } - Tensor* output_indices = nullptr; - int num_outputs = std::min(num_to_keep, output_size); // no padding! - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_outputs}), - &output_indices)); - if (num_outputs == 0) return; - config = GetGpuLaunchConfig(num_outputs, device); - TF_CHECK_OK(GpuLaunchKernel( - IndexMultiSelect<int, int>, config.block_count, config.thread_per_block, - 0, device.stream(), config.virtual_thread_count, - d_selected_indices.flat<int>().data(), sorted_indices, - (*output_indices).flat<int>().data())); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - } -}; - struct GreaterThanCubOp { float threshold_; __host__ __device__ __forceinline__ GreaterThanCubOp(float threshold) @@ -597,6 +419,180 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op, return Status::OK(); } +Status DoNMS(OpKernelContext* context, const Tensor& boxes, + const Tensor& scores, const int64_t max_output_size, + const float iou_threshold_val, const float score_threshold) { + const int output_size = max_output_size; + int num_boxes = boxes.dim_size(0); + size_t cub_sort_temp_storage_bytes = 0; + auto cuda_stream = GetGpuStream(context); + auto device = context->eigen_gpu_device(); + // Calling cub with nullptrs as inputs will make it return + // workspace size needed for the operation instead of doing the operation. + // In this specific instance, cub_sort_temp_storage_bytes will contain the + // necessary workspace size for sorting after the call. + if (num_boxes == 0) { + Tensor* output_indices = nullptr; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({0}), &output_indices)); + return Status::OK(); + } + + cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( + nullptr, cub_sort_temp_storage_bytes, + static_cast<float*>(nullptr), // scores + static_cast<float*>(nullptr), // sorted scores + static_cast<int*>(nullptr), // input indices + static_cast<int*>(nullptr), // sorted indices + num_boxes, // num items + 0, 8 * sizeof(float), // sort all bits + cuda_stream); + TF_RETURN_IF_CUDA_ERROR(cuda_ret); + Tensor d_cub_sort_buffer; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}), + &d_cub_sort_buffer)); + Tensor d_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_indices)); + Tensor d_sorted_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_sorted_indices)); + Tensor d_selected_indices; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_INT32, TensorShape({num_boxes}), &d_selected_indices)); + Tensor d_sorted_scores; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_boxes}), &d_sorted_scores)); + Tensor d_sorted_boxes; + TF_RETURN_IF_ERROR(context->allocate_temp( + DataType::DT_FLOAT, TensorShape({num_boxes, 4}), &d_sorted_boxes)); + + // this will return sorted scores and their indices + auto config = GetGpuLaunchConfig(num_boxes, device); + // initialize box and score indices + TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, 0, + d_indices.flat<int>().data())); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + cuda_ret = cub::DeviceRadixSort::SortPairsDescending( + d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes, + scores.flat<float>().data(), d_sorted_scores.flat<float>().data(), + d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(), + num_boxes, 0, + 8 * sizeof(float), // sort all bits + cuda_stream); + TF_RETURN_IF_CUDA_ERROR(cuda_ret); + + // get pointers for easy access + const float4* original_boxes = + reinterpret_cast<const float4*>(boxes.flat<float>().data()); + float4* sorted_boxes = + reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data()); + const int* sorted_indices = d_sorted_indices.flat<int>().data(); + // sort boxes using indices + TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, config.block_count, + config.thread_per_block, 0, device.stream(), + config.virtual_thread_count, sorted_indices, + original_boxes, sorted_boxes)); + int limited_num_boxes = num_boxes; + // filter boxes by scores if nms v3 + if (score_threshold > std::numeric_limits<float>::min()) { + GreaterThanCubOp score_limit(score_threshold); + TF_RETURN_IF_ERROR(CountIf(context, d_sorted_scores.flat<float>().data(), + score_limit, num_boxes, &limited_num_boxes)); + if (limited_num_boxes == 0) { + Tensor* output_indices = nullptr; + VLOG(1) << "Number of boxes above score threshold " << score_threshold + << " is 0"; + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({0}), &output_indices)); + return Status::OK(); + } else { + VLOG(2) << "Number of boxes above threshold=" << score_threshold + << " is " << limited_num_boxes; + } + } + int num_to_keep = 0; + // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2, + // flip boxes if necessary! + const bool flip_boxes = true; + auto status = NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes, + iou_threshold_val, d_selected_indices.flat<int>().data(), + &num_to_keep, context, output_size, flip_boxes); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + if (!status.ok()) { + context->SetStatus(status); + return status; + } + Tensor* output_indices = nullptr; + int num_outputs = std::min(num_to_keep, output_size); // no padding! + TF_RETURN_IF_ERROR( + context->allocate_output(0, TensorShape({num_outputs}), &output_indices)); + if (num_outputs == 0) return Status::OK(); + config = GetGpuLaunchConfig(num_outputs, device); + TF_CHECK_OK(GpuLaunchKernel( + IndexMultiSelect<int, int>, config.block_count, config.thread_per_block, + 0, device.stream(), config.virtual_thread_count, + d_selected_indices.flat<int>().data(), sorted_indices, + (*output_indices).flat<int>().data())); + TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); + return Status::OK(); +} + +class NonMaxSuppressionV2GPUOp : public OpKernel { + public: + explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // boxes: [num_boxes, 4] + const Tensor& boxes = context->input(0); + // scores: [num_boxes] + const Tensor& scores = context->input(1); + // max_output_size: scalar + const Tensor& max_output_size = context->input(2); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(max_output_size.shape()), + errors::InvalidArgument("max_output_size must be 0-D, got shape ", + max_output_size.shape().DebugString())); + // iou_threshold: scalar + const Tensor& iou_threshold = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()), + errors::InvalidArgument("iou_threshold must be 0-D, got shape ", + iou_threshold.shape().DebugString())); + const float iou_threshold_val = iou_threshold.scalar<float>()(); + + OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1, + errors::InvalidArgument("iou_threshold must be in [0, 1]")); + OP_REQUIRES(context, boxes.dims() == 2, + errors::InvalidArgument("boxes must be a rank 2 tensor!")); + int num_boxes = boxes.dim_size(0); + OP_REQUIRES(context, boxes.dim_size(1) == 4, + errors::InvalidArgument("boxes must be Nx4")); + OP_REQUIRES(context, scores.dims() == 1, + errors::InvalidArgument("scores must be a vector!")); + OP_REQUIRES( + context, scores.dim_size(0) == num_boxes, + errors::InvalidArgument( + "scores has incompatible shape")); // message must be exactly this + // otherwise tests fail! + if (num_boxes == 0) { + Tensor* output_indices = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({0}), &output_indices)); + return; + } + const int64_t output_size = max_output_size.scalar<int>()(); + OP_REQUIRES_OK( + context, + DoNMS(context, boxes, scores, output_size, iou_threshold_val, + /*score_threshold is float min if score threshold is disabled*/ + std::numeric_limits<float>::min())); + } +}; + class NonMaxSuppressionV3GPUOp : public OpKernel { public: explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context) @@ -648,131 +644,8 @@ class NonMaxSuppressionV3GPUOp : public OpKernel { return; } const int output_size = max_output_size.scalar<int>()(); - size_t cub_sort_temp_storage_bytes = 0; - auto cuda_stream = tensorflow::GetGpuStream(context); - auto device = context->eigen_gpu_device(); - // Calling cub with nullptrs as inputs will make it return - // workspace size needed for the operation instead of doing the operation. - // In this specific instance, cub_sort_temp_storage_bytes will contain the - // necessary workspace size for sorting after the call. - cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( - nullptr, cub_sort_temp_storage_bytes, - static_cast<float*>(nullptr), // scores - static_cast<float*>(nullptr), // sorted scores - static_cast<int*>(nullptr), // input indices - static_cast<int*>(nullptr), // sorted indices - num_boxes, // num items - 0, 8 * sizeof(float), // sort all bits - cuda_stream); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); - Tensor d_cub_sort_buffer; - OP_REQUIRES_OK(context, - context->allocate_temp( - DataType::DT_INT8, - TensorShape({(int64)cub_sort_temp_storage_bytes}), - &d_cub_sort_buffer)); - Tensor d_indices; - OP_REQUIRES_OK( - context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), &d_indices)); - Tensor d_sorted_indices; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), - &d_sorted_indices)); - Tensor d_selected_indices; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, - TensorShape({num_boxes}), - &d_selected_indices)); - Tensor d_sorted_scores; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, - TensorShape({num_boxes}), - &d_sorted_scores)); - Tensor d_sorted_boxes; - OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT, - TensorShape({num_boxes, 4}), - &d_sorted_boxes)); - - // this will return sorted scores and their indices - auto config = GetGpuLaunchConfig(num_boxes, device); - // initialize box and score indices - TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count, - config.thread_per_block, 0, device.stream(), - config.virtual_thread_count, 0, - d_indices.flat<int>().data())); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - cuda_ret = cub::DeviceRadixSort::SortPairsDescending( - d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes, - scores.flat<float>().data(), d_sorted_scores.flat<float>().data(), - d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(), - num_boxes, 0, - 8 * sizeof(float), // sort all bits - cuda_stream); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret); - // get pointers for easy access - const float4* original_boxes = - reinterpret_cast<const float4*>(boxes.flat<float>().data()); - float4* sorted_boxes = - reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data()); - const int* sorted_indices = d_sorted_indices.flat<int>().data(); - // sort boxes using indices - TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, - config.block_count, config.thread_per_block, 0, - device.stream(), config.virtual_thread_count, - sorted_indices, original_boxes, sorted_boxes)); - - // Unfortunately we had to sort scores to find the number of boxes which has - // a threshold above score_threshold_val. It can be done before sorting but - // that would require either implementing a custom sort or a generic random - // access iterator for cub. For the time being we search for the location of - // the score_threshold_val in the sorted array and limit num_boxes to its - // index. - GreaterThanCubOp score_limit(score_threshold_val); - int limited_num_boxes = 0; - OP_REQUIRES_OK(context, - CountIf(context, d_sorted_scores.flat<float>().data(), - score_limit, num_boxes, &limited_num_boxes)); - if (limited_num_boxes == 0) { - Tensor* output_indices = nullptr; - VLOG(1) << "Number of boxes above score threshold " << score_threshold_val - << " is 0"; - OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), - &output_indices)); - return; - } else { - VLOG(2) << "Number of boxes above threshold=" << score_threshold_val - << " is " << limited_num_boxes; - } - int num_to_keep = 0; - // There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2, - // flip boxes if necessary! - const bool flip_boxes = true; - auto status = NmsGpu( - d_sorted_boxes.flat<float>().data(), limited_num_boxes, - iou_threshold_val, d_selected_indices.flat<int>().data(), &num_to_keep, - context, output_size, flip_boxes, /*legacy_mode*/ false); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); - if (!status.ok()) { - context->SetStatus(status); - return; - } - Tensor* output_indices = nullptr; - int num_outputs = std::min(num_to_keep, output_size); // no padding! - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({num_outputs}), - &output_indices)); - if (num_outputs == 0) { - VLOG(1) << "No outputs!"; - return; - } else { - VLOG(2) << "Num outputs= " << num_outputs; - } - config = GetGpuLaunchConfig(num_outputs, device); - TF_CHECK_OK(GpuLaunchKernel( - IndexMultiSelect<int, int>, config.block_count, config.thread_per_block, - 0, device.stream(), config.virtual_thread_count, - d_selected_indices.flat<int>().data(), sorted_indices, - (*output_indices).flat<int>().data())); - TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError()); + OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size, + iou_threshold_val, score_threshold_val)); } }; diff --git a/tensorflow/core/kernels/non_max_suppression_op.h b/tensorflow/core/kernels/non_max_suppression_op.h index fbf4dbfcd1a..120c4c16223 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.h +++ b/tensorflow/core/kernels/non_max_suppression_op.h @@ -54,8 +54,7 @@ extern const int kNmsBoxesPerTread; Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes, const float iou_threshold, int* d_selected_indices, int* h_num_boxes_to_keep, OpKernelContext* context, - const int max_boxes, bool flip_boxes = false, - bool legacy_mode = false); + const int max_boxes, bool flip_boxes = false); #endif } // namespace tensorflow From 178ed66d97ae36e7a0cf8d9d4e9626699f193bfb Mon Sep 17 00:00:00 2001 From: Sami <skama@nvidia.com> Date: Mon, 9 Sep 2019 15:04:47 -0700 Subject: [PATCH 5/5] Fix numeric_limits::min()->lowest() and clang-format --- .../core/kernels/non_max_suppression_op.cu.cc | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index af3b36a464d..de8bc5f3428 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -498,10 +498,10 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes, original_boxes, sorted_boxes)); int limited_num_boxes = num_boxes; // filter boxes by scores if nms v3 - if (score_threshold > std::numeric_limits<float>::min()) { + if (score_threshold > std::numeric_limits<float>::lowest()) { GreaterThanCubOp score_limit(score_threshold); TF_RETURN_IF_ERROR(CountIf(context, d_sorted_scores.flat<float>().data(), - score_limit, num_boxes, &limited_num_boxes)); + score_limit, num_boxes, &limited_num_boxes)); if (limited_num_boxes == 0) { Tensor* output_indices = nullptr; VLOG(1) << "Number of boxes above score threshold " << score_threshold @@ -510,8 +510,8 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes, context->allocate_output(0, TensorShape({0}), &output_indices)); return Status::OK(); } else { - VLOG(2) << "Number of boxes above threshold=" << score_threshold - << " is " << limited_num_boxes; + VLOG(2) << "Number of boxes above threshold=" << score_threshold << " is " + << limited_num_boxes; } } int num_to_keep = 0; @@ -580,8 +580,8 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { // otherwise tests fail! if (num_boxes == 0) { Tensor* output_indices = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, TensorShape({0}), &output_indices)); + OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}), + &output_indices)); return; } const int64_t output_size = max_output_size.scalar<int>()(); @@ -589,7 +589,7 @@ class NonMaxSuppressionV2GPUOp : public OpKernel { context, DoNMS(context, boxes, scores, output_size, iou_threshold_val, /*score_threshold is float min if score threshold is disabled*/ - std::numeric_limits<float>::min())); + std::numeric_limits<float>::lowest())); } };