[ROCm] enable InTopK op on ROCm.
This commit is contained in:
parent
86d00f0125
commit
8986d4eb87
tensorflow/core/kernels
@ -116,7 +116,7 @@ REGISTER_KERNEL_BUILDER(Name("InTopKV2")
|
|||||||
.TypeConstraint<int64>("T"),
|
.TypeConstraint<int64>("T"),
|
||||||
InTopK<CPUDevice, float, int64>);
|
InTopK<CPUDevice, float, int64>);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
@ -142,6 +142,6 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
|
Name("InTopKV2").Device(DEVICE_GPU).TypeConstraint<int64>("T"),
|
||||||
InTopK<GPUDevice, float, int64>);
|
InTopK<GPUDevice, float, int64>);
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_IN_TOPK_OP_H_
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || TENSORFLOW_USE_ROCM
|
||||||
|
|
||||||
#define EIGEN_USE_GPU
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ __global__ void ComputePredictionMaskKernel(
|
|||||||
const TargetT* targets, // dims: [ num_targets ]
|
const TargetT* targets, // dims: [ num_targets ]
|
||||||
int64* mask, // dims: [ num_targets x num_classes ]
|
int64* mask, // dims: [ num_targets x num_classes ]
|
||||||
int num_targets, int num_classes) {
|
int num_targets, int num_classes) {
|
||||||
CUDA_1D_KERNEL_LOOP(i, num_targets * num_classes) {
|
GPU_1D_KERNEL_LOOP(i, num_targets * num_classes) {
|
||||||
const int batch_index = i / num_classes;
|
const int batch_index = i / num_classes;
|
||||||
TargetT target_idx = ldg(targets + batch_index);
|
TargetT target_idx = ldg(targets + batch_index);
|
||||||
|
|
||||||
@ -118,7 +118,7 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
|
|||||||
const auto& d = context->eigen_device<GPUDevice>();
|
const auto& d = context->eigen_device<GPUDevice>();
|
||||||
|
|
||||||
// Compute a mask for all predictions.
|
// Compute a mask for all predictions.
|
||||||
CudaLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
|
GpuLaunchConfig config = GetGpuLaunchConfig(num_targets * num_classes, d);
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
|
context, GpuLaunchKernel(ComputePredictionMaskKernel<T, TargetT>,
|
||||||
config.block_count, config.thread_per_block, 0,
|
config.block_count, config.thread_per_block, 0,
|
||||||
@ -173,4 +173,4 @@ DEFINE_GPU_KERNELS(float, int64);
|
|||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
|
Loading…
Reference in New Issue
Block a user