diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a76cd1f18f9..c64380890bb 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -517,6 +517,16 @@ cc_library( deps = ["//third_party/eigen3"], ) +cc_library( + name = "gpu_prim_hdrs", + hdrs = ["gpu_prim.h"], + deps = if_cuda([ + "@cub_archive//:cub", + ]) + if_rocm([ + "@local_config_rocm//rocm:rocprim", + ]), +) + cc_library( name = "conv_ops_gpu_hdrs", hdrs = ["conv_ops_gpu.h"], @@ -1393,12 +1403,9 @@ tf_kernel_library( "where_op_gpu_impl_8.cu.cc", ], deps = if_cuda_or_rocm([ - ":cuda_solvers", - ]) + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]) + ARRAY_DEPS, + ":cuda_solvers", + ]) + [":gpu_prim_hdrs"] + + ARRAY_DEPS, ) cc_library( @@ -2637,10 +2644,9 @@ tf_kernel_library( deps = DYNAMIC_DEPS + [ ":fill_functor", ":gather_functor", + ":gpu_prim_hdrs", "//tensorflow/core:framework_internal", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ], ) tf_kernel_library( @@ -3082,8 +3088,7 @@ tf_kernel_library( tf_kernel_library( name = "generate_box_proposals_op", gpu_srcs = ["generate_box_proposals_op.cu.cc"], - deps = if_cuda([ - "@cub_archive//:cub", + deps = [":gpu_prim_hdrs"] + if_cuda([ ":non_max_suppression_op_gpu", ]), ) @@ -3091,7 +3096,7 @@ tf_kernel_library( tf_kernel_library( name = "non_max_suppression_op", prefix = "non_max_suppression_op", - deps = IMAGE_DEPS + if_cuda(["@cub_archive//:cub"]), + deps = IMAGE_DEPS + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4160,11 +4165,10 @@ tf_kernel_library( name = "reduction_ops", gpu_srcs = ["reduction_gpu_kernels.cu.h"], prefix = "reduction_ops", - deps = MATH_DEPS + [":transpose_functor"] + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + deps = MATH_DEPS + [ + ":gpu_prim_hdrs", + ":transpose_functor", + ], ) tf_kernel_library( @@ -4187,11 +4191,7 @@ tf_kernel_library( "scan_ops_gpu_half.cu.cc", "scan_ops_gpu_int.cu.cc", ], - deps = MATH_DEPS + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + deps = MATH_DEPS + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4792,11 +4792,7 @@ tf_kernel_library( prefix = "softmax_op", deps = NN_DEPS + if_cuda_or_rocm([ ":reduction_ops", - ]) + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ]) + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4828,11 +4824,7 @@ tf_kernel_library( "topk_op_gpu_int8.cu.cc", "topk_op_gpu_uint8.cu.cc", ], - deps = NN_DEPS + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + deps = NN_DEPS + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4851,41 +4843,38 @@ tf_kernel_library( name = "bincount_op", prefix = "bincount_op", deps = [ + ":gpu_prim_hdrs", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ], ) tf_kernel_library( name = "histogram_op", prefix = "histogram_op", deps = [ + ":gpu_prim_hdrs", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ], ) tf_kernel_library( name = "l2loss_op", prefix = "l2loss_op", deps = [ + ":gpu_prim_hdrs", ":reduction_ops", - "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:nn_grad", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + "//third_party/eigen3", + ], ) tf_cuda_cc_test( @@ -6024,16 +6013,13 @@ tf_kernel_library( ":random_op", ":random_ops", ":stateless_random_ops", + ":gpu_prim_hdrs", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ] + if_cuda_or_rocm([ ":reduction_ops", - ]) + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", ]), ) diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc index de1457d6ddf..56e209819d9 100644 --- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc @@ -17,26 +17,16 @@ limitations under the License. #define EIGEN_USE_GPU -#if GOOGLE_CUDA -#include "third_party/cub/device/device_histogram.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/bincount_op.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index e0fb36eca57..98c2fb57833 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -35,14 +35,6 @@ limitations under the License. #define EIGEN_USE_GPU -#if GOOGLE_CUDA -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/iterator/constant_input_iterator.cuh" -#include "third_party/cub/thread/thread_operators.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -52,15 +44,10 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/gather_functor_gpu.cu.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/transform_output_iterator.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index d3a7574e956..b862c42d299 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -20,12 +20,10 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/device/device_select.cuh" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/non_max_suppression_op.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h new file mode 100644 index 00000000000..82fcb21e0ac --- /dev/null +++ b/tensorflow/core/kernels/gpu_prim.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +To in writing unless required by applicable law or agreed, +distributed on an, software distributed under the license is "AS IS" +BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express +or implied. For the specific language governing permissions and +limitations under the license, the license you must see. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ + +#if GOOGLE_CUDA +#include "third_party/cub/block/block_load.cuh" +#include "third_party/cub/block/block_scan.cuh" +#include "third_party/cub/block/block_store.cuh" +#include "third_party/cub/device/device_histogram.cuh" +#include "third_party/cub/device/device_radix_sort.cuh" +#include "third_party/cub/device/device_reduce.cuh" +#include "third_party/cub/device/device_segmented_radix_sort.cuh" +#include "third_party/cub/device/device_segmented_reduce.cuh" +#include "third_party/cub/device/device_select.cuh" +#include "third_party/cub/iterator/counting_input_iterator.cuh" +#include "third_party/cub/iterator/transform_input_iterator.cuh" +#include "third_party/cub/thread/thread_operators.cuh" +#include "third_party/cub/warp/warp_reduce.cuh" +#include "third_party/gpus/cuda/include/cusparse.h" + +namespace gpuprim = ::cub; +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hipcub/hipcub.hpp" +namespace gpuprim = ::hipcub; + +namespace rocprim { +namespace detail { +template <> +struct radix_key_codec_base + : radix_key_codec_floating {}; +}; // namespace detail +}; // namespace rocprim +#endif // GOOGLE_CUDA + +#endif // TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc index b3d21a0f561..e8a1c630e70 100644 --- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc +++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc @@ -18,26 +18,16 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_histogram.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/histogram_op.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc index 8cb46204869..a2c288c36d1 100644 --- a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc +++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc @@ -17,19 +17,12 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/l2loss_op.h" - #include "tensorflow/core/framework/register_types.h" - +#include "tensorflow/core/kernels/gpu_prim.h" +#include "tensorflow/core/kernels/l2loss_op.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc index 4cb38d5873e..95bc0ed357a 100644 --- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc +++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/multinomial_op.h" #include "tensorflow/core/kernels/random_op.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" @@ -29,12 +30,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { namespace functor { diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 7b2848b2a77..53559b20419 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -19,12 +19,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/device/device_select.cuh" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/non_max_suppression_op.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_launch_config.h" diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index e26b9fd5ad1..c043c6a8e33 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -23,16 +23,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/device/device_segmented_reduce.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/cub/warp/warp_reduce.cuh" -#include "third_party/gpus/cuda/include/cuComplex.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/reduction_ops.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/util/gpu_device_functions.h" @@ -40,12 +31,6 @@ limitations under the License. #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/transform_output_iterator.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { namespace functor { diff --git a/tensorflow/core/kernels/scan_ops_gpu.h b/tensorflow/core/kernels/scan_ops_gpu.h index 27da21982af..aca2a8985de 100644 --- a/tensorflow/core/kernels/scan_ops_gpu.h +++ b/tensorflow/core/kernels/scan_ops_gpu.h @@ -24,30 +24,15 @@ limitations under the License. #define CUB_USE_COOPERATIVE_GROUPS #endif // CUDA_VERSION >= 9000 -#if GOOGLE_CUDA -#include "third_party/cub/block/block_load.cuh" -#include "third_party/cub/block/block_scan.cuh" -#include "third_party/cub/block/block_store.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/gpus/cuda/include/cuComplex.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/scan_ops.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/permutation_output_iterator.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc index 0ec2b008aee..0c09fd2852b 100644 --- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -18,22 +18,18 @@ limitations under the License. #define EIGEN_USE_GPU +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD index fc7fe089f64..1d281bc1d61 100644 --- a/tensorflow/core/kernels/sparse/BUILD +++ b/tensorflow/core/kernels/sparse/BUILD @@ -78,6 +78,7 @@ tf_kernel_library( "//tensorflow/core/kernels:scatter_nd_op", "//tensorflow/core/kernels:slice_op", "//tensorflow/core/kernels:transpose_functor", + "//tensorflow/core/kernels:gpu_prim_hdrs", ] + if_cuda_or_rocm([ "//tensorflow/core/kernels:cuda_solvers", "//tensorflow/core/kernels:cuda_sparse", diff --git a/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc b/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc index 99c6d5b9259..1c014db3d0a 100644 --- a/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc @@ -18,30 +18,17 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_histogram.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/gpus/cuda/include/cusparse.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cuda_sparse.h" #include "tensorflow/core/kernels/gpu_device_array.h" #include "tensorflow/core/kernels/gpu_device_array_gpu.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/sparse/kernels.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc index f651358b47f..862048603f5 100644 --- a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc @@ -17,19 +17,13 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/core/kernels/sparse_xent_op.h" - #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" +#include "tensorflow/core/kernels/sparse_xent_op.h" #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index 3156b6d9bd9..d26dd7a8bc3 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -23,26 +23,25 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/topk_op.h" #include "tensorflow/core/lib/gtl/top_n.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" +#if GOOGLE_CUDA // Required for sorting Eigen::half namespace cub { template <> struct NumericTraits - : BaseTraits { -}; + : BaseTraits {}; } // namespace cub +#endif // GOOGLE_CUDA namespace tensorflow { diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h index f13f504c1d7..98f216c1e5b 100644 --- a/tensorflow/core/kernels/where_op_gpu.cu.h +++ b/tensorflow/core/kernels/where_op_gpu.cu.h @@ -21,28 +21,15 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/device/device_select.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_prim.h" #include "tensorflow/core/kernels/where_op.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice;