Fixing and enabling NonMaxSuppression for ROCm
This commit is contained in:
parent
580dc945a6
commit
577579c575
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
@ -28,7 +28,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
#include "tensorflow/core/util/gpu_launch_config.h"
|
||||||
#include "tensorflow/stream_executor/stream_executor.h"
|
#include "tensorflow/stream_executor/stream_executor.h"
|
||||||
|
|
||||||
struct __align__(16) Box {
|
|
||||||
|
struct
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
__align__(16)
|
||||||
|
#endif
|
||||||
|
Box {
|
||||||
float x1, y1, x2, y2;
|
float x1, y1, x2, y2;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -114,7 +119,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
|
|||||||
char* result_mask) {
|
char* result_mask) {
|
||||||
extern __shared__ int local[];
|
extern __shared__ int local[];
|
||||||
// set global mask to accept all boxes
|
// set global mask to accept all boxes
|
||||||
for (int box : CudaGridRangeX(bit_mask_len)) {
|
for (int box : GpuGridRangeX(bit_mask_len)) {
|
||||||
local[box] = 0xFFFFFFFF;
|
local[box] = 0xFFFFFFFF;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -127,7 +132,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
|
|||||||
accepted_boxes += 1;
|
accepted_boxes += 1;
|
||||||
int offset = box * bit_mask_len;
|
int offset = box * bit_mask_len;
|
||||||
// update global mask with current box's mask
|
// update global mask with current box's mask
|
||||||
for (int b : CudaGridRangeX(bit_mask_len)) {
|
for (int b : GpuGridRangeX(bit_mask_len)) {
|
||||||
local[b] &= ~bitmask[offset + b];
|
local[b] &= ~bitmask[offset + b];
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -135,7 +140,7 @@ __global__ void NMSReduce(const int* bitmask, const int bit_mask_len,
|
|||||||
}
|
}
|
||||||
// copy global mask to result_max char array. char array is needed for
|
// copy global mask to result_max char array. char array is needed for
|
||||||
// cub::DeviceSelect later.
|
// cub::DeviceSelect later.
|
||||||
for (int box : CudaGridRangeX(num_boxes)) {
|
for (int box : GpuGridRangeX(num_boxes)) {
|
||||||
result_mask[box] = CheckBit(local, box);
|
result_mask[box] = CheckBit(local, box);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -232,14 +237,14 @@ __device__ EIGEN_STRONG_INLINE void SelectHelper(const Index i_selected,
|
|||||||
template <typename Index, typename T, typename... Args>
|
template <typename Index, typename T, typename... Args>
|
||||||
__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
|
__global__ void IndexMultiSelect(const int num_elements, const Index* indices,
|
||||||
const T* original, T* selected, Args... args) {
|
const T* original, T* selected, Args... args) {
|
||||||
for (const int idx : CudaGridRangeX(num_elements)) {
|
for (const int idx : GpuGridRangeX(num_elements)) {
|
||||||
SelectHelper(idx, indices[idx], original, selected, args...);
|
SelectHelper(idx, indices[idx], original, selected, args...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
|
__global__ void Iota(const int num_elements, const T offset, T* to_fill) {
|
||||||
for (int idx : CudaGridRangeX(num_elements)) {
|
for (int idx : GpuGridRangeX(num_elements)) {
|
||||||
to_fill[idx] = static_cast<T>(idx) + offset;
|
to_fill[idx] = static_cast<T>(idx) + offset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -322,13 +327,13 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
// do Cub::deviceSelect::flagged
|
// do Cub::deviceSelect::flagged
|
||||||
size_t flagged_buffer_size = 0;
|
size_t flagged_buffer_size = 0;
|
||||||
cub::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
|
gpuprim::DeviceSelect::Flagged(static_cast<void*>(nullptr), // temp_storage
|
||||||
flagged_buffer_size,
|
flagged_buffer_size,
|
||||||
static_cast<int*>(nullptr), // input
|
static_cast<int*>(nullptr), // input
|
||||||
static_cast<char*>(nullptr), // selection flag
|
static_cast<char*>(nullptr), // selection flag
|
||||||
static_cast<int*>(nullptr), // selected items
|
static_cast<int*>(nullptr), // selected items
|
||||||
static_cast<int*>(nullptr), // num_selected
|
static_cast<int*>(nullptr), // num_selected
|
||||||
num_boxes, device.stream());
|
num_boxes, device.stream());
|
||||||
Tensor cub_scratch;
|
Tensor cub_scratch;
|
||||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
|
DataType::DT_INT8, TensorShape({(int64)flagged_buffer_size}),
|
||||||
@ -337,22 +342,22 @@ Status NmsGpu(const float* d_sorted_boxes_float_ptr, const int num_boxes,
|
|||||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||||
TensorShape({1}), &d_num_selected));
|
TensorShape({1}), &d_num_selected));
|
||||||
|
|
||||||
cub::DeviceSelect::Flagged(
|
gpuprim::DeviceSelect::Flagged(
|
||||||
(void*)cub_scratch.flat<int8>().data(), // temp_storage
|
(void*)cub_scratch.flat<int8>().data(), // temp_storage
|
||||||
flagged_buffer_size,
|
flagged_buffer_size,
|
||||||
d_indices.flat<int>().data(), // input
|
d_indices.flat<int>().data(), // input
|
||||||
selected, // selection flag
|
selected, // selection flag
|
||||||
d_selected_indices, // selected items
|
d_selected_indices, // selected items
|
||||||
d_num_selected.flat<int>().data(), num_boxes, device.stream());
|
d_num_selected.flat<int>().data(), num_boxes, device.stream());
|
||||||
cudaEvent_t copy_done;
|
gpuEvent_t copy_done;
|
||||||
TF_RETURN_IF_CUDA_ERROR(
|
TF_RETURN_IF_CUDA_ERROR(
|
||||||
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
gpuEventCreateWithFlags(©_done, gpuEventDisableTiming));
|
||||||
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
|
device.memcpyDeviceToHost(h_selected_count, d_num_selected.flat<int>().data(),
|
||||||
sizeof(int));
|
sizeof(int));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
|
||||||
*h_nkeep = *h_selected_count;
|
*h_nkeep = *h_selected_count;
|
||||||
cudaEventDestroy(copy_done);
|
gpuEventDestroy(copy_done);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -375,9 +380,10 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
|
|||||||
size_t workspace_size = 0;
|
size_t workspace_size = 0;
|
||||||
auto cuda_stream = tensorflow::GetGpuStream(context);
|
auto cuda_stream = tensorflow::GetGpuStream(context);
|
||||||
auto device = context->eigen_gpu_device();
|
auto device = context->eigen_gpu_device();
|
||||||
cub::DeviceSelect::If(nullptr, workspace_size, static_cast<float*>(nullptr),
|
gpuprim::DeviceSelect::If(nullptr, workspace_size,
|
||||||
static_cast<float*>(nullptr),
|
static_cast<float*>(nullptr),
|
||||||
static_cast<int*>(nullptr), num_elements, op);
|
static_cast<float*>(nullptr),
|
||||||
|
static_cast<int*>(nullptr), num_elements, op);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(context->allocate_temp(
|
TF_RETURN_IF_ERROR(context->allocate_temp(
|
||||||
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
|
DataType::DT_FLOAT, TensorShape({num_elements}), &scratch_output));
|
||||||
@ -385,17 +391,17 @@ Status CountIf(OpKernelContext* context, const float* dev_array, const Op& op,
|
|||||||
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
|
DataType::DT_INT8, TensorShape({(int64)workspace_size}), &workspace));
|
||||||
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
TF_RETURN_IF_ERROR(context->allocate_temp(DataType::DT_INT32,
|
||||||
TensorShape({1}), &element_count));
|
TensorShape({1}), &element_count));
|
||||||
cudaEvent_t copy_done;
|
gpuEvent_t copy_done;
|
||||||
TF_RETURN_IF_CUDA_ERROR(
|
TF_RETURN_IF_CUDA_ERROR(
|
||||||
cudaEventCreateWithFlags(©_done, cudaEventDisableTiming));
|
gpuEventCreateWithFlags(©_done, gpuEventDisableTiming));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cub::DeviceSelect::If(
|
TF_RETURN_IF_CUDA_ERROR(gpuprim::DeviceSelect::If(
|
||||||
workspace.flat<int8>().data(), workspace_size, dev_array,
|
workspace.flat<int8>().data(), workspace_size, dev_array,
|
||||||
scratch_output.flat<float>().data(), element_count.flat<int32>().data(),
|
scratch_output.flat<float>().data(), element_count.flat<int32>().data(),
|
||||||
num_elements, op, cuda_stream));
|
num_elements, op, cuda_stream));
|
||||||
device.memcpyDeviceToHost(result, element_count.flat<int32>().data(),
|
device.memcpyDeviceToHost(result, element_count.flat<int32>().data(),
|
||||||
sizeof(int));
|
sizeof(int));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventRecord(copy_done, device.stream()));
|
TF_RETURN_IF_CUDA_ERROR(gpuEventRecord(copy_done, device.stream()));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaEventSynchronize(copy_done));
|
TF_RETURN_IF_CUDA_ERROR(gpuEventSynchronize(copy_done));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -418,7 +424,7 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
cudaError_t cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
cudaError_t cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
|
||||||
nullptr, cub_sort_temp_storage_bytes,
|
nullptr, cub_sort_temp_storage_bytes,
|
||||||
static_cast<float*>(nullptr), // scores
|
static_cast<float*>(nullptr), // scores
|
||||||
static_cast<float*>(nullptr), // sorted scores
|
static_cast<float*>(nullptr), // sorted scores
|
||||||
@ -458,7 +464,7 @@ Status DoNMS(OpKernelContext* context, const Tensor& boxes,
|
|||||||
config.virtual_thread_count, 0,
|
config.virtual_thread_count, 0,
|
||||||
d_indices.flat<int>().data()));
|
d_indices.flat<int>().data()));
|
||||||
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
TF_RETURN_IF_CUDA_ERROR(cudaGetLastError());
|
||||||
cuda_ret = cub::DeviceRadixSort::SortPairsDescending(
|
cuda_ret = gpuprim::DeviceRadixSort::SortPairsDescending(
|
||||||
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
d_cub_sort_buffer.flat<int8>().data(), cub_sort_temp_storage_bytes,
|
||||||
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
scores.flat<float>().data(), d_sorted_scores.flat<float>().data(),
|
||||||
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
d_indices.flat<int>().data(), d_sorted_indices.flat<int>().data(),
|
||||||
|
@ -35,7 +35,7 @@ struct NonMaxSuppression {
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
extern const int kNmsBoxesPerTread;
|
extern const int kNmsBoxesPerTread;
|
||||||
|
|
||||||
// Given descending sorted box list, apply non-maximal-suppression with given
|
// Given descending sorted box list, apply non-maximal-suppression with given
|
||||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
// These tests are copied from non_max_suppression_op_test.cc file and modified
|
// These tests are copied from non_max_suppression_op_test.cc file and modified
|
||||||
// to use GPU ops. See other file for test details.
|
// to use GPU ops. See other file for test details.
|
||||||
|
|
||||||
|
@ -122,7 +122,8 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
|
|||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
CHECK_NE(device_type, DEVICE_GPU)
|
CHECK_NE(device_type, DEVICE_GPU)
|
||||||
<< "Requesting GPU on binary compiled without GOOGLE_CUDA or TENSORFLOW_USE_ROCM.";
|
<< "Requesting GPU on binary compiled without GOOGLE_CUDA or "
|
||||||
|
"TENSORFLOW_USE_ROCM.";
|
||||||
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
allocator_ = device_->GetAllocator(AllocatorAttributes());
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user