[ROCm] enable InTopK op on ROCm.

This commit is contained in:
Wen-Heng (Jack) Chung 2019-08-13 16:31:25 +00:00
parent 86d00f0125
commit 8986d4eb87
3 changed files with 8 additions and 8 deletions

View File

@ -116,7 +116,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
.TypeConstraint<int64>("T"),
InTopK<CPUDevice, float, int64>);
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
namespace functor {
@ -142,6 +142,6 @@ REGISTER_KERNEL_BUILDER(
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
InTopK<GPUDevice, float, int64>);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
} // namespace tensorflow

View File

@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
@ -41,7 +41,7 @@ __global__ void ComputePredictionMaskKernel(
const TargetT* targets, // dims: [ num_targets ]
int64* mask, // dims: [ num_targets x num_classes ]
int num_targets, int num_classes) {
CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
const int batch_index = i / num_classes;
TargetT target_idx = ldg(targets + batch_index);
@ -118,7 +118,7 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
const auto& d = context->eigen_device<GPUDevice>();
// Compute a mask for all predictions.
CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
OP_REQUIRES_OK(
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
config.block_count, config.thread_per_block, 0,
@ -173,4 +173,4 @@ DEFINE_GPU_KERNELS(float, int64);
} // end namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM