Fixing and enabling NonMaxSuppression for ROCm

This commit is contained in:
Eugene Kuznetsov 2020-01-22 12:22:56 -08:00
parent 580dc945a6
commit 577579c575
4 changed files with 40 additions and 33 deletions

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include <limits> #include <limits>
@ -28,7 +28,12 @@ limitations under the License.
#include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow/core/util/gpu_launch_config.h"
#include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor.h"
struct __align__(16) Box {
struct
#if GOOGLE_CUDA
__align__(16)
#endif
Box {
float x1, y1, x2, y2; float x1, y1, x2, y2;
}; };
@ -114,7 +119,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
char* result_mask) { char* result_mask) {
extern __shared__ int local[]; extern __shared__ int local[];
// set global mask to accept all boxes // set global mask to accept all boxes
for (int box : CudaGridRangeX(bit_mask_len)) { for (int box : GpuGridRangeX(bit_mask_len)) {
local[box] = 0xFFFFFFFF; local[box] = 0xFFFFFFFF;
} }
__syncthreads(); __syncthreads();
@ -127,7 +132,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
accepted_boxes += 1; accepted_boxes += 1;
int offset = box * bit_mask_len; int offset = box * bit_mask_len;
// update global mask with current box's mask // update global mask with current box's mask
for (int b : CudaGridRangeX(bit_mask_len)) { for (int b : GpuGridRangeX(bit_mask_len)) {
local[b] &= ~bitmask[offset + b]; local[b] &= ~bitmask[offset + b];
} }
__syncthreads(); __syncthreads();
@ -135,7 +140,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
} }
// copy global mask to result_max char array. char array is needed for // copy global mask to result_max char array. char array is needed for
// cub::DeviceSelect later. // cub::DeviceSelect later.
for (int box : CudaGridRangeX(num_boxes)) { for (int box : GpuGridRangeX(num_boxes)) {
result_mask[box] = CheckBit(local, box); result_mask[box] = CheckBit(local, box);
} }
} }
@ -232,14 +237,14 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
template <typename Index, typename T, typename... Args> template <typename Index, typename T, typename... Args>
__global__ void IndexMultiSelect(const int num_elements, const Index* indices, __global__ void IndexMultiSelect(const int num_elements, const Index* indices,
const T* original, T* selected, Args... args) { const T* original, T* selected, Args... args) {
for (const int idx : CudaGridRangeX(num_elements)) { for (const int idx : GpuGridRangeX(num_elements)) {
SelectHelper(idx, indices[idx], original, selected, args...); SelectHelper(idx, indices[idx], original, selected, args...);
} }
} }
template <typename T> template <typename T>
__global__ void Iota(const int num_elements, const T offset, T* to_fill) { __global__ void Iota(const int num_elements, const T offset, T* to_fill) {
for (int idx : CudaGridRangeX(num_elements)) { for (int idx : GpuGridRangeX(num_elements)) {
to_fill[idx] = static_cast<T>(idx) + offset; to_fill[idx] = static_cast<T>(idx) + offset;
} }
} }
@ -322,7 +327,7 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
// do Cub::deviceSelect::flagged // do Cub::deviceSelect::flagged
size_t flagged_buffer_size = 0; size_t flagged_buffer_size = 0;
cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage gpuprim::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
flagged_buffer_size, flagged_buffer_size,
static_cast<int*>(nullptr), // input static_cast<int*>(nullptr), // input
static_cast<char*>(nullptr), // selection flag static_cast<char*>(nullptr), // selection flag
@ -337,22 +342,22 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({1}), &d_num_selected)); TensorShape({1}), &d_num_selected));
cub::DeviceSelect::Flagged( gpuprim::DeviceSelect::Flagged(
(void*)cub_scratch.flat<int8>().data(), // temp_storage (void*)cub_scratch.flat<int8>().data(), // temp_storage
flagged_buffer_size, flagged_buffer_size,
d_indices.flat<int>().data(), // input d_indices.flat<int>().data(), // input
selected, // selection flag selected, // selection flag
d_selected_indices, // selected items d_selected_indices, // selected items
d_num_selected.flat<int>().data(), num_boxes, device.stream()); d_num_selected.flat<int>().data(), num_boxes, device.stream());
cudaEvent_t copy_done; gpuEvent_t copy_done;
TF_RETURN_IF_CUDA_ERROR( TF_RETURN_IF_CUDA_ERROR(
cudaEventCreateWithFlags(&copy_done, cudaEventDisableTiming)); gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(), device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
sizeof(int)); sizeof(int));
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
*h_nkeep = *h_selected_count; *h_nkeep = *h_selected_count;
cudaEventDestroy(copy_done); gpuEventDestroy(copy_done);
return Status::OK(); return Status::OK();
} }
@ -375,7 +380,8 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
size_t workspace_size = 0; size_t workspace_size = 0;
auto cuda_stream = tensorflow::GetGpuStream(context); auto cuda_stream = tensorflow::GetGpuStream(context);
auto device = context->eigen_gpu_device(); auto device = context->eigen_gpu_device();
cub::DeviceSelect::If(nullptr, workspace_size, static_cast<float*>(nullptr), gpuprim::DeviceSelect::If(nullptr, workspace_size,
static_cast<float*>(nullptr),
static_cast<float*>(nullptr), static_cast<float*>(nullptr),
static_cast<int*>(nullptr), num_elements, op); static_cast<int*>(nullptr), num_elements, op);
@ -385,17 +391,17 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace)); DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32, TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
TensorShape({1}), &element_count)); TensorShape({1}), &element_count));
cudaEvent_t copy_done; gpuEvent_t copy_done;
TF_RETURN_IF_CUDA_ERROR( TF_RETURN_IF_CUDA_ERROR(
cudaEventCreateWithFlags(&copy_done, cudaEventDisableTiming)); gpuEventCreateWithFlags(&copy_done, gpuEventDisableTiming));
TF_RETURN_IF_CUDA_ERROR(cub::DeviceSelect::If( TF_RETURN_IF_CUDA_ERROR(gpuprim::DeviceSelect::If(
workspace.flat<int8>().data(), workspace_size, dev_array, workspace.flat<int8>().data(), workspace_size, dev_array,
scratch_output.flat<float>().data(), element_count.flat<int32>().data(), scratch_output.flat<float>().data(), element_count.flat<int32>().data(),
num_elements, op, cuda_stream)); num_elements, op, cuda_stream));
device.memcpyDeviceToHost(result, element_count.flat<int32>().data(), device.memcpyDeviceToHost(result, element_count.flat<int32>().data(),
sizeof(int)); sizeof(int));
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream())); TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done)); TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
return Status::OK(); return Status::OK();
} }
@ -418,7 +424,7 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes,
return Status::OK(); return Status::OK();
} }
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending( cudaError_t cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
nullptr, cub_sort_temp_storage_bytes, nullptr, cub_sort_temp_storage_bytes,
static_cast<float*>(nullptr), // scores static_cast<float*>(nullptr), // scores
static_cast<float*>(nullptr), // sorted scores static_cast<float*>(nullptr), // sorted scores
@ -458,7 +464,7 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes,
config.virtual_thread_count, 0, config.virtual_thread_count, 0,
d_indices.flat<int>().data())); d_indices.flat<int>().data()));
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError()); TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
cuda_ret = cub::DeviceRadixSort::SortPairsDescending( cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes, d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(), scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(), d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),

View File

@ -35,7 +35,7 @@ struct NonMaxSuppression {
} // namespace functor } // namespace functor
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
extern const int kNmsBoxesPerTread; extern const int kNmsBoxesPerTread;
// Given descending sorted box list, apply non-maximal-suppression with given // Given descending sorted box list, apply non-maximal-suppression with given

View File

@ -35,7 +35,7 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// These tests are copied from non_max_suppression_op_test.cc file and modified // These tests are copied from non_max_suppression_op_test.cc file and modified
// to use GPU ops. See other file for test details. // to use GPU ops. See other file for test details.

View File

@ -122,7 +122,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
} }
#else #else
CHECK_NE(device_type, DEVICE_GPU) CHECK_NE(device_type, DEVICE_GPU)
<< "Requesting GPU on binary compiled without GOOGLE_CUDA or TENSORFLOW_USE_ROCM."; << "Requesting GPU on binary compiled without GOOGLE_CUDA or "
"TENSORFLOW_USE_ROCM.";
allocator_ = device_->GetAllocator(AllocatorAttributes()); allocator_ = device_->GetAllocator(AllocatorAttributes());
#endif #endif
} }