Addressing review comments
This commit is contained in:
parent
e882f17498
commit
736eba374e
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
#include <limits>
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
@ -80,19 +81,8 @@ __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) {
|
|||||||
b = c;
|
b = c;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool T>
|
|
||||||
__device__ float legacy_offset(float);
|
|
||||||
template <>
|
|
||||||
__device__ EIGEN_STRONG_INLINE float legacy_offset<true>(float a) {
|
|
||||||
return a + 1.0;
|
|
||||||
}
|
|
||||||
template <>
|
|
||||||
__device__ EIGEN_STRONG_INLINE float legacy_offset<false>(float a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check whether two boxes have an IoU greater than threshold.
|
// Check whether two boxes have an IoU greater than threshold.
|
||||||
template <typename T, bool L>
|
template <typename T>
|
||||||
__device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b,
|
__device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b,
|
||||||
const float a_area,
|
const float a_area,
|
||||||
const T iou_threshold) {
|
const T iou_threshold) {
|
||||||
@ -104,8 +94,8 @@ __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b,
|
|||||||
const float yy2 = fminf(a->y2, b->y2);
|
const float yy2 = fminf(a->y2, b->y2);
|
||||||
|
|
||||||
// fdimf computes the positive difference between xx2+1 and xx1.
|
// fdimf computes the positive difference between xx2+1 and xx1.
|
||||||
const float w = fdimf(legacy_offset<L>(xx2), xx1);
|
const float w = fdimf(xx2, xx1);
|
||||||
const float h = fdimf(legacy_offset<L>(yy2), yy1);
|
const float h = fdimf(yy2, yy1);
|
||||||
const float intersection = w * h;
|
const float intersection = w * h;
|
||||||
|
|
||||||
// Testing for aa/bb > t
|
// Testing for aa/bb > t
|
||||||
@ -130,10 +120,10 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
|
|||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) {
|
__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) {
|
||||||
constexpr int SHIFTLEN = NumBits(8 * sizeof(T)) - 1;
|
constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1;
|
||||||
constexpr int REMAINDER_MASK = 8 * sizeof(T) - 1;
|
constexpr int kRemainderMask = 8 * sizeof(T) - 1;
|
||||||
int bin = bit >> SHIFTLEN;
|
int bin = bit >> kShiftLen;
|
||||||
return (bit_mask[bin] >> (bit & REMAINDER_MASK)) & 1;
|
return (bit_mask[bin] >> (bit & kRemainderMask)) & 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Produce a global bitmask (result_mask) of selected boxes from bitmask
|
// Produce a global bitmask (result_mask) of selected boxes from bitmask
|
||||||
@ -180,11 +170,11 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
|
|||||||
// If flip_box is true boxes may have x1>x2 and or y1>y2. If so change the
|
// If flip_box is true boxes may have x1>x2 and or y1>y2. If so change the
|
||||||
// coordinates such that for all boxes x1<x2 and y1<y2. Else boxes should have
|
// coordinates such that for all boxes x1<x2 and y1<y2. Else boxes should have
|
||||||
// x1<x2 and y1<y2.
|
// x1<x2 and y1<y2.
|
||||||
template <bool flip_box, bool legacy_mode>
|
template <bool flip_box>
|
||||||
__launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
|
__launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
|
||||||
void NMSKernel(const Box* d_desc_sorted_boxes,
|
void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes,
|
||||||
const int num_boxes, const float iou_threshold,
|
const float iou_threshold, const int bit_mask_len,
|
||||||
const int bit_mask_len, int* d_delete_mask) {
|
int* d_delete_mask) {
|
||||||
// Storing boxes used by this CUDA block in the shared memory.
|
// Storing boxes used by this CUDA block in the shared memory.
|
||||||
__shared__ Box shared_i_boxes[kNmsBlockDim];
|
__shared__ Box shared_i_boxes[kNmsBlockDim];
|
||||||
// Same thing with areas
|
// Same thing with areas
|
||||||
@ -224,8 +214,8 @@ __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
|
|||||||
Box j_box = d_desc_sorted_boxes[j];
|
Box j_box = d_desc_sorted_boxes[j];
|
||||||
const Box i_box = shared_i_boxes[threadIdx.x];
|
const Box i_box = shared_i_boxes[threadIdx.x];
|
||||||
Flipped<flip_box>(j_box);
|
Flipped<flip_box>(j_box);
|
||||||
if (OverThreshold<float, legacy_mode>(
|
if (OverThreshold<float>(&i_box, &j_box, shared_i_areas[threadIdx.x],
|
||||||
&i_box, &j_box, shared_i_areas[threadIdx.x], iou_threshold)) {
|
iou_threshold)) {
|
||||||
// we have score[j] <= score[i].
|
// we have score[j] <= score[i].
|
||||||
above_threshold |= (1U << ib);
|
above_threshold |= (1U << ib);
|
||||||
}
|
}
|
||||||
@ -247,8 +237,7 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
|||||||
template <typename Index, typename T, typename... Args>
|
template <typename Index, typename T, typename... Args>
|
||||||
__device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
__device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
||||||
const Index i_original,
|
const Index i_original,
|
||||||
const T* original,
|
const T* original, T* selected,
|
||||||
T* selected,
|
|
||||||
Args... args) {
|
Args... args) {
|
||||||
selected[i_selected] = original[i_original];
|
selected[i_selected] = original[i_original];
|
||||||
SelectHelper(i_selected, i_original, args...);
|
SelectHelper(i_selected, i_original, args...);
|
||||||
@ -261,18 +250,15 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
|||||||
// IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
|
// IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
|
||||||
// selected2).
|
// selected2).
|
||||||
template <typename Index, typename T, typename... Args>
|
template <typename Index, typename T, typename... Args>
|
||||||
__global__ void IndexMultiSelect(const int num_elements,
|
__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
|
||||||
const Index* indices,
|
const T* original, T* selected, Args... args) {
|
||||||
const T* original,
|
|
||||||
T* selected, Args... args) {
|
|
||||||
for (const int idx : CudaGridRangeX(num_elements)) {
|
for (const int idx : CudaGridRangeX(num_elements)) {
|
||||||
SelectHelper(idx, indices[idx], original, selected, args...);
|
SelectHelper(idx, indices[idx], original, selected, args...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void Iota(const int num_elements, const T offset,
|
__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
|
||||||
T* to_fill) {
|
|
||||||
for (int idx : CudaGridRangeX(num_elements)) {
|
for (int idx : CudaGridRangeX(num_elements)) {
|
||||||
to_fill[idx] = static_cast<T>(idx) + offset;
|
to_fill[idx] = static_cast<T>(idx) + offset;
|
||||||
}
|
}
|
||||||
@ -280,8 +266,7 @@ __global__ void Iota(const int num_elements, const T offset,
|
|||||||
|
|
||||||
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||||
const float iou_threshold, int* d_selected_indices, int* h_nkeep,
|
const float iou_threshold, int* d_selected_indices, int* h_nkeep,
|
||||||
OpKernelContext* context, const int max_boxes, bool flip_boxes,
|
OpKernelContext* context, const int max_boxes, bool flip_boxes) {
|
||||||
bool legacy_mode) {
|
|
||||||
// Making sure we respect the __align(16)__
|
// Making sure we respect the __align(16)__
|
||||||
// we promised to the compiler.
|
// we promised to the compiler.
|
||||||
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
|
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
|
||||||
@ -309,9 +294,8 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
alloc_attr.set_gpu_compatible(true);
|
alloc_attr.set_gpu_compatible(true);
|
||||||
// Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
|
// Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
|
||||||
// using it as a ring buffer. However savings should be a few MB .
|
// using it as a ring buffer. However savings should be a few MB .
|
||||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
TensorShape({1}),
|
DataType::DT_INT32, TensorShape({1}), &h_num_selected, alloc_attr));
|
||||||
&h_num_selected, alloc_attr));
|
|
||||||
|
|
||||||
int* d_delete_mask = d_nms_mask.flat<int>().data();
|
int* d_delete_mask = d_nms_mask.flat<int>().data();
|
||||||
int* h_selected_count = h_num_selected.flat<int>().data();
|
int* h_selected_count = h_num_selected.flat<int>().data();
|
||||||
@ -327,29 +311,13 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
thread_block.y = kNmsBlockDim;
|
thread_block.y = kNmsBlockDim;
|
||||||
thread_block.z = 1;
|
thread_block.z = 1;
|
||||||
if (flip_boxes) {
|
if (flip_boxes) {
|
||||||
if (!legacy_mode) {
|
TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true>, block_dim, thread_block, 0,
|
||||||
TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true, false>, block_dim,
|
device.stream(), d_sorted_boxes, num_boxes,
|
||||||
thread_block, 0, device.stream(),
|
iou_threshold, bit_mask_len, d_delete_mask));
|
||||||
d_sorted_boxes, num_boxes, iou_threshold,
|
|
||||||
bit_mask_len, d_delete_mask));
|
|
||||||
} else {
|
} else {
|
||||||
TF_CHECK_OK(GpuLaunchKernel(NMSKernel<true, true>, block_dim,
|
TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false>, block_dim, thread_block, 0,
|
||||||
thread_block, 0, device.stream(),
|
device.stream(), d_sorted_boxes, num_boxes,
|
||||||
d_sorted_boxes, num_boxes, iou_threshold,
|
iou_threshold, bit_mask_len, d_delete_mask));
|
||||||
bit_mask_len, d_delete_mask));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!legacy_mode) {
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false, false>, block_dim,
|
|
||||||
thread_block, 0, device.stream(),
|
|
||||||
d_sorted_boxes, num_boxes, iou_threshold,
|
|
||||||
bit_mask_len, d_delete_mask));
|
|
||||||
} else {
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(NMSKernel<false, true>, block_dim,
|
|
||||||
thread_block, 0, device.stream(),
|
|
||||||
d_sorted_boxes, num_boxes, iou_threshold,
|
|
||||||
bit_mask_len, d_delete_mask));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
// Overlapping CPU computes and D2H memcpy
|
// Overlapping CPU computes and D2H memcpy
|
||||||
@ -408,152 +376,6 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
class NonMaxSuppressionV2GPUOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context)
|
|
||||||
: OpKernel(context) {}
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
|
||||||
// boxes: [num_boxes, 4]
|
|
||||||
const Tensor& boxes = context->input(0);
|
|
||||||
// scores: [num_boxes]
|
|
||||||
const Tensor& scores = context->input(1);
|
|
||||||
// max_output_size: scalar
|
|
||||||
const Tensor& max_output_size = context->input(2);
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, TensorShapeUtils::IsScalar(max_output_size.shape()),
|
|
||||||
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
|
|
||||||
max_output_size.shape().DebugString()));
|
|
||||||
// iou_threshold: scalar
|
|
||||||
const Tensor& iou_threshold = context->input(3);
|
|
||||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
|
|
||||||
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
|
||||||
iou_threshold.shape().DebugString()));
|
|
||||||
const float iou_threshold_val = iou_threshold.scalar<float>()();
|
|
||||||
|
|
||||||
OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
|
|
||||||
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
|
||||||
OP_REQUIRES(context, boxes.dims() == 2,
|
|
||||||
errors::InvalidArgument("boxes must be a rank 2 tensor!"));
|
|
||||||
int num_boxes = boxes.dim_size(0);
|
|
||||||
OP_REQUIRES(context, boxes.dim_size(1) == 4,
|
|
||||||
errors::InvalidArgument("boxes must be Nx4"));
|
|
||||||
OP_REQUIRES(context, scores.dims() == 1,
|
|
||||||
errors::InvalidArgument("scores must be a vector!"));
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, scores.dim_size(0) == num_boxes,
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"scores has incompatible shape")); // message must be exactly this
|
|
||||||
// otherwise tests fail!
|
|
||||||
if (num_boxes == 0) {
|
|
||||||
Tensor* output_indices = nullptr;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
|
|
||||||
&output_indices));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const int output_size = max_output_size.scalar<int>()();
|
|
||||||
size_t cub_sort_temp_storage_bytes = 0;
|
|
||||||
auto cuda_stream = GetGpuStream(context);
|
|
||||||
auto device = context->eigen_gpu_device();
|
|
||||||
// Calling cub with nullptrs as inputs will make it return
|
|
||||||
// workspace size needed for the operation instead of doing the operation.
|
|
||||||
// In this specific instance, cub_sort_temp_storage_bytes will contain the
|
|
||||||
// necessary workspace size for sorting after the call.
|
|
||||||
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
|
||||||
nullptr, cub_sort_temp_storage_bytes,
|
|
||||||
static_cast<float*>(nullptr), // scores
|
|
||||||
static_cast<float*>(nullptr), // sorted scores
|
|
||||||
static_cast<int*>(nullptr), // input indices
|
|
||||||
static_cast<int*>(nullptr), // sorted indices
|
|
||||||
num_boxes, // num items
|
|
||||||
0, 8 * sizeof(float), // sort all bits
|
|
||||||
cuda_stream);
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret);
|
|
||||||
Tensor d_cub_sort_buffer;
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->allocate_temp(
|
|
||||||
DataType::DT_INT8,
|
|
||||||
TensorShape({(int64)cub_sort_temp_storage_bytes}),
|
|
||||||
&d_cub_sort_buffer));
|
|
||||||
Tensor d_indices;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
context, context->allocate_temp(DataType::DT_INT32,
|
|
||||||
TensorShape({num_boxes}), &d_indices));
|
|
||||||
Tensor d_sorted_indices;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32,
|
|
||||||
TensorShape({num_boxes}),
|
|
||||||
&d_sorted_indices));
|
|
||||||
Tensor d_selected_indices;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32,
|
|
||||||
TensorShape({num_boxes}),
|
|
||||||
&d_selected_indices));
|
|
||||||
Tensor d_sorted_scores;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT,
|
|
||||||
TensorShape({num_boxes}),
|
|
||||||
&d_sorted_scores));
|
|
||||||
Tensor d_sorted_boxes;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT,
|
|
||||||
TensorShape({num_boxes, 4}),
|
|
||||||
&d_sorted_boxes));
|
|
||||||
|
|
||||||
// this will return sorted scores and their indices
|
|
||||||
auto config = GetGpuLaunchConfig(num_boxes, device);
|
|
||||||
// initialize box and score indices
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
|
||||||
config.thread_per_block, 0, device.stream(),
|
|
||||||
config.virtual_thread_count, 0,
|
|
||||||
d_indices.flat<int>().data()));
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
|
||||||
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
|
||||||
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
|
||||||
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
|
||||||
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
|
||||||
num_boxes, 0,
|
|
||||||
8 * sizeof(float), // sort all bits
|
|
||||||
cuda_stream);
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret);
|
|
||||||
|
|
||||||
// get pointers for easy access
|
|
||||||
const float4* original_boxes =
|
|
||||||
reinterpret_cast<const float4*>(boxes.flat<float>().data());
|
|
||||||
float4* sorted_boxes =
|
|
||||||
reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
|
|
||||||
const int* sorted_indices = d_sorted_indices.flat<int>().data();
|
|
||||||
// sort boxes using indices
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>,
|
|
||||||
config.block_count, config.thread_per_block, 0,
|
|
||||||
device.stream(), config.virtual_thread_count,
|
|
||||||
sorted_indices, original_boxes, sorted_boxes));
|
|
||||||
|
|
||||||
int num_to_keep = 0;
|
|
||||||
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
|
||||||
// flip boxes if necessary!
|
|
||||||
const bool flip_boxes = true;
|
|
||||||
auto status = NmsGpu(
|
|
||||||
d_sorted_boxes.flat<float>().data(), num_boxes, iou_threshold_val,
|
|
||||||
d_selected_indices.flat<int>().data(), &num_to_keep, context,
|
|
||||||
output_size, flip_boxes, /*legacy_mode*/ false);
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
|
||||||
if (!status.ok()) {
|
|
||||||
context->SetStatus(status);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Tensor* output_indices = nullptr;
|
|
||||||
int num_outputs = std::min(num_to_keep, output_size); // no padding!
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->allocate_output(0, TensorShape({num_outputs}),
|
|
||||||
&output_indices));
|
|
||||||
if (num_outputs == 0) return;
|
|
||||||
config = GetGpuLaunchConfig(num_outputs, device);
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
|
||||||
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
|
||||||
0, device.stream(), config.virtual_thread_count,
|
|
||||||
d_selected_indices.flat<int>().data(), sorted_indices,
|
|
||||||
(*output_indices).flat<int>().data()));
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GreaterThanCubOp {
|
struct GreaterThanCubOp {
|
||||||
float threshold_;
|
float threshold_;
|
||||||
__host__ __device__ __forceinline__ GreaterThanCubOp(float threshold)
|
__host__ __device__ __forceinline__ GreaterThanCubOp(float threshold)
|
||||||
@ -597,6 +419,180 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DoNMS(OpKernelContext* context, const Tensor& boxes,
|
||||||
|
const Tensor& scores, const int64_t max_output_size,
|
||||||
|
const float iou_threshold_val, const float score_threshold) {
|
||||||
|
const int output_size = max_output_size;
|
||||||
|
int num_boxes = boxes.dim_size(0);
|
||||||
|
size_t cub_sort_temp_storage_bytes = 0;
|
||||||
|
auto cuda_stream = GetGpuStream(context);
|
||||||
|
auto device = context->eigen_gpu_device();
|
||||||
|
// Calling cub with nullptrs as inputs will make it return
|
||||||
|
// workspace size needed for the operation instead of doing the operation.
|
||||||
|
// In this specific instance, cub_sort_temp_storage_bytes will contain the
|
||||||
|
// necessary workspace size for sorting after the call.
|
||||||
|
if (num_boxes == 0) {
|
||||||
|
Tensor* output_indices = nullptr;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(0, TensorShape({0}), &output_indices));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
||||||
|
nullptr, cub_sort_temp_storage_bytes,
|
||||||
|
static_cast<float*>(nullptr), // scores
|
||||||
|
static_cast<float*>(nullptr), // sorted scores
|
||||||
|
static_cast<int*>(nullptr), // input indices
|
||||||
|
static_cast<int*>(nullptr), // sorted indices
|
||||||
|
num_boxes, // num items
|
||||||
|
0, 8 * sizeof(float), // sort all bits
|
||||||
|
cuda_stream);
|
||||||
|
TF_RETURN_IF_CUDA_ERROR(cuda_ret);
|
||||||
|
Tensor d_cub_sort_buffer;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}),
|
||||||
|
&d_cub_sort_buffer));
|
||||||
|
Tensor d_indices;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
|
||||||
|
Tensor d_sorted_indices;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT32, TensorShape({num_boxes}), &d_sorted_indices));
|
||||||
|
Tensor d_selected_indices;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_INT32, TensorShape({num_boxes}), &d_selected_indices));
|
||||||
|
Tensor d_sorted_scores;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_FLOAT, TensorShape({num_boxes}), &d_sorted_scores));
|
||||||
|
Tensor d_sorted_boxes;
|
||||||
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
|
DataType::DT_FLOAT, TensorShape({num_boxes, 4}), &d_sorted_boxes));
|
||||||
|
|
||||||
|
// this will return sorted scores and their indices
|
||||||
|
auto config = GetGpuLaunchConfig(num_boxes, device);
|
||||||
|
// initialize box and score indices
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
||||||
|
config.thread_per_block, 0, device.stream(),
|
||||||
|
config.virtual_thread_count, 0,
|
||||||
|
d_indices.flat<int>().data()));
|
||||||
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
|
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
||||||
|
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
||||||
|
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
||||||
|
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
||||||
|
num_boxes, 0,
|
||||||
|
8 * sizeof(float), // sort all bits
|
||||||
|
cuda_stream);
|
||||||
|
TF_RETURN_IF_CUDA_ERROR(cuda_ret);
|
||||||
|
|
||||||
|
// get pointers for easy access
|
||||||
|
const float4* original_boxes =
|
||||||
|
reinterpret_cast<const float4*>(boxes.flat<float>().data());
|
||||||
|
float4* sorted_boxes =
|
||||||
|
reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
|
||||||
|
const int* sorted_indices = d_sorted_indices.flat<int>().data();
|
||||||
|
// sort boxes using indices
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, config.block_count,
|
||||||
|
config.thread_per_block, 0, device.stream(),
|
||||||
|
config.virtual_thread_count, sorted_indices,
|
||||||
|
original_boxes, sorted_boxes));
|
||||||
|
int limited_num_boxes = num_boxes;
|
||||||
|
// filter boxes by scores if nms v3
|
||||||
|
if (score_threshold > std::numeric_limits<float>::min()) {
|
||||||
|
GreaterThanCubOp score_limit(score_threshold);
|
||||||
|
TF_RETURN_IF_ERROR(CountIf(context, d_sorted_scores.flat<float>().data(),
|
||||||
|
score_limit, num_boxes, &limited_num_boxes));
|
||||||
|
if (limited_num_boxes == 0) {
|
||||||
|
Tensor* output_indices = nullptr;
|
||||||
|
VLOG(1) << "Number of boxes above score threshold " << score_threshold
|
||||||
|
<< " is 0";
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(0, TensorShape({0}), &output_indices));
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
VLOG(2) << "Number of boxes above threshold=" << score_threshold
|
||||||
|
<< " is " << limited_num_boxes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int num_to_keep = 0;
|
||||||
|
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
||||||
|
// flip boxes if necessary!
|
||||||
|
const bool flip_boxes = true;
|
||||||
|
auto status = NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes,
|
||||||
|
iou_threshold_val, d_selected_indices.flat<int>().data(),
|
||||||
|
&num_to_keep, context, output_size, flip_boxes);
|
||||||
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
|
if (!status.ok()) {
|
||||||
|
context->SetStatus(status);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
Tensor* output_indices = nullptr;
|
||||||
|
int num_outputs = std::min(num_to_keep, output_size); // no padding!
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context->allocate_output(0, TensorShape({num_outputs}), &output_indices));
|
||||||
|
if (num_outputs == 0) return Status::OK();
|
||||||
|
config = GetGpuLaunchConfig(num_outputs, device);
|
||||||
|
TF_CHECK_OK(GpuLaunchKernel(
|
||||||
|
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
||||||
|
0, device.stream(), config.virtual_thread_count,
|
||||||
|
d_selected_indices.flat<int>().data(), sorted_indices,
|
||||||
|
(*output_indices).flat<int>().data()));
|
||||||
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
class NonMaxSuppressionV2GPUOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit NonMaxSuppressionV2GPUOp(OpKernelConstruction* context)
|
||||||
|
: OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
// boxes: [num_boxes, 4]
|
||||||
|
const Tensor& boxes = context->input(0);
|
||||||
|
// scores: [num_boxes]
|
||||||
|
const Tensor& scores = context->input(1);
|
||||||
|
// max_output_size: scalar
|
||||||
|
const Tensor& max_output_size = context->input(2);
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, TensorShapeUtils::IsScalar(max_output_size.shape()),
|
||||||
|
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
|
||||||
|
max_output_size.shape().DebugString()));
|
||||||
|
// iou_threshold: scalar
|
||||||
|
const Tensor& iou_threshold = context->input(3);
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
|
||||||
|
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||||
|
iou_threshold.shape().DebugString()));
|
||||||
|
const float iou_threshold_val = iou_threshold.scalar<float>()();
|
||||||
|
|
||||||
|
OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
|
||||||
|
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
||||||
|
OP_REQUIRES(context, boxes.dims() == 2,
|
||||||
|
errors::InvalidArgument("boxes must be a rank 2 tensor!"));
|
||||||
|
int num_boxes = boxes.dim_size(0);
|
||||||
|
OP_REQUIRES(context, boxes.dim_size(1) == 4,
|
||||||
|
errors::InvalidArgument("boxes must be Nx4"));
|
||||||
|
OP_REQUIRES(context, scores.dims() == 1,
|
||||||
|
errors::InvalidArgument("scores must be a vector!"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, scores.dim_size(0) == num_boxes,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"scores has incompatible shape")); // message must be exactly this
|
||||||
|
// otherwise tests fail!
|
||||||
|
if (num_boxes == 0) {
|
||||||
|
Tensor* output_indices = nullptr;
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->allocate_output(0, TensorShape({0}), &output_indices));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t output_size = max_output_size.scalar<int>()();
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
context,
|
||||||
|
DoNMS(context, boxes, scores, output_size, iou_threshold_val,
|
||||||
|
/*score_threshold is float min if score threshold is disabled*/
|
||||||
|
std::numeric_limits<float>::min()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class NonMaxSuppressionV3GPUOp : public OpKernel {
|
class NonMaxSuppressionV3GPUOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context)
|
explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context)
|
||||||
@ -648,131 +644,8 @@ class NonMaxSuppressionV3GPUOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int output_size = max_output_size.scalar<int>()();
|
const int output_size = max_output_size.scalar<int>()();
|
||||||
size_t cub_sort_temp_storage_bytes = 0;
|
OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size,
|
||||||
auto cuda_stream = tensorflow::GetGpuStream(context);
|
iou_threshold_val, score_threshold_val));
|
||||||
auto device = context->eigen_gpu_device();
|
|
||||||
// Calling cub with nullptrs as inputs will make it return
|
|
||||||
// workspace size needed for the operation instead of doing the operation.
|
|
||||||
// In this specific instance, cub_sort_temp_storage_bytes will contain the
|
|
||||||
// necessary workspace size for sorting after the call.
|
|
||||||
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
|
||||||
nullptr, cub_sort_temp_storage_bytes,
|
|
||||||
static_cast<float*>(nullptr), // scores
|
|
||||||
static_cast<float*>(nullptr), // sorted scores
|
|
||||||
static_cast<int*>(nullptr), // input indices
|
|
||||||
static_cast<int*>(nullptr), // sorted indices
|
|
||||||
num_boxes, // num items
|
|
||||||
0, 8 * sizeof(float), // sort all bits
|
|
||||||
cuda_stream);
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret);
|
|
||||||
Tensor d_cub_sort_buffer;
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->allocate_temp(
|
|
||||||
DataType::DT_INT8,
|
|
||||||
TensorShape({(int64)cub_sort_temp_storage_bytes}),
|
|
||||||
&d_cub_sort_buffer));
|
|
||||||
Tensor d_indices;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
context, context->allocate_temp(DataType::DT_INT32,
|
|
||||||
TensorShape({num_boxes}), &d_indices));
|
|
||||||
Tensor d_sorted_indices;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32,
|
|
||||||
TensorShape({num_boxes}),
|
|
||||||
&d_sorted_indices));
|
|
||||||
Tensor d_selected_indices;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32,
|
|
||||||
TensorShape({num_boxes}),
|
|
||||||
&d_selected_indices));
|
|
||||||
Tensor d_sorted_scores;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT,
|
|
||||||
TensorShape({num_boxes}),
|
|
||||||
&d_sorted_scores));
|
|
||||||
Tensor d_sorted_boxes;
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT,
|
|
||||||
TensorShape({num_boxes, 4}),
|
|
||||||
&d_sorted_boxes));
|
|
||||||
|
|
||||||
// this will return sorted scores and their indices
|
|
||||||
auto config = GetGpuLaunchConfig(num_boxes, device);
|
|
||||||
// initialize box and score indices
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
|
||||||
config.thread_per_block, 0, device.stream(),
|
|
||||||
config.virtual_thread_count, 0,
|
|
||||||
d_indices.flat<int>().data()));
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
|
||||||
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
|
||||||
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
|
||||||
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
|
||||||
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
|
||||||
num_boxes, 0,
|
|
||||||
8 * sizeof(float), // sort all bits
|
|
||||||
cuda_stream);
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret);
|
|
||||||
// get pointers for easy access
|
|
||||||
const float4* original_boxes =
|
|
||||||
reinterpret_cast<const float4*>(boxes.flat<float>().data());
|
|
||||||
float4* sorted_boxes =
|
|
||||||
reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
|
|
||||||
const int* sorted_indices = d_sorted_indices.flat<int>().data();
|
|
||||||
// sort boxes using indices
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>,
|
|
||||||
config.block_count, config.thread_per_block, 0,
|
|
||||||
device.stream(), config.virtual_thread_count,
|
|
||||||
sorted_indices, original_boxes, sorted_boxes));
|
|
||||||
|
|
||||||
// Unfortunately we had to sort scores to find the number of boxes which has
|
|
||||||
// a threshold above score_threshold_val. It can be done before sorting but
|
|
||||||
// that would require either implementing a custom sort or a generic random
|
|
||||||
// access iterator for cub. For the time being we search for the location of
|
|
||||||
// the score_threshold_val in the sorted array and limit num_boxes to its
|
|
||||||
// index.
|
|
||||||
GreaterThanCubOp score_limit(score_threshold_val);
|
|
||||||
int limited_num_boxes = 0;
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
CountIf(context, d_sorted_scores.flat<float>().data(),
|
|
||||||
score_limit, num_boxes, &limited_num_boxes));
|
|
||||||
if (limited_num_boxes == 0) {
|
|
||||||
Tensor* output_indices = nullptr;
|
|
||||||
VLOG(1) << "Number of boxes above score threshold " << score_threshold_val
|
|
||||||
<< " is 0";
|
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
|
|
||||||
&output_indices));
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
VLOG(2) << "Number of boxes above threshold=" << score_threshold_val
|
|
||||||
<< " is " << limited_num_boxes;
|
|
||||||
}
|
|
||||||
int num_to_keep = 0;
|
|
||||||
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
|
||||||
// flip boxes if necessary!
|
|
||||||
const bool flip_boxes = true;
|
|
||||||
auto status = NmsGpu(
|
|
||||||
d_sorted_boxes.flat<float>().data(), limited_num_boxes,
|
|
||||||
iou_threshold_val, d_selected_indices.flat<int>().data(), &num_to_keep,
|
|
||||||
context, output_size, flip_boxes, /*legacy_mode*/ false);
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
|
||||||
if (!status.ok()) {
|
|
||||||
context->SetStatus(status);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Tensor* output_indices = nullptr;
|
|
||||||
int num_outputs = std::min(num_to_keep, output_size); // no padding!
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->allocate_output(0, TensorShape({num_outputs}),
|
|
||||||
&output_indices));
|
|
||||||
if (num_outputs == 0) {
|
|
||||||
VLOG(1) << "No outputs!";
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
VLOG(2) << "Num outputs= " << num_outputs;
|
|
||||||
}
|
|
||||||
config = GetGpuLaunchConfig(num_outputs, device);
|
|
||||||
TF_CHECK_OK(GpuLaunchKernel(
|
|
||||||
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
|
||||||
0, device.stream(), config.virtual_thread_count,
|
|
||||||
d_selected_indices.flat<int>().data(), sorted_indices,
|
|
||||||
(*output_indices).flat<int>().data()));
|
|
||||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -54,8 +54,7 @@ extern const int kNmsBoxesPerTread;
|
|||||||
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||||
const float iou_threshold, int* d_selected_indices,
|
const float iou_threshold, int* d_selected_indices,
|
||||||
int* h_num_boxes_to_keep, OpKernelContext* context,
|
int* h_num_boxes_to_keep, OpKernelContext* context,
|
||||||
const int max_boxes, bool flip_boxes = false,
|
const int max_boxes, bool flip_boxes = false);
|
||||||
bool legacy_mode = false);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user