[ROCm] enable InTopK op on ROCm.
This commit is contained in:
parent
86d00f0125
commit
8986d4eb87
tensorflow/core/kernels
@ -116,7 +116,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
|
||||
.TypeConstraint<int64>("T"),
|
||||
InTopK<CPUDevice, float, int64>);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
@ -142,6 +142,6 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
|
||||
InTopK<GPUDevice, float, int64>);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
@ -41,7 +41,7 @@ __global__ void ComputePredictionMaskKernel(
|
||||
const TargetT* targets, // dims: [ num_targets ]
|
||||
int64* mask, // dims: [ num_targets x num_classes ]
|
||||
int num_targets, int num_classes) {
|
||||
CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
|
||||
GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
|
||||
const int batch_index = i / num_classes;
|
||||
TargetT target_idx = ldg(targets + batch_index);
|
||||
|
||||
@ -118,7 +118,7 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
|
||||
const auto& d = context->eigen_device<GPUDevice>();
|
||||
|
||||
// Compute a mask for all predictions.
|
||||
CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
|
||||
GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
|
||||
OP_REQUIRES_OK(
|
||||
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
|
||||
config.block_count, config.thread_per_block, 0,
|
||||
@ -173,4 +173,4 @@ DEFINE_GPU_KERNELS(float, int64);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
Loading…
Reference in New Issue
Block a user