Pure GPU NMS implementation.
This commit is contained in:
parent
b2b8e56f5f
commit
c2a4931076
@ -128,6 +128,47 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
|
|||||||
if (box.x1 > box.x2) Swap(box.x1, box.x2);
|
if (box.x1 > box.x2) Swap(box.x1, box.x2);
|
||||||
if (box.y1 > box.y2) Swap(box.y1, box.y2);
|
if (box.y1 > box.y2) Swap(box.y1, box.y2);
|
||||||
}
|
}
|
||||||
|
template <typename T>
|
||||||
|
__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) {
|
||||||
|
constexpr int SHIFTLEN = NumBits(8 * sizeof(T)) - 1;
|
||||||
|
constexpr int REMAINDER_MASK = 8 * sizeof(T) - 1;
|
||||||
|
int bin = bit >> SHIFTLEN;
|
||||||
|
return (bit_mask[bin] >> (bit & REMAINDER_MASK)) & 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Produce a global bitmask (result_mask) of selected boxes from bitmask
|
||||||
|
// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask
|
||||||
|
// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box.
|
||||||
|
__global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
|
||||||
|
const int num_boxes, const int max_boxes,
|
||||||
|
char* result_mask) {
|
||||||
|
extern __shared__ int local[];
|
||||||
|
// set global mask to accept all boxes
|
||||||
|
for (int box : CudaGridRangeX(bit_mask_len)) {
|
||||||
|
local[box] = 0xFFFFFFFF;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
int accepted_boxes = 0;
|
||||||
|
for (int box = 0; box < num_boxes - 1; ++box) {
|
||||||
|
// if current box is masked by an earlier box, skip it.
|
||||||
|
if (!CheckBit(local, box)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
accepted_boxes += 1;
|
||||||
|
int offset = box * bit_mask_len;
|
||||||
|
// update global mask with current box's mask
|
||||||
|
for (int b : CudaGridRangeX(bit_mask_len)) {
|
||||||
|
local[b] &= ~bitmask[offset + b];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (accepted_boxes > max_boxes) break;
|
||||||
|
}
|
||||||
|
// copy global mask to result_max char array. char array is needed for
|
||||||
|
// cub::DeviceSelect later.
|
||||||
|
for (int box : CudaGridRangeX(num_boxes)) {
|
||||||
|
result_mask[box] = CheckBit(local, box);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For each box, compute a bitmask of boxes which has an overlap with given box
|
// For each box, compute a bitmask of boxes which has an overlap with given box
|
||||||
// above threshold.
|
// above threshold.
|
||||||
@ -235,7 +276,8 @@ __global__ void Iota(const int num_elements, const T offset, T* to_fill) {
|
|||||||
|
|
||||||
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||||
const float iou_threshold, int* d_selected_indices, int* h_nkeep,
|
const float iou_threshold, int* d_selected_indices, int* h_nkeep,
|
||||||
OpKernelContext* context, bool flip_boxes, bool legacy_mode) {
|
OpKernelContext* context, const int max_boxes, bool flip_boxes,
|
||||||
|
bool legacy_mode) {
|
||||||
// Making sure we respect the __align(16)__
|
// Making sure we respect the __align(16)__
|
||||||
// we promised to the compiler.
|
// we promised to the compiler.
|
||||||
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
|
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
|
||||||
@ -243,7 +285,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
return errors::InvalidArgument("Boxes should be aligned to 16 Bytes.");
|
return errors::InvalidArgument("Boxes should be aligned to 16 Bytes.");
|
||||||
}
|
}
|
||||||
// allocate bitmask arrays on host and on device
|
// allocate bitmask arrays on host and on device
|
||||||
Tensor h_nms_mask, d_nms_mask;
|
Tensor h_num_selected, d_nms_mask;
|
||||||
const int bit_mask_len =
|
const int bit_mask_len =
|
||||||
(num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
|
(num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
|
||||||
|
|
||||||
@ -264,11 +306,11 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
// Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
|
// Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
|
||||||
// using it as a ring buffer. However savings should be a few MB .
|
// using it as a ring buffer. However savings should be a few MB .
|
||||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||||
TensorShape({max_nms_mask_size}),
|
TensorShape({1}),
|
||||||
&h_nms_mask, alloc_attr));
|
&h_num_selected, alloc_attr));
|
||||||
|
|
||||||
int* d_delete_mask = d_nms_mask.flat<int>().data();
|
int* d_delete_mask = d_nms_mask.flat<int>().data();
|
||||||
int* h_delete_mask = h_nms_mask.flat<int>().data();
|
int* h_selected_count = h_num_selected.flat<int>().data();
|
||||||
const Box* d_sorted_boxes =
|
const Box* d_sorted_boxes =
|
||||||
reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr);
|
reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr);
|
||||||
dim3 block_dim, thread_block;
|
dim3 block_dim, thread_block;
|
||||||
@ -308,59 +350,57 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
// Overlapping CPU computes and D2H memcpy
|
// Overlapping CPU computes and D2H memcpy
|
||||||
// both take about the same time
|
// both take about the same time
|
||||||
int num_to_copy = std::min(kNmsChunkSize, num_boxes);
|
|
||||||
|
config = GetGpuLaunchConfig(num_boxes, device);
|
||||||
|
Tensor selected_boxes;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes));
|
||||||
|
Tensor d_indices;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
||||||
|
config.thread_per_block, 0, device.stream(),
|
||||||
|
config.virtual_thread_count, 0,
|
||||||
|
d_indices.flat<int>().data()));
|
||||||
|
|
||||||
|
char* selected = (char*)(selected_boxes.flat<int8>().data());
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int),
|
||||||
|
device.stream(), d_delete_mask, bit_mask_len,
|
||||||
|
num_boxes, max_boxes, selected));
|
||||||
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
|
// do Cub::deviceSelect::flagged
|
||||||
|
size_t flagged_buffer_size = 0;
|
||||||
|
cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
|
||||||
|
flagged_buffer_size,
|
||||||
|
static_cast<int*>(nullptr), // input
|
||||||
|
static_cast<char*>(nullptr), // selection flag
|
||||||
|
static_cast<int*>(nullptr), // selected items
|
||||||
|
static_cast<int*>(nullptr), // num_selected
|
||||||
|
num_boxes, device.stream());
|
||||||
|
Tensor cub_scratch;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
|
||||||
|
&cub_scratch));
|
||||||
|
Tensor d_num_selected;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||||
|
TensorShape({1}), &d_num_selected));
|
||||||
|
|
||||||
|
cub::DeviceSelect::Flagged(
|
||||||
|
(void*)cub_scratch.flat<int8>().data(), // temp_storage
|
||||||
|
flagged_buffer_size,
|
||||||
|
d_indices.flat<int>().data(), // input
|
||||||
|
selected, // selection flag
|
||||||
|
d_selected_indices, // selected items
|
||||||
|
d_num_selected.flat<int>().data(), num_boxes, device.stream());
|
||||||
cudaEvent_t copy_done;
|
cudaEvent_t copy_done;
|
||||||
TF_RETURN_IF_CUDA_ERROR(
|
TF_RETURN_IF_CUDA_ERROR(
|
||||||
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
||||||
device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0],
|
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
|
||||||
num_to_copy * bit_mask_len * sizeof(int));
|
sizeof(int));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
||||||
int offset = 0;
|
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
||||||
std::vector<int> h_selected_indices;
|
*h_nkeep = *h_selected_count;
|
||||||
// Reserve worst case scenario. Since box count is not huge, this should have
|
|
||||||
// negligible footprint.
|
|
||||||
h_selected_indices.reserve(num_boxes);
|
|
||||||
std::vector<int> to_remove(bit_mask_len, 0);
|
|
||||||
while (offset < num_boxes) {
|
|
||||||
const int num_copied = num_to_copy;
|
|
||||||
int next_offset = offset + num_copied;
|
|
||||||
num_to_copy = std::min(kNmsChunkSize, num_boxes - next_offset);
|
|
||||||
if (num_to_copy > 0) {
|
|
||||||
device.memcpyDeviceToHost(&h_delete_mask[next_offset * bit_mask_len],
|
|
||||||
&d_delete_mask[next_offset * bit_mask_len],
|
|
||||||
num_to_copy * bit_mask_len * sizeof(int));
|
|
||||||
}
|
|
||||||
// Waiting for previous copy
|
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
|
||||||
if (num_to_copy > 0) {
|
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
|
||||||
}
|
|
||||||
// Starting from highest scoring box, mark any box with iou>threshold and
|
|
||||||
// lower score for deletion if current box is not marked for deletion. Add
|
|
||||||
// current box to to_keep list.
|
|
||||||
for (int i = offset; i < next_offset; ++i) {
|
|
||||||
// See the comment at the beginning of the file.
|
|
||||||
// Bit shift and logical And operations are used
|
|
||||||
// instead of division and modulo operations.
|
|
||||||
int iblock = i >> kNmsBoxesPerThreadShiftBits;
|
|
||||||
int inblock = i & kNmsBoxesPerThreadModuloMask;
|
|
||||||
if (!(to_remove[iblock] & (1 << inblock))) {
|
|
||||||
h_selected_indices.push_back(i);
|
|
||||||
int* p = &h_delete_mask[i * bit_mask_len];
|
|
||||||
for (int ib = 0; ib < bit_mask_len; ++ib) {
|
|
||||||
to_remove[ib] |= p[ib];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
offset = next_offset;
|
|
||||||
}
|
|
||||||
cudaEventDestroy(copy_done);
|
cudaEventDestroy(copy_done);
|
||||||
|
|
||||||
const int nkeep = h_selected_indices.size();
|
|
||||||
device.memcpyHostToDevice(d_selected_indices, &h_selected_indices[0],
|
|
||||||
nkeep * sizeof(int));
|
|
||||||
|
|
||||||
*h_nkeep = nkeep;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -485,10 +525,10 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
|
|||||||
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
||||||
// flip boxes if necessary!
|
// flip boxes if necessary!
|
||||||
const bool flip_boxes = true;
|
const bool flip_boxes = true;
|
||||||
auto status =
|
auto status = NmsGpu(
|
||||||
NmsGpu(d_sorted_boxes.flat<float>().data(), num_boxes,
|
d_sorted_boxes.flat<float>().data(), num_boxes, iou_threshold_val,
|
||||||
iou_threshold_val, d_selected_indices.flat<int>().data(),
|
d_selected_indices.flat<int>().data(), &num_to_keep, context,
|
||||||
&num_to_keep, context, flip_boxes, /*legacy_mode*/ false);
|
output_size, flip_boxes, /*legacy_mode*/ false);
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
context->SetStatus(status);
|
context->SetStatus(status);
|
||||||
@ -535,13 +575,10 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
|
|||||||
|
|
||||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
|
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
context,
|
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
|
||||||
context->allocate_temp(DataType::DT_INT8,
|
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||||
TensorShape({(int64)workspace_size}), &workspace));
|
TensorShape({1}), &element_count));
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
context, context->allocate_temp(DataType::DT_INT32, TensorShape({1}),
|
|
||||||
&element_count));
|
|
||||||
cudaEvent_t copy_done;
|
cudaEvent_t copy_done;
|
||||||
TF_RETURN_IF_CUDA_ERROR(
|
TF_RETURN_IF_CUDA_ERROR(
|
||||||
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
||||||
@ -697,15 +734,18 @@ class NonMaxSuppressionV3GPUOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
|
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
|
||||||
&output_indices));
|
&output_indices));
|
||||||
return;
|
return;
|
||||||
|
} else {
|
||||||
|
VLOG(2) << "Number of boxes above threshold=" << score_threshold_val
|
||||||
|
<< " is " << limited_num_boxes;
|
||||||
}
|
}
|
||||||
int num_to_keep = 0;
|
int num_to_keep = 0;
|
||||||
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
||||||
// flip boxes if necessary!
|
// flip boxes if necessary!
|
||||||
const bool flip_boxes = true;
|
const bool flip_boxes = true;
|
||||||
auto status =
|
auto status = NmsGpu(
|
||||||
NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes,
|
d_sorted_boxes.flat<float>().data(), limited_num_boxes,
|
||||||
iou_threshold_val, d_selected_indices.flat<int>().data(),
|
iou_threshold_val, d_selected_indices.flat<int>().data(), &num_to_keep,
|
||||||
&num_to_keep, context, flip_boxes, /*legacy_mode*/ false);
|
context, output_size, flip_boxes, /*legacy_mode*/ false);
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
context->SetStatus(status);
|
context->SetStatus(status);
|
||||||
@ -716,7 +756,12 @@ class NonMaxSuppressionV3GPUOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, TensorShape({num_outputs}),
|
context->allocate_output(0, TensorShape({num_outputs}),
|
||||||
&output_indices));
|
&output_indices));
|
||||||
if (num_outputs == 0) return;
|
if (num_outputs == 0) {
|
||||||
|
VLOG(1) << "No outputs!";
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
VLOG(2) << "Num outputs= " << num_outputs;
|
||||||
|
}
|
||||||
config = GetGpuLaunchConfig(num_outputs, device);
|
config = GetGpuLaunchConfig(num_outputs, device);
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
||||||
|
@ -16,10 +16,10 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_NON_MAX_SUPPRESSION_OP_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace functor {
|
namespace functor {
|
||||||
@ -54,7 +54,8 @@ extern const int kNmsBoxesPerTread;
|
|||||||
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||||
const float iou_threshold, int* d_selected_indices,
|
const float iou_threshold, int* d_selected_indices,
|
||||||
int* h_num_boxes_to_keep, OpKernelContext* context,
|
int* h_num_boxes_to_keep, OpKernelContext* context,
|
||||||
bool flip_boxes = false,bool legacy_mode=false);
|
const int max_boxes, bool flip_boxes = false,
|
||||||
|
bool legacy_mode = false);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
x
Reference in New Issue
Block a user