Pure GPU NMS implementation.

This commit is contained in:
Sami 2019-08-22 18:47:03 -07:00
parent b2b8e56f5f
commit c2a4931076
2 changed files with 117 additions and 71 deletions

View File

@ -128,6 +128,47 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
if (box.x1 > box.x2) Swap(box.x1, box.x2);
if (box.y1 > box.y2) Swap(box.y1, box.y2);
}
template <typename T>
__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) {
constexpr int SHIFTLEN = NumBits(8 * sizeof(T)) - 1;
constexpr int REMAINDER_MASK = 8 * sizeof(T) - 1;
int bin = bit >> SHIFTLEN;
return (bit_mask[bin] >> (bit & REMAINDER_MASK)) & 1;
}
// Produce a global bitmask (result_mask) of selected boxes from bitmask
// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask
// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box.
__global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
const int num_boxes, const int max_boxes,
char* result_mask) {
extern __shared__ int local[];
// set global mask to accept all boxes
for (int box : CudaGridRangeX(bit_mask_len)) {
local[box] = 0xFFFFFFFF;
}
__syncthreads();
int accepted_boxes = 0;
for (int box = 0; box < num_boxes - 1; ++box) {
// if current box is masked by an earlier box, skip it.
if (!CheckBit(local, box)) {
continue;
}
accepted_boxes += 1;
int offset = box * bit_mask_len;
// update global mask with current box's mask
for (int b : CudaGridRangeX(bit_mask_len)) {
local[b] &= ~bitmask[offset + b];
}
__syncthreads();
if (accepted_boxes > max_boxes) break;
}
// copy global mask to result_max char array. char array is needed for
// cub::DeviceSelect later.
for (int box : CudaGridRangeX(num_boxes)) {
result_mask[box] = CheckBit(local, box);
}
}
// For each box, compute a bitmask of boxes which has an overlap with given box
// above threshold.
@ -235,7 +276,8 @@ __global__ void Iota(const int num_elements, const T offset, T* to_fill) {
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
const float iou_threshold, int* d_selected_indices, int* h_nkeep,
OpKernelContext* context, bool flip_boxes, bool legacy_mode) {
OpKernelContext* context, const int max_boxes, bool flip_boxes,
bool legacy_mode) {
// Making sure we respect the __align(16)__
// we promised to the compiler.
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
@ -243,7 +285,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
return errors::InvalidArgument("Boxes should be aligned to 16 Bytes.");
}
// allocate bitmask arrays on host and on device
Tensor h_nms_mask, d_nms_mask;
Tensor h_num_selected, d_nms_mask;
const int bit_mask_len =
(num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
@ -264,11 +306,11 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
// Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
// using it as a ring buffer. However savings should be a few MB .
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({max_nms_mask_size}),
&h_nms_mask, alloc_attr));
TensorShape({1}),
&h_num_selected, alloc_attr));
int* d_delete_mask = d_nms_mask.flat<int>().data();
int* h_delete_mask = h_nms_mask.flat<int>().data();
int* h_selected_count = h_num_selected.flat<int>().data();
const Box* d_sorted_boxes =
reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr);
dim3 block_dim, thread_block;
@ -308,59 +350,57 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
// Overlapping CPU computes and D2H memcpy
// both take about the same time
int num_to_copy = std::min(kNmsChunkSize, num_boxes);
config = GetGpuLaunchConfig(num_boxes, device);
Tensor selected_boxes;
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes));
Tensor d_indices;
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
config.thread_per_block, 0, device.stream(),
config.virtual_thread_count, 0,
d_indices.flat<int>().data()));
char* selected = (char*)(selected_boxes.flat<int8>().data());
TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int),
device.stream(), d_delete_mask, bit_mask_len,
num_boxes, max_boxes, selected));
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
// do Cub::deviceSelect::flagged
size_t flagged_buffer_size = 0;
cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
flagged_buffer_size,
static_cast<int*>(nullptr), // input
static_cast<char*>(nullptr), // selection flag
static_cast<int*>(nullptr), // selected items
static_cast<int*>(nullptr), // num_selected
num_boxes, device.stream());
Tensor cub_scratch;
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
&cub_scratch));
Tensor d_num_selected;
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({1}), &d_num_selected));
cub::DeviceSelect::Flagged(
(void*)cub_scratch.flat<int8>().data(), // temp_storage
flagged_buffer_size,
d_indices.flat<int>().data(), // input
selected, // selection flag
d_selected_indices, // selected items
d_num_selected.flat<int>().data(), num_boxes, device.stream());
cudaEvent_t copy_done;
TF_RETURN_IF_CUDA_ERROR(
cudaEventCreateWithFlags(&copy_done, cudaEventDisableTiming));
device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0],
num_to_copy * bit_mask_len * sizeof(int));
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
sizeof(int));
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
int offset = 0;
std::vector<int> h_selected_indices;
// Reserve worst case scenario. Since box count is not huge, this should have
// negligible footprint.
h_selected_indices.reserve(num_boxes);
std::vector<int> to_remove(bit_mask_len, 0);
while (offset < num_boxes) {
const int num_copied = num_to_copy;
int next_offset = offset + num_copied;
num_to_copy = std::min(kNmsChunkSize, num_boxes - next_offset);
if (num_to_copy > 0) {
device.memcpyDeviceToHost(&h_delete_mask[next_offset * bit_mask_len],
&d_delete_mask[next_offset * bit_mask_len],
num_to_copy * bit_mask_len * sizeof(int));
}
// Waiting for previous copy
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
if (num_to_copy > 0) {
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
}
// Starting from highest scoring box, mark any box with iou>threshold and
// lower score for deletion if current box is not marked for deletion. Add
// current box to to_keep list.
for (int i = offset; i < next_offset; ++i) {
// See the comment at the beginning of the file.
// Bit shift and logical And operations are used
// instead of division and modulo operations.
int iblock = i >> kNmsBoxesPerThreadShiftBits;
int inblock = i & kNmsBoxesPerThreadModuloMask;
if (!(to_remove[iblock] & (1 << inblock))) {
h_selected_indices.push_back(i);
int* p = &h_delete_mask[i * bit_mask_len];
for (int ib = 0; ib < bit_mask_len; ++ib) {
to_remove[ib] |= p[ib];
}
}
}
offset = next_offset;
}
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
*h_nkeep = *h_selected_count;
cudaEventDestroy(copy_done);
const int nkeep = h_selected_indices.size();
device.memcpyHostToDevice(d_selected_indices, &h_selected_indices[0],
nkeep * sizeof(int));
*h_nkeep = nkeep;
return Status::OK();
}
@ -485,10 +525,10 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
// flip boxes if necessary!
const bool flip_boxes = true;
auto status =
NmsGpu(d_sorted_boxes.flat<float>().data(), num_boxes,
iou_threshold_val, d_selected_indices.flat<int>().data(),
&num_to_keep, context, flip_boxes, /*legacy_mode*/ false);
auto status = NmsGpu(
d_sorted_boxes.flat<float>().data(), num_boxes, iou_threshold_val,
d_selected_indices.flat<int>().data(), &num_to_keep, context,
output_size, flip_boxes, /*legacy_mode*/ false);
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
if (!status.ok()) {
context->SetStatus(status);
@ -535,13 +575,10 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
TF_RETURN_IF_ERROR(
context,
context->allocate_temp(DataType::DT_INT8,
TensorShape({(int64)workspace_size}), &workspace));
TF_RETURN_IF_ERROR(
context, context->allocate_temp(DataType::DT_INT32, TensorShape({1}),
&element_count));
TF_RETURN_IF_ERROR(context->allocate_temp(
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({1}), &element_count));
cudaEvent_t copy_done;
TF_RETURN_IF_CUDA_ERROR(
cudaEventCreateWithFlags(&copy_done, cudaEventDisableTiming));
@ -697,15 +734,18 @@ class NonMaxSuppressionV3GPUOp : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
&output_indices));
return;
} else {
VLOG(2) << "Number of boxes above threshold=" << score_threshold_val
<< " is " << limited_num_boxes;
}
int num_to_keep = 0;
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
// flip boxes if necessary!
const bool flip_boxes = true;
auto status =
NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes,
iou_threshold_val, d_selected_indices.flat<int>().data(),
&num_to_keep, context, flip_boxes, /*legacy_mode*/ false);
auto status = NmsGpu(
d_sorted_boxes.flat<float>().data(), limited_num_boxes,
iou_threshold_val, d_selected_indices.flat<int>().data(), &num_to_keep,
context, output_size, flip_boxes, /*legacy_mode*/ false);
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
if (!status.ok()) {
context->SetStatus(status);
@ -716,7 +756,12 @@ class NonMaxSuppressionV3GPUOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({num_outputs}),
&output_indices));
if (num_outputs == 0) return;
if (num_outputs == 0) {
VLOG(1) << "No outputs!";
return;
} else {
VLOG(2) << "Num outputs= " << num_outputs;
}
config = GetGpuLaunchConfig(num_outputs, device);
TF_CHECK_OK(GpuLaunchKernel(
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,

View File

@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_
#define TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
namespace functor {
@ -54,7 +54,8 @@ extern const int kNmsBoxesPerTread;
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
const float iou_threshold, int* d_selected_indices,
int* h_num_boxes_to_keep, OpKernelContext* context,
bool flip_boxes = false,bool legacy_mode=false);
const int max_boxes, bool flip_boxes = false,
bool legacy_mode = false);
#endif
} // namespace tensorflow