Adding ROCM support for the population_count op

This commit is contained in:
Deven Desai 2019-06-10 15:39:12 +00:00
parent dac4bd7750
commit e7f1163401
2 changed files with 10 additions and 10 deletions

View File

@ -122,7 +122,7 @@ struct PopulationCount<CPUDevice, T> {
} // namespace functor } // namespace functor
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_POPULATION_COUNT(type) \ #define REGISTER_POPULATION_COUNT(type) \
REGISTER_KERNEL_BUILDER( \ REGISTER_KERNEL_BUILDER( \
@ -158,6 +158,6 @@ TF_CALL_int64(DECLARE_GPU_SPEC);
} // namespace functor } // namespace functor
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#if GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
@ -35,14 +35,14 @@ namespace functor {
template <typename T> template <typename T>
__global__ void PopulationCountKernel(const int size, const T* input, __global__ void PopulationCountKernel(const int size, const T* input,
uint8* output) { uint8* output) {
CUDA_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); } GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); }
} }
template <> template <>
__global__ void PopulationCountKernel(const int size, const int8* input, __global__ void PopulationCountKernel(const int size, const int8* input,
uint8* output) { uint8* output) {
// For some reason, __popc on a negative int8 gets confused. // For some reason, __popc on a negative int8 gets confused.
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i))); output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i)));
} }
} }
@ -51,7 +51,7 @@ template <>
__global__ void PopulationCountKernel(const int size, const int16* input, __global__ void PopulationCountKernel(const int size, const int16* input,
uint8* output) { uint8* output) {
// For some reason, __popc on a negative int16 gets confused. // For some reason, __popc on a negative int16 gets confused.
CUDA_1D_KERNEL_LOOP(i, size) { GPU_1D_KERNEL_LOOP(i, size) {
output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i))); output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i)));
} }
} }
@ -59,7 +59,7 @@ __global__ void PopulationCountKernel(const int size, const int16* input,
template <> template <>
__global__ void PopulationCountKernel<int64>(const int size, const int64* input, __global__ void PopulationCountKernel<int64>(const int size, const int64* input,
uint8* output) { uint8* output) {
CUDA_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); } GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); }
} }
#define DEFINE_GPU_SPECS(T) \ #define DEFINE_GPU_SPECS(T) \
@ -69,8 +69,8 @@ __global__ void PopulationCountKernel<int64>(const int size, const int64* input,
TTypes<uint8>::Flat output) { \ TTypes<uint8>::Flat output) { \
const GPUDevice& d = c->eigen_device<GPUDevice>(); \ const GPUDevice& d = c->eigen_device<GPUDevice>(); \
int64 total_count = input.size(); \ int64 total_count = input.size(); \
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); \ GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); \
TF_CHECK_OK(CudaLaunchKernel(PopulationCountKernel<T>, config.block_count, \ TF_CHECK_OK(GpuLaunchKernel(PopulationCountKernel<T>, config.block_count, \
config.thread_per_block, 0, d.stream(), \ config.thread_per_block, 0, d.stream(), \
total_count, input.data(), output.data())); \ total_count, input.data(), output.data())); \
} }
@ -88,4 +88,4 @@ TF_CALL_int64(DEFINE_GPU_SPECS);
} // namespace tensorflow } // namespace tensorflow
#endif // GOOGLE_CUDA #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM