Adding ROCM support for the population_count op
This commit is contained in:
parent
dac4bd7750
commit
e7f1163401
@ -122,7 +122,7 @@ struct PopulationCount<CPUDevice, T> {
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define REGISTER_POPULATION_COUNT(type) \
|
#define REGISTER_POPULATION_COUNT(type) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
@ -158,6 +158,6 @@ TF_CALL_int64(DECLARE_GPU_SPEC);
|
|||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -35,14 +35,14 @@ namespace functor {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void PopulationCountKernel(const int size, const T* input,
|
__global__ void PopulationCountKernel(const int size, const T* input,
|
||||||
uint8* output) {
|
uint8* output) {
|
||||||
CUDA_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); }
|
GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popc(ldg(input + i)); }
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ void PopulationCountKernel(const int size, const int8* input,
|
__global__ void PopulationCountKernel(const int size, const int8* input,
|
||||||
uint8* output) {
|
uint8* output) {
|
||||||
// For some reason, __popc on a negative int8 gets confused.
|
// For some reason, __popc on a negative int8 gets confused.
|
||||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
GPU_1D_KERNEL_LOOP(i, size) {
|
||||||
output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i)));
|
output[i] = __popc(ldg(reinterpret_cast<const uint8*>(input + i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -51,7 +51,7 @@ template <>
|
|||||||
__global__ void PopulationCountKernel(const int size, const int16* input,
|
__global__ void PopulationCountKernel(const int size, const int16* input,
|
||||||
uint8* output) {
|
uint8* output) {
|
||||||
// For some reason, __popc on a negative int16 gets confused.
|
// For some reason, __popc on a negative int16 gets confused.
|
||||||
CUDA_1D_KERNEL_LOOP(i, size) {
|
GPU_1D_KERNEL_LOOP(i, size) {
|
||||||
output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i)));
|
output[i] = __popc(ldg(reinterpret_cast<const uint16*>(input + i)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -59,7 +59,7 @@ __global__ void PopulationCountKernel(const int size, const int16* input,
|
|||||||
template <>
|
template <>
|
||||||
__global__ void PopulationCountKernel<int64>(const int size, const int64* input,
|
__global__ void PopulationCountKernel<int64>(const int size, const int64* input,
|
||||||
uint8* output) {
|
uint8* output) {
|
||||||
CUDA_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); }
|
GPU_1D_KERNEL_LOOP(i, size) { output[i] = __popcll(ldg(input + i)); }
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_GPU_SPECS(T) \
|
#define DEFINE_GPU_SPECS(T) \
|
||||||
@ -69,8 +69,8 @@ __global__ void PopulationCountKernel<int64>(const int size, const int64* input,
|
|||||||
TTypes<uint8>::Flat output) { \
|
TTypes<uint8>::Flat output) { \
|
||||||
const GPUDevice& d = c->eigen_device<GPUDevice>(); \
|
const GPUDevice& d = c->eigen_device<GPUDevice>(); \
|
||||||
int64 total_count = input.size(); \
|
int64 total_count = input.size(); \
|
||||||
GpuLaunchConfig config = GetCudaLaunchConfig(total_count, d); \
|
GpuLaunchConfig config = GetGpuLaunchConfig(total_count, d); \
|
||||||
TF_CHECK_OK(CudaLaunchKernel(PopulationCountKernel<T>, config.block_count, \
|
TF_CHECK_OK(GpuLaunchKernel(PopulationCountKernel<T>, config.block_count, \
|
||||||
config.thread_per_block, 0, d.stream(), \
|
config.thread_per_block, 0, d.stream(), \
|
||||||
total_count, input.data(), output.data())); \
|
total_count, input.data(), output.data())); \
|
||||||
}
|
}
|
||||||
@ -88,4 +88,4 @@ TF_CALL_int64(DEFINE_GPU_SPECS);
|
|||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
x
Reference in New Issue
Block a user