Merge pull request #36640 from ROCmSoftwarePlatform:google-upstream-gpuprim
PiperOrigin-RevId: 306656469 Change-Id: I8591530f51c818380b76e3a7c42944503da523ac
This commit is contained in:
commit
0592ae692b
@ -517,6 +517,16 @@ cc_library(
|
||||
deps = ["//third_party/eigen3"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_prim_hdrs",
|
||||
hdrs = ["gpu_prim.h"],
|
||||
deps = if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conv_ops_gpu_hdrs",
|
||||
hdrs = ["conv_ops_gpu.h"],
|
||||
@ -1393,12 +1403,9 @@ tf_kernel_library(
|
||||
"where_op_gpu_impl_8.cu.cc",
|
||||
],
|
||||
deps = if_cuda_or_rocm([
|
||||
":cuda_solvers",
|
||||
]) + if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]) + ARRAY_DEPS,
|
||||
":cuda_solvers",
|
||||
]) + [":gpu_prim_hdrs"] +
|
||||
ARRAY_DEPS,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -2637,10 +2644,9 @@ tf_kernel_library(
|
||||
deps = DYNAMIC_DEPS + [
|
||||
":fill_functor",
|
||||
":gather_functor",
|
||||
":gpu_prim_hdrs",
|
||||
"//tensorflow/core:framework_internal",
|
||||
] + if_cuda(["@cub_archive//:cub"]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -3082,8 +3088,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "generate_box_proposals_op",
|
||||
gpu_srcs = ["generate_box_proposals_op.cu.cc"],
|
||||
deps = if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
deps = [":gpu_prim_hdrs"] + if_cuda([
|
||||
":non_max_suppression_op_gpu",
|
||||
]),
|
||||
)
|
||||
@ -3091,7 +3096,7 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "non_max_suppression_op",
|
||||
prefix = "non_max_suppression_op",
|
||||
deps = IMAGE_DEPS + if_cuda(["@cub_archive//:cub"]),
|
||||
deps = IMAGE_DEPS + [":gpu_prim_hdrs"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -4160,11 +4165,10 @@ tf_kernel_library(
|
||||
name = "reduction_ops",
|
||||
gpu_srcs = ["reduction_gpu_kernels.cu.h"],
|
||||
prefix = "reduction_ops",
|
||||
deps = MATH_DEPS + [":transpose_functor"] + if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
deps = MATH_DEPS + [
|
||||
":gpu_prim_hdrs",
|
||||
":transpose_functor",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -4187,11 +4191,7 @@ tf_kernel_library(
|
||||
"scan_ops_gpu_half.cu.cc",
|
||||
"scan_ops_gpu_int.cu.cc",
|
||||
],
|
||||
deps = MATH_DEPS + if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
deps = MATH_DEPS + [":gpu_prim_hdrs"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -4792,11 +4792,7 @@ tf_kernel_library(
|
||||
prefix = "softmax_op",
|
||||
deps = NN_DEPS + if_cuda_or_rocm([
|
||||
":reduction_ops",
|
||||
]) + if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
]) + [":gpu_prim_hdrs"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -4828,11 +4824,7 @@ tf_kernel_library(
|
||||
"topk_op_gpu_int8.cu.cc",
|
||||
"topk_op_gpu_uint8.cu.cc",
|
||||
],
|
||||
deps = NN_DEPS + if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
deps = NN_DEPS + [":gpu_prim_hdrs"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
@ -4851,41 +4843,38 @@ tf_kernel_library(
|
||||
name = "bincount_op",
|
||||
prefix = "bincount_op",
|
||||
deps = [
|
||||
":gpu_prim_hdrs",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//third_party/eigen3",
|
||||
] + if_cuda(["@cub_archive//:cub"]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "histogram_op",
|
||||
prefix = "histogram_op",
|
||||
deps = [
|
||||
":gpu_prim_hdrs",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//third_party/eigen3",
|
||||
] + if_cuda(["@cub_archive//:cub"]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "l2loss_op",
|
||||
prefix = "l2loss_op",
|
||||
deps = [
|
||||
":gpu_prim_hdrs",
|
||||
":reduction_ops",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:nn_grad",
|
||||
] + if_cuda(["@cub_archive//:cub"]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
@ -6024,16 +6013,13 @@ tf_kernel_library(
|
||||
":random_op",
|
||||
":random_ops",
|
||||
":stateless_random_ops",
|
||||
":gpu_prim_hdrs",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
] + if_cuda_or_rocm([
|
||||
":reduction_ops",
|
||||
]) + if_cuda([
|
||||
"@cub_archive//:cub",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocprim",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -17,26 +17,16 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/bincount_op.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -35,14 +35,6 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_reduce.cuh"
|
||||
#include "third_party/cub/iterator/constant_input_iterator.cuh"
|
||||
#include "third_party/cub/thread/thread_operators.cuh"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -52,15 +44,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/gather_functor_gpu.cu.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
#include "tensorflow/core/util/transform_output_iterator.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -20,12 +20,10 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/cub/device/device_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_segmented_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_select.cuh"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/non_max_suppression_op.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
47
tensorflow/core/kernels/gpu_prim.h
Normal file
47
tensorflow/core/kernels/gpu_prim.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
To in writing unless required by applicable law or agreed,
|
||||
distributed on an, software distributed under the license is "AS IS"
|
||||
BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express
|
||||
or implied. For the specific language governing permissions and
|
||||
limitations under the license, the license you must see.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/block/block_load.cuh"
|
||||
#include "third_party/cub/block/block_scan.cuh"
|
||||
#include "third_party/cub/block/block_store.cuh"
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#include "third_party/cub/device/device_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_reduce.cuh"
|
||||
#include "third_party/cub/device/device_segmented_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_segmented_reduce.cuh"
|
||||
#include "third_party/cub/device/device_select.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "third_party/cub/thread/thread_operators.cuh"
|
||||
#include "third_party/cub/warp/warp_reduce.cuh"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
namespace gpuprim = ::hipcub;
|
||||
|
||||
namespace rocprim {
|
||||
namespace detail {
|
||||
template <>
|
||||
struct radix_key_codec_base<Eigen::half>
|
||||
: radix_key_codec_floating<Eigen::half, unsigned short> {};
|
||||
}; // namespace detail
|
||||
}; // namespace rocprim
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
@ -18,26 +18,16 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/histogram_op.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -17,19 +17,12 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/l2loss_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/l2loss_op.h"
|
||||
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <stdio.h>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/multinomial_op.h"
|
||||
#include "tensorflow/core/kernels/random_op.h"
|
||||
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
|
||||
@ -29,12 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/random/random_distributions.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
@ -19,12 +19,10 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/cub/device/device_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_segmented_radix_sort.cuh"
|
||||
#include "third_party/cub/device/device_select.cuh"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/non_max_suppression_op.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
||||
|
@ -23,16 +23,7 @@ limitations under the License.
|
||||
#include <sstream>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_reduce.cuh"
|
||||
#include "third_party/cub/device/device_segmented_reduce.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "third_party/cub/warp/warp_reduce.cuh"
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/util/gpu_device_functions.h"
|
||||
@ -40,12 +31,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/permutation_input_iterator.h"
|
||||
#include "tensorflow/core/util/transform_output_iterator.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
|
||||
|
@ -24,30 +24,15 @@ limitations under the License.
|
||||
#define CUB_USE_COOPERATIVE_GROUPS
|
||||
#endif // CUDA_VERSION >= 9000
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/block/block_load.cuh"
|
||||
#include "third_party/cub/block/block_scan.cuh"
|
||||
#include "third_party/cub/block/block_store.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "third_party/gpus/cuda/include/cuComplex.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/scan_ops.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
#include "tensorflow/core/util/gpu_launch_config.h"
|
||||
#include "tensorflow/core/util/permutation_input_iterator.h"
|
||||
#include "tensorflow/core/util/permutation_output_iterator.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -18,22 +18,18 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
@ -78,6 +78,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/kernels:scatter_nd_op",
|
||||
"//tensorflow/core/kernels:slice_op",
|
||||
"//tensorflow/core/kernels:transpose_functor",
|
||||
"//tensorflow/core/kernels:gpu_prim_hdrs",
|
||||
] + if_cuda_or_rocm([
|
||||
"//tensorflow/core/kernels:cuda_solvers",
|
||||
"//tensorflow/core/kernels:cuda_sparse",
|
||||
|
@ -18,30 +18,17 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_histogram.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/cuda_sparse.h"
|
||||
#include "tensorflow/core/kernels/gpu_device_array.h"
|
||||
#include "tensorflow/core/kernels/gpu_device_array_gpu.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/sparse/kernels.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -17,19 +17,13 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/sparse_xent_op.h"
|
||||
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h"
|
||||
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||
#include "tensorflow/core/kernels/sparse_xent_op.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
@ -23,26 +23,25 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/cub/device/device_segmented_radix_sort.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/topk_op.h"
|
||||
#include "tensorflow/core/lib/gtl/top_n.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Required for sorting Eigen::half
|
||||
namespace cub {
|
||||
template <>
|
||||
struct NumericTraits<Eigen::half>
|
||||
: BaseTraits<FLOATING_POINT, true, false, unsigned short int, Eigen::half> {
|
||||
};
|
||||
: BaseTraits<FLOATING_POINT, true, false, unsigned short, Eigen::half> {};
|
||||
} // namespace cub
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -21,28 +21,15 @@ limitations under the License.
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/cub/device/device_reduce.cuh"
|
||||
#include "third_party/cub/device/device_select.cuh"
|
||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hipcub/hipcub.hpp"
|
||||
#endif
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/gpu_prim.h"
|
||||
#include "tensorflow/core/kernels/where_op.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/gpu_kernel_helper.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace gpuprim = ::cub;
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
namespace gpuprim = ::hipcub;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
Loading…
Reference in New Issue
Block a user