Merge pull request #30893 from samikama:GPUNMSFixes
PiperOrigin-RevId: 268115881
This commit is contained in:
commit
5d6158e0d4
tensorflow/core/kernels
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#include <limits>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/cub/device/device_radix_sort.cuh"
|
||||
@ -82,10 +84,9 @@ __device__ EIGEN_STRONG_INLINE void Swap(T& a, T& b) {
|
||||
|
||||
// Check whether two boxes have an IoU greater than threshold.
|
||||
template <typename T>
|
||||
__device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* __restrict__ a,
|
||||
const Box* __restrict__ b,
|
||||
float a_area,
|
||||
T iou_threshold) {
|
||||
__device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* a, const Box* b,
|
||||
const float a_area,
|
||||
const T iou_threshold) {
|
||||
const float b_area = (b->x2 - b->x1) * (b->y2 - b->y1);
|
||||
if (a_area == 0.0f || b_area == 0.0f) return false;
|
||||
const float xx1 = fmaxf(a->x1, b->x1);
|
||||
@ -94,8 +95,8 @@ __device__ EIGEN_STRONG_INLINE bool OverThreshold(const Box* __restrict__ a,
|
||||
const float yy2 = fminf(a->y2, b->y2);
|
||||
|
||||
// fdimf computes the positive difference between xx2+1 and xx1.
|
||||
const float w = fdimf(xx2 + 1.0f, xx1);
|
||||
const float h = fdimf(yy2 + 1.0f, yy1);
|
||||
const float w = fdimf(xx2, xx1);
|
||||
const float h = fdimf(yy2, yy1);
|
||||
const float intersection = w * h;
|
||||
|
||||
// Testing for aa/bb > t
|
||||
@ -118,6 +119,47 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
|
||||
if (box.x1 > box.x2) Swap(box.x1, box.x2);
|
||||
if (box.y1 > box.y2) Swap(box.y1, box.y2);
|
||||
}
|
||||
template <typename T>
|
||||
__device__ EIGEN_STRONG_INLINE bool CheckBit(T* bit_mask, int bit) {
|
||||
constexpr int kShiftLen = NumBits(8 * sizeof(T)) - 1;
|
||||
constexpr int kRemainderMask = 8 * sizeof(T) - 1;
|
||||
int bin = bit >> kShiftLen;
|
||||
return (bit_mask[bin] >> (bit & kRemainderMask)) & 1;
|
||||
}
|
||||
|
||||
// Produce a global bitmask (result_mask) of selected boxes from bitmask
|
||||
// generated by NMSKernel Abort early if max_boxes boxes are selected. Bitmask
|
||||
// is num_boxes*bit_mask_len bits indicating whether to keep or remove a box.
|
||||
__global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
|
||||
const int num_boxes, const int max_boxes,
|
||||
char* result_mask) {
|
||||
extern __shared__ int local[];
|
||||
// set global mask to accept all boxes
|
||||
for (int box : CudaGridRangeX(bit_mask_len)) {
|
||||
local[box] = 0xFFFFFFFF;
|
||||
}
|
||||
__syncthreads();
|
||||
int accepted_boxes = 0;
|
||||
for (int box = 0; box < num_boxes - 1; ++box) {
|
||||
// if current box is masked by an earlier box, skip it.
|
||||
if (!CheckBit(local, box)) {
|
||||
continue;
|
||||
}
|
||||
accepted_boxes += 1;
|
||||
int offset = box * bit_mask_len;
|
||||
// update global mask with current box's mask
|
||||
for (int b : CudaGridRangeX(bit_mask_len)) {
|
||||
local[b] &= ~bitmask[offset + b];
|
||||
}
|
||||
__syncthreads();
|
||||
if (accepted_boxes > max_boxes) break;
|
||||
}
|
||||
// copy global mask to result_max char array. char array is needed for
|
||||
// cub::DeviceSelect later.
|
||||
for (int box : CudaGridRangeX(num_boxes)) {
|
||||
result_mask[box] = CheckBit(local, box);
|
||||
}
|
||||
}
|
||||
|
||||
// For each box, compute a bitmask of boxes which has an overlap with given box
|
||||
// above threshold.
|
||||
@ -131,9 +173,9 @@ __device__ EIGEN_STRONG_INLINE void Flipped<true>(Box& box) {
|
||||
// x1<x2 and y1<y2.
|
||||
template <bool flip_box>
|
||||
__launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
|
||||
void NMSKernel(const Box* __restrict__ d_desc_sorted_boxes,
|
||||
const int num_boxes, const float iou_threshold,
|
||||
const int bit_mask_len, int* __restrict__ d_delete_mask) {
|
||||
void NMSKernel(const Box* d_desc_sorted_boxes, const int num_boxes,
|
||||
const float iou_threshold, const int bit_mask_len,
|
||||
int* d_delete_mask) {
|
||||
// Storing boxes used by this CUDA block in the shared memory.
|
||||
__shared__ Box shared_i_boxes[kNmsBlockDim];
|
||||
// Same thing with areas
|
||||
@ -173,8 +215,8 @@ __launch_bounds__(kNmsBlockDim* kNmsBlockDim, 4) __global__
|
||||
Box j_box = d_desc_sorted_boxes[j];
|
||||
const Box i_box = shared_i_boxes[threadIdx.x];
|
||||
Flipped<flip_box>(j_box);
|
||||
if (OverThreshold(&i_box, &j_box, shared_i_areas[threadIdx.x],
|
||||
iou_threshold)) {
|
||||
if (OverThreshold<float>(&i_box, &j_box, shared_i_areas[threadIdx.x],
|
||||
iou_threshold)) {
|
||||
// we have score[j] <= score[i].
|
||||
above_threshold |= (1U << ib);
|
||||
}
|
||||
@ -196,8 +238,7 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
||||
template <typename Index, typename T, typename... Args>
|
||||
__device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
||||
const Index i_original,
|
||||
const T* __restrict__ original,
|
||||
T* __restrict__ selected,
|
||||
const T* original, T* selected,
|
||||
Args... args) {
|
||||
selected[i_selected] = original[i_original];
|
||||
SelectHelper(i_selected, i_original, args...);
|
||||
@ -210,18 +251,15 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
||||
// IndexMultiSelect(num_elements, indices, original1 ,selected1, original2,
|
||||
// selected2).
|
||||
template <typename Index, typename T, typename... Args>
|
||||
__global__ void IndexMultiSelect(const int num_elements,
|
||||
const Index* __restrict__ indices,
|
||||
const T* __restrict__ original,
|
||||
T* __restrict__ selected, Args... args) {
|
||||
__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
|
||||
const T* original, T* selected, Args... args) {
|
||||
for (const int idx : CudaGridRangeX(num_elements)) {
|
||||
SelectHelper(idx, indices[idx], original, selected, args...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Iota(const int num_elements, const T offset,
|
||||
T* __restrict__ to_fill) {
|
||||
__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
|
||||
for (int idx : CudaGridRangeX(num_elements)) {
|
||||
to_fill[idx] = static_cast<T>(idx) + offset;
|
||||
}
|
||||
@ -229,7 +267,7 @@ __global__ void Iota(const int num_elements, const T offset,
|
||||
|
||||
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||
const float iou_threshold, int* d_selected_indices, int* h_nkeep,
|
||||
OpKernelContext* context, bool flip_boxes) {
|
||||
OpKernelContext* context, const int max_boxes, bool flip_boxes) {
|
||||
// Making sure we respect the __align(16)__
|
||||
// we promised to the compiler.
|
||||
auto iptr = reinterpret_cast<std::uintptr_t>(d_sorted_boxes_float_ptr);
|
||||
@ -237,7 +275,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||
return errors::InvalidArgument("Boxes should be aligned to 16 Bytes.");
|
||||
}
|
||||
// allocate bitmask arrays on host and on device
|
||||
Tensor h_nms_mask, d_nms_mask;
|
||||
Tensor h_num_selected, d_nms_mask;
|
||||
const int bit_mask_len =
|
||||
(num_boxes + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
|
||||
|
||||
@ -257,12 +295,11 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||
alloc_attr.set_gpu_compatible(true);
|
||||
// Size of this buffer can be reduced to kNmsChunkSize*bit_mask_len*2 and
|
||||
// using it as a ring buffer. However savings should be a few MB .
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||
TensorShape({max_nms_mask_size}),
|
||||
&h_nms_mask, alloc_attr));
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT32, TensorShape({1}), &h_num_selected, alloc_attr));
|
||||
|
||||
int* d_delete_mask = d_nms_mask.flat<int>().data();
|
||||
int* h_delete_mask = h_nms_mask.flat<int>().data();
|
||||
int* h_selected_count = h_num_selected.flat<int>().data();
|
||||
const Box* d_sorted_boxes =
|
||||
reinterpret_cast<const Box*>(d_sorted_boxes_float_ptr);
|
||||
dim3 block_dim, thread_block;
|
||||
@ -286,58 +323,222 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||
// Overlapping CPU computes and D2H memcpy
|
||||
// both take about the same time
|
||||
int num_to_copy = std::min(kNmsChunkSize, num_boxes);
|
||||
|
||||
config = GetGpuLaunchConfig(num_boxes, device);
|
||||
Tensor selected_boxes;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT8, TensorShape({num_boxes}), &selected_boxes));
|
||||
Tensor d_indices;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
|
||||
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count, 0,
|
||||
d_indices.flat<int>().data()));
|
||||
|
||||
char* selected = (char*)(selected_boxes.flat<int8>().data());
|
||||
TF_CHECK_OK(GpuLaunchKernel(NMSReduce, 1, 1024, bit_mask_len * sizeof(int),
|
||||
device.stream(), d_delete_mask, bit_mask_len,
|
||||
num_boxes, max_boxes, selected));
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||
// do Cub::deviceSelect::flagged
|
||||
size_t flagged_buffer_size = 0;
|
||||
cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
|
||||
flagged_buffer_size,
|
||||
static_cast<int*>(nullptr), // input
|
||||
static_cast<char*>(nullptr), // selection flag
|
||||
static_cast<int*>(nullptr), // selected items
|
||||
static_cast<int*>(nullptr), // num_selected
|
||||
num_boxes, device.stream());
|
||||
Tensor cub_scratch;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
|
||||
&cub_scratch));
|
||||
Tensor d_num_selected;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||
TensorShape({1}), &d_num_selected));
|
||||
|
||||
cub::DeviceSelect::Flagged(
|
||||
(void*)cub_scratch.flat<int8>().data(), // temp_storage
|
||||
flagged_buffer_size,
|
||||
d_indices.flat<int>().data(), // input
|
||||
selected, // selection flag
|
||||
d_selected_indices, // selected items
|
||||
d_num_selected.flat<int>().data(), num_boxes, device.stream());
|
||||
cudaEvent_t copy_done;
|
||||
cudaEventCreate(©_done);
|
||||
device.memcpyDeviceToHost(&h_delete_mask[0], &d_delete_mask[0],
|
||||
num_to_copy * bit_mask_len * sizeof(int));
|
||||
TF_RETURN_IF_CUDA_ERROR(
|
||||
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
||||
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
|
||||
sizeof(int));
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
||||
int offset = 0;
|
||||
std::vector<int> h_selected_indices;
|
||||
// Reserve worst case scenario. Since box count is not huge, this should have
|
||||
// negligible footprint.
|
||||
h_selected_indices.reserve(num_boxes);
|
||||
std::vector<int> to_remove(bit_mask_len, 0);
|
||||
while (offset < num_boxes) {
|
||||
const int num_copied = num_to_copy;
|
||||
int next_offset = offset + num_copied;
|
||||
num_to_copy = std::min(kNmsChunkSize, num_boxes - next_offset);
|
||||
if (num_to_copy > 0) {
|
||||
device.memcpyDeviceToHost(&h_delete_mask[next_offset * bit_mask_len],
|
||||
&d_delete_mask[next_offset * bit_mask_len],
|
||||
num_to_copy * bit_mask_len * sizeof(int));
|
||||
}
|
||||
// Waiting for previous copy
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
||||
if (num_to_copy > 0) {
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
||||
}
|
||||
// Starting from highest scoring box, mark any box with iou>threshold and
|
||||
// lower score for deletion if current box is not marked for deletion. Add
|
||||
// current box to to_keep list.
|
||||
for (int i = offset; i < next_offset; ++i) {
|
||||
// See the comment at the beginning of the file.
|
||||
// Bit shift and logical And operations are used
|
||||
// instead of division and modulo operations.
|
||||
int iblock = i >> kNmsBoxesPerThreadShiftBits;
|
||||
int inblock = i & kNmsBoxesPerThreadModuloMask;
|
||||
if (!(to_remove[iblock] & (1 << inblock))) {
|
||||
h_selected_indices.push_back(i);
|
||||
int* p = &h_delete_mask[i * bit_mask_len];
|
||||
for (int ib = 0; ib < bit_mask_len; ++ib) {
|
||||
to_remove[ib] |= p[ib];
|
||||
}
|
||||
}
|
||||
}
|
||||
offset = next_offset;
|
||||
}
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
||||
*h_nkeep = *h_selected_count;
|
||||
cudaEventDestroy(copy_done);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const int nkeep = h_selected_indices.size();
|
||||
device.memcpyHostToDevice(d_selected_indices, &h_selected_indices[0],
|
||||
nkeep * sizeof(int));
|
||||
struct GreaterThanCubOp {
|
||||
float threshold_;
|
||||
__host__ __device__ __forceinline__ GreaterThanCubOp(float threshold)
|
||||
: threshold_(threshold) {}
|
||||
__host__ __device__ __forceinline__ bool operator()(const float& val) const {
|
||||
return (val > threshold_);
|
||||
}
|
||||
};
|
||||
// Use DeviceSelect::If to count number of elements.
|
||||
// TODO(sami) Not really a good way. Perhaps consider using thrust?
|
||||
template <typename Op>
|
||||
Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
|
||||
int num_elements, int* result) {
|
||||
Tensor scratch_output;
|
||||
Tensor workspace;
|
||||
Tensor element_count;
|
||||
size_t workspace_size = 0;
|
||||
auto cuda_stream = tensorflow::GetGpuStream(context);
|
||||
auto device = context->eigen_gpu_device();
|
||||
cub::DeviceSelect::If(nullptr, workspace_size, static_cast<float*>(nullptr),
|
||||
static_cast<float*>(nullptr),
|
||||
static_cast<int*>(nullptr), num_elements, op);
|
||||
|
||||
*h_nkeep = nkeep;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||
TensorShape({1}), &element_count));
|
||||
cudaEvent_t copy_done;
|
||||
TF_RETURN_IF_CUDA_ERROR(
|
||||
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
||||
TF_RETURN_IF_CUDA_ERROR(cub::DeviceSelect::If(
|
||||
workspace.flat<int8>().data(), workspace_size, dev_array,
|
||||
scratch_output.flat<float>().data(), element_count.flat<int32>().data(),
|
||||
num_elements, op, cuda_stream));
|
||||
device.memcpyDeviceToHost(result, element_count.flat<int32>().data(),
|
||||
sizeof(int));
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DoNMS(OpKernelContext* context, const Tensor& boxes,
|
||||
const Tensor& scores, const int64_t max_output_size,
|
||||
const float iou_threshold_val, const float score_threshold) {
|
||||
const int output_size = max_output_size;
|
||||
int num_boxes = boxes.dim_size(0);
|
||||
size_t cub_sort_temp_storage_bytes = 0;
|
||||
auto cuda_stream = GetGpuStream(context);
|
||||
auto device = context->eigen_gpu_device();
|
||||
// Calling cub with nullptrs as inputs will make it return
|
||||
// workspace size needed for the operation instead of doing the operation.
|
||||
// In this specific instance, cub_sort_temp_storage_bytes will contain the
|
||||
// necessary workspace size for sorting after the call.
|
||||
if (num_boxes == 0) {
|
||||
Tensor* output_indices = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(0, TensorShape({0}), &output_indices));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
||||
nullptr, cub_sort_temp_storage_bytes,
|
||||
static_cast<float*>(nullptr), // scores
|
||||
static_cast<float*>(nullptr), // sorted scores
|
||||
static_cast<int*>(nullptr), // input indices
|
||||
static_cast<int*>(nullptr), // sorted indices
|
||||
num_boxes, // num items
|
||||
0, 8 * sizeof(float), // sort all bits
|
||||
cuda_stream);
|
||||
TF_RETURN_IF_CUDA_ERROR(cuda_ret);
|
||||
Tensor d_cub_sort_buffer;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT8, TensorShape({(int64)cub_sort_temp_storage_bytes}),
|
||||
&d_cub_sort_buffer));
|
||||
Tensor d_indices;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT32, TensorShape({num_boxes}), &d_indices));
|
||||
Tensor d_sorted_indices;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT32, TensorShape({num_boxes}), &d_sorted_indices));
|
||||
Tensor d_selected_indices;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_INT32, TensorShape({num_boxes}), &d_selected_indices));
|
||||
Tensor d_sorted_scores;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_FLOAT, TensorShape({num_boxes}), &d_sorted_scores));
|
||||
Tensor d_sorted_boxes;
|
||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||
DataType::DT_FLOAT, TensorShape({num_boxes, 4}), &d_sorted_boxes));
|
||||
|
||||
// this will return sorted scores and their indices
|
||||
auto config = GetGpuLaunchConfig(num_boxes, device);
|
||||
// initialize box and score indices
|
||||
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count, 0,
|
||||
d_indices.flat<int>().data()));
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
||||
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
||||
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
||||
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
||||
num_boxes, 0,
|
||||
8 * sizeof(float), // sort all bits
|
||||
cuda_stream);
|
||||
TF_RETURN_IF_CUDA_ERROR(cuda_ret);
|
||||
|
||||
// get pointers for easy access
|
||||
const float4* original_boxes =
|
||||
reinterpret_cast<const float4*>(boxes.flat<float>().data());
|
||||
float4* sorted_boxes =
|
||||
reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
|
||||
const int* sorted_indices = d_sorted_indices.flat<int>().data();
|
||||
// sort boxes using indices
|
||||
TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>, config.block_count,
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count, sorted_indices,
|
||||
original_boxes, sorted_boxes));
|
||||
int limited_num_boxes = num_boxes;
|
||||
// filter boxes by scores if nms v3
|
||||
if (score_threshold > std::numeric_limits<float>::lowest()) {
|
||||
GreaterThanCubOp score_limit(score_threshold);
|
||||
TF_RETURN_IF_ERROR(CountIf(context, d_sorted_scores.flat<float>().data(),
|
||||
score_limit, num_boxes, &limited_num_boxes));
|
||||
if (limited_num_boxes == 0) {
|
||||
Tensor* output_indices = nullptr;
|
||||
VLOG(1) << "Number of boxes above score threshold " << score_threshold
|
||||
<< " is 0";
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(0, TensorShape({0}), &output_indices));
|
||||
return Status::OK();
|
||||
} else {
|
||||
VLOG(2) << "Number of boxes above threshold=" << score_threshold << " is "
|
||||
<< limited_num_boxes;
|
||||
}
|
||||
}
|
||||
int num_to_keep = 0;
|
||||
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
||||
// flip boxes if necessary!
|
||||
const bool flip_boxes = true;
|
||||
auto status = NmsGpu(d_sorted_boxes.flat<float>().data(), limited_num_boxes,
|
||||
iou_threshold_val, d_selected_indices.flat<int>().data(),
|
||||
&num_to_keep, context, output_size, flip_boxes);
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||
if (!status.ok()) {
|
||||
context->SetStatus(status);
|
||||
return status;
|
||||
}
|
||||
Tensor* output_indices = nullptr;
|
||||
int num_outputs = std::min(num_to_keep, output_size); // no padding!
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(0, TensorShape({num_outputs}), &output_indices));
|
||||
if (num_outputs == 0) return Status::OK();
|
||||
config = GetGpuLaunchConfig(num_outputs, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
||||
0, device.stream(), config.virtual_thread_count,
|
||||
d_selected_indices.flat<int>().data(), sorted_indices,
|
||||
(*output_indices).flat<int>().data()));
|
||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -384,112 +585,84 @@ class NonMaxSuppressionV2GPUOp : public OpKernel {
|
||||
&output_indices));
|
||||
return;
|
||||
}
|
||||
const int output_size = max_output_size.scalar<int>()();
|
||||
size_t cub_sort_temp_storage_bytes = 0;
|
||||
auto cuda_stream = GetGpuStream(context);
|
||||
auto device = context->eigen_gpu_device();
|
||||
// Calling cub with nullptrs as inputs will make it return
|
||||
// workspace size needed for the operation instead of doing the operation.
|
||||
// In this specific instance, cub_sort_temp_storage_bytes will contain the
|
||||
// necessary workspace size for sorting after the call.
|
||||
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
||||
nullptr, cub_sort_temp_storage_bytes,
|
||||
static_cast<float*>(nullptr), // scores
|
||||
static_cast<float*>(nullptr), // sorted scores
|
||||
static_cast<int*>(nullptr), // input indices
|
||||
static_cast<int*>(nullptr), // sorted indices
|
||||
num_boxes, // num items
|
||||
0, 8 * sizeof(float), // sort all bits
|
||||
cuda_stream);
|
||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret);
|
||||
Tensor d_cub_sort_buffer;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_temp(
|
||||
DataType::DT_INT8,
|
||||
TensorShape({(int64)cub_sort_temp_storage_bytes}),
|
||||
&d_cub_sort_buffer));
|
||||
Tensor d_indices;
|
||||
const int64_t output_size = max_output_size.scalar<int>()();
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_temp(DataType::DT_INT32,
|
||||
TensorShape({num_boxes}), &d_indices));
|
||||
Tensor d_sorted_indices;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32,
|
||||
TensorShape({num_boxes}),
|
||||
&d_sorted_indices));
|
||||
Tensor d_selected_indices;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32,
|
||||
TensorShape({num_boxes}),
|
||||
&d_selected_indices));
|
||||
Tensor d_sorted_scores;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT,
|
||||
TensorShape({num_boxes}),
|
||||
&d_sorted_scores));
|
||||
Tensor d_sorted_boxes;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_FLOAT,
|
||||
TensorShape({num_boxes, 4}),
|
||||
&d_sorted_boxes));
|
||||
|
||||
// this will return sorted scores and their indices
|
||||
auto config = GetGpuLaunchConfig(num_boxes, device);
|
||||
// initialize box and score indices
|
||||
TF_CHECK_OK(GpuLaunchKernel(Iota<int>, config.block_count,
|
||||
config.thread_per_block, 0, device.stream(),
|
||||
config.virtual_thread_count, 0,
|
||||
d_indices.flat<int>().data()));
|
||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
||||
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
||||
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
||||
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
||||
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
||||
num_boxes, 0,
|
||||
8 * sizeof(float), // sort all bits
|
||||
cuda_stream);
|
||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cuda_ret);
|
||||
|
||||
// get pointers for easy access
|
||||
const float4* original_boxes =
|
||||
reinterpret_cast<const float4*>(boxes.flat<float>().data());
|
||||
float4* sorted_boxes =
|
||||
reinterpret_cast<float4*>(d_sorted_boxes.flat<float>().data());
|
||||
const int* sorted_indices = d_sorted_indices.flat<int>().data();
|
||||
// sort boxes using indices
|
||||
TF_CHECK_OK(GpuLaunchKernel(IndexMultiSelect<int, float4>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
device.stream(), config.virtual_thread_count,
|
||||
sorted_indices, original_boxes, sorted_boxes));
|
||||
|
||||
int num_to_keep = 0;
|
||||
// There is no guarantee that boxes are given in the for x1<x2 and/or y1<y2,
|
||||
// flip boxes if necessary!
|
||||
const bool flip_boxes = true;
|
||||
auto status =
|
||||
NmsGpu(d_sorted_boxes.flat<float>().data(), num_boxes,
|
||||
iou_threshold_val, d_selected_indices.flat<int>().data(),
|
||||
&num_to_keep, context, flip_boxes);
|
||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
||||
if (!status.ok()) {
|
||||
context->SetStatus(status);
|
||||
return;
|
||||
}
|
||||
Tensor* output_indices = nullptr;
|
||||
int num_outputs = std::min(num_to_keep, output_size); // no padding!
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({num_outputs}),
|
||||
&output_indices));
|
||||
if (num_outputs == 0) return;
|
||||
config = GetGpuLaunchConfig(num_outputs, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
IndexMultiSelect<int, int>, config.block_count, config.thread_per_block,
|
||||
0, device.stream(), config.virtual_thread_count,
|
||||
d_selected_indices.flat<int>().data(), sorted_indices,
|
||||
(*output_indices).flat<int>().data()));
|
||||
TF_OP_REQUIRES_CUDA_SUCCESS(context, cudaGetLastError());
|
||||
context,
|
||||
DoNMS(context, boxes, scores, output_size, iou_threshold_val,
|
||||
/*score_threshold is float min if score threshold is disabled*/
|
||||
std::numeric_limits<float>::lowest()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_GPU),
|
||||
NonMaxSuppressionV2GPUOp);
|
||||
class NonMaxSuppressionV3GPUOp : public OpKernel {
|
||||
public:
|
||||
explicit NonMaxSuppressionV3GPUOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// boxes: [num_boxes, 4]
|
||||
const Tensor& boxes = context->input(0);
|
||||
// scores: [num_boxes]
|
||||
const Tensor& scores = context->input(1);
|
||||
// max_output_size: scalar
|
||||
const Tensor& max_output_size = context->input(2);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsScalar(max_output_size.shape()),
|
||||
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
|
||||
max_output_size.shape().DebugString()));
|
||||
// iou_threshold: scalar
|
||||
const Tensor& iou_threshold = context->input(3);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
|
||||
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
|
||||
iou_threshold.shape().DebugString()));
|
||||
const float iou_threshold_val = iou_threshold.scalar<float>()();
|
||||
|
||||
const Tensor& score_threshold = context->input(4);
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsScalar(score_threshold.shape()),
|
||||
errors::InvalidArgument("score_threshold must be 0-D, got shape ",
|
||||
score_threshold.shape().DebugString()));
|
||||
const float score_threshold_val = score_threshold.scalar<float>()();
|
||||
|
||||
OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
|
||||
errors::InvalidArgument("iou_threshold must be in [0, 1]"));
|
||||
OP_REQUIRES(context, boxes.dims() == 2,
|
||||
errors::InvalidArgument("boxes must be a rank 2 tensor!"));
|
||||
int num_boxes = boxes.dim_size(0);
|
||||
OP_REQUIRES(context, boxes.dim_size(1) == 4,
|
||||
errors::InvalidArgument("boxes must be Nx4"));
|
||||
OP_REQUIRES(context, scores.dims() == 1,
|
||||
errors::InvalidArgument("scores must be a vector!"));
|
||||
OP_REQUIRES(
|
||||
context, scores.dim_size(0) == num_boxes,
|
||||
errors::InvalidArgument(
|
||||
"scores has incompatible shape")); // message must be exactly this
|
||||
// otherwise tests fail!
|
||||
if (num_boxes == 0) {
|
||||
Tensor* output_indices = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({0}),
|
||||
&output_indices));
|
||||
return;
|
||||
}
|
||||
const int output_size = max_output_size.scalar<int>()();
|
||||
OP_REQUIRES_OK(context, DoNMS(context, boxes, scores, output_size,
|
||||
iou_threshold_val, score_threshold_val));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
|
||||
.TypeConstraint<float>("T")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("iou_threshold")
|
||||
.HostMemory("max_output_size"),
|
||||
NonMaxSuppressionV2GPUOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
|
||||
.TypeConstraint<float>("T")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("iou_threshold")
|
||||
.HostMemory("max_output_size")
|
||||
.HostMemory("score_threshold"),
|
||||
NonMaxSuppressionV3GPUOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif
|
||||
|
@ -54,7 +54,7 @@ extern const int kNmsBoxesPerTread;
|
||||
Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
||||
const float iou_threshold, int* d_selected_indices,
|
||||
int* h_num_boxes_to_keep, OpKernelContext* context,
|
||||
bool flip_boxes = false);
|
||||
const int max_boxes, bool flip_boxes = false);
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -203,6 +203,222 @@ TEST_F(NonMaxSuppressionV2GPUOpTest, TestEmptyInput) {
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
//
|
||||
// NonMaxSuppressionV3GPUOp Tests
|
||||
// Copied from CPU tests
|
||||
|
||||
class NonMaxSuppressionV3GPUOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp() {
|
||||
SetDevice(DEVICE_GPU,
|
||||
std::unique_ptr<tensorflow::Device>(DeviceFactory::NewDevice(
|
||||
"GPU", {}, "/job:a/replica:0/task:0")));
|
||||
|
||||
TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV3")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectFromThreeClusters) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({6, 4}),
|
||||
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
|
||||
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({3}));
|
||||
test::FillValues<int>(&expected, {3, 0, 5});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest,
|
||||
TestSelectFromThreeClustersWithScoreThreshold) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({6, 4}),
|
||||
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
|
||||
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.4f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({2}));
|
||||
test::FillValues<int>(&expected, {3, 0});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest,
|
||||
TestSelectFromThreeClustersWithScoreThresholdZeroScores) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({6, 4}),
|
||||
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
|
||||
AddInputFromArray<float>(TensorShape({6}), {.1, 0, 0, .3, .2, -5.0});
|
||||
// If we ask for more boxes than we actually expect to get back;
|
||||
// should still only get 2 boxes back.
|
||||
AddInputFromArray<int>(TensorShape({}), {6});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {-3.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({2}));
|
||||
test::FillValues<int>(&expected, {3, 0});
|
||||
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest,
|
||||
TestSelectFromThreeClustersFlippedCoordinates) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(TensorShape({6, 4}),
|
||||
{1, 1, 0, 0, 0, 0.1f, 1, 1.1f, 0, .9f, 1, -0.1f,
|
||||
0, 10, 1, 11, 1, 10.1f, 0, 11.1f, 1, 101, 0, 100});
|
||||
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({3}));
|
||||
test::FillValues<int>(&expected, {3, 0, 5});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest,
|
||||
TestSelectAtMostTwoBoxesFromThreeClusters) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({6, 4}),
|
||||
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
|
||||
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
AddInputFromArray<int>(TensorShape({}), {2});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({2}));
|
||||
test::FillValues<int>(&expected, {3, 0});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest,
|
||||
TestSelectAtMostThirtyBoxesFromThreeClusters) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({6, 4}),
|
||||
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
|
||||
AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
AddInputFromArray<int>(TensorShape({}), {30});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({3}));
|
||||
test::FillValues<int>(&expected, {3, 0, 5});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectSingleBox) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
|
||||
AddInputFromArray<float>(TensorShape({1}), {.9f});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({1}));
|
||||
test::FillValues<int>(&expected, {0});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest, TestSelectFromTenIdenticalBoxes) {
|
||||
MakeOp();
|
||||
|
||||
int num_boxes = 10;
|
||||
std::vector<float> corners(num_boxes * 4);
|
||||
std::vector<float> scores(num_boxes);
|
||||
for (int i = 0; i < num_boxes; ++i) {
|
||||
corners[i * 4 + 0] = 0;
|
||||
corners[i * 4 + 1] = 0;
|
||||
corners[i * 4 + 2] = 1;
|
||||
corners[i * 4 + 3] = 1;
|
||||
scores[i] = .9;
|
||||
}
|
||||
AddInputFromArray<float>(TensorShape({num_boxes, 4}), corners);
|
||||
AddInputFromArray<float>(TensorShape({num_boxes}), scores);
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({1}));
|
||||
test::FillValues<int>(&expected, {0});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest, TestInconsistentBoxAndScoreShapes) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(
|
||||
TensorShape({6, 4}),
|
||||
{0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
|
||||
0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
|
||||
AddInputFromArray<float>(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f});
|
||||
AddInputFromArray<int>(TensorShape({}), {30});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
Status s = RunOpKernel();
|
||||
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest, TestInvalidIOUThreshold) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(TensorShape({1, 4}), {0, 0, 1, 1});
|
||||
AddInputFromArray<float>(TensorShape({1}), {.9f});
|
||||
AddInputFromArray<int>(TensorShape({}), {3});
|
||||
AddInputFromArray<float>(TensorShape({}), {1.2f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
Status s = RunOpKernel();
|
||||
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_TRUE(
|
||||
absl::StrContains(s.ToString(), "iou_threshold must be in [0, 1]"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(NonMaxSuppressionV3GPUOpTest, TestEmptyInput) {
|
||||
MakeOp();
|
||||
AddInputFromArray<float>(TensorShape({0, 4}), {});
|
||||
AddInputFromArray<float>(TensorShape({0}), {});
|
||||
AddInputFromArray<int>(TensorShape({}), {30});
|
||||
AddInputFromArray<float>(TensorShape({}), {.5f});
|
||||
AddInputFromArray<float>(TensorShape({}), {0.0f});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_INT32, TensorShape({0}));
|
||||
test::FillValues<int>(&expected, {});
|
||||
test::ExpectTensorEqual<int>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user