From 25956c47d6fa218b1fe04fcfd8d9352e6c31b842 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Wed, 15 Jan 2020 19:06:07 -0800 Subject: [PATCH 1/4] Create a wrapper header for rocprim and cub --- tensorflow/core/kernels/BUILD | 75 ++++++++----------- tensorflow/core/kernels/bincount_op_gpu.cu.cc | 12 +-- .../kernels/dynamic_partition_op_gpu.cu.cc | 15 +--- .../kernels/generate_box_proposals_op.cu.cc | 4 +- tensorflow/core/kernels/gpu_prim.h | 56 ++++++++++++++ .../core/kernels/histogram_op_gpu.cu.cc | 12 +-- tensorflow/core/kernels/l2loss_op_gpu.cu.cc | 7 +- .../core/kernels/multinomial_op_gpu.cu.cc | 7 +- .../core/kernels/non_max_suppression_op.cu.cc | 4 +- .../core/kernels/reduction_gpu_kernels.cu.h | 17 +---- tensorflow/core/kernels/scan_ops_gpu.h | 17 +---- tensorflow/core/kernels/softmax_op_gpu.cu.cc | 7 +- tensorflow/core/kernels/sparse/BUILD | 1 + .../core/kernels/sparse/kernels_gpu.cu.cc | 15 +--- .../core/kernels/sparse_xent_op_gpu.cu.cc | 6 +- tensorflow/core/kernels/topk_op_gpu.h | 4 +- tensorflow/core/kernels/where_op_gpu.cu.h | 15 +--- 17 files changed, 101 insertions(+), 173 deletions(-) create mode 100644 tensorflow/core/kernels/gpu_prim.h diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 532acd1f4f7..becc1868aeb 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -510,6 +510,16 @@ cc_library( deps = ["//third_party/eigen3"], ) +cc_library( + name = "gpu_prim_hdrs", + hdrs = ["gpu_prim.h"], + deps = if_cuda([ + "@cub_archive//:cub", + ]) + if_rocm([ + "@local_config_rocm//rocm:rocprim", + ]) +) + cc_library( name = "conv_ops_gpu_hdrs", hdrs = ["conv_ops_gpu.h"], @@ -1360,11 +1370,8 @@ tf_kernel_library( ], deps = if_cuda_or_rocm([ ":cuda_solvers", - ]) + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]) + ARRAY_DEPS, + ]) + [":gpu_prim_hdrs",] + + ARRAY_DEPS, ) cc_library( @@ -2591,9 +2598,7 @@ tf_kernel_library( deps = DYNAMIC_DEPS + [ ":fill_functor", ":gather_functor", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ":gpu_prim_hdrs"], ) tf_kernel_library( @@ -3035,8 +3040,8 @@ tf_kernel_library( tf_kernel_library( name = "generate_box_proposals_op", gpu_srcs = ["generate_box_proposals_op.cu.cc"], - deps = if_cuda([ - "@cub_archive//:cub", + deps = [":gpu_prim_hdrs"] + + if_cuda_or_rocm([ ":non_max_suppression_op_gpu", ]), ) @@ -3044,7 +3049,7 @@ tf_kernel_library( tf_kernel_library( name = "non_max_suppression_op", prefix = "non_max_suppression_op", - deps = IMAGE_DEPS + if_cuda(["@cub_archive//:cub"]), + deps = IMAGE_DEPS + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4041,11 +4046,9 @@ tf_kernel_library( name = "reduction_ops", gpu_srcs = ["reduction_gpu_kernels.cu.h"], prefix = "reduction_ops", - deps = MATH_DEPS + [":transpose_functor"] + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + deps = MATH_DEPS + + [":gpu_prim_hdrs", + ":transpose_functor"], ) tf_kernel_library( @@ -4068,11 +4071,7 @@ tf_kernel_library( "scan_ops_gpu_half.cu.cc", "scan_ops_gpu_int.cu.cc", ], - deps = MATH_DEPS + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + deps = MATH_DEPS + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4493,6 +4492,7 @@ tf_kernel_library( }) + if_cuda([ "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", + ]) + if_cuda_or_rocm([ "//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/core:stream_executor", @@ -4671,11 +4671,7 @@ tf_kernel_library( prefix = "softmax_op", deps = NN_DEPS + if_cuda_or_rocm([ ":reduction_ops", - ]) + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ]) + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4707,11 +4703,7 @@ tf_kernel_library( "topk_op_gpu_int8.cu.cc", "topk_op_gpu_uint8.cu.cc", ], - deps = NN_DEPS + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + deps = NN_DEPS + [":gpu_prim_hdrs"], ) tf_kernel_library( @@ -4730,26 +4722,24 @@ tf_kernel_library( name = "bincount_op", prefix = "bincount_op", deps = [ + ":gpu_prim_hdrs", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ], ) tf_kernel_library( name = "histogram_op", prefix = "histogram_op", deps = [ + ":gpu_prim_hdrs", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + ], ) tf_kernel_library( @@ -4757,14 +4747,12 @@ tf_kernel_library( prefix = "l2loss_op", deps = [ ":reduction_ops", + ":gpu_prim_hdrs", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:nn_grad", - ] + if_cuda(["@cub_archive//:cub"]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", - ]), + "//tensorflow/core:nn_grad"], ) tf_cuda_cc_test( @@ -5892,16 +5880,13 @@ tf_kernel_library( ":random_op", ":random_ops", ":stateless_random_ops", + ":gpu_prim_hdrs", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ] + if_cuda_or_rocm([ ":reduction_ops", - ]) + if_cuda([ - "@cub_archive//:cub", - ]) + if_rocm([ - "@local_config_rocm//rocm:rocprim", ]), ) diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc index de1457d6ddf..30be2723342 100644 --- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc @@ -17,11 +17,6 @@ limitations under the License. #define EIGEN_USE_GPU -#if GOOGLE_CUDA -#include "third_party/cub/device/device_histogram.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -30,12 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index f843ca55ddc..5e75698554a 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -35,14 +35,6 @@ limitations under the License. #define EIGEN_USE_GPU -#if GOOGLE_CUDA -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/iterator/constant_input_iterator.cuh" -#include "third_party/cub/thread/thread_operators.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -53,12 +45,7 @@ limitations under the License. #include "tensorflow/core/kernels/gather_functor_gpu.cu.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/transform_output_iterator.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc index d3a7574e956..edcd49ffe2b 100644 --- a/tensorflow/core/kernels/generate_box_proposals_op.cu.cc +++ b/tensorflow/core/kernels/generate_box_proposals_op.cu.cc @@ -20,9 +20,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/device/device_select.cuh" +#include "gpu_prim.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h new file mode 100644 index 00000000000..35aaf03cb40 --- /dev/null +++ b/tensorflow/core/kernels/gpu_prim.h @@ -0,0 +1,56 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +To in writing unless required by applicable law or agreed, +distributed on an, software distributed under the license is "AS IS" +BASIS, WITHOUT OF ANY KIND WARRANTIES OR CONDITIONS, either express +or implied. For the specific language governing permissions and +limitations under the license, the license you must see. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ +#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_ + +#if GOOGLE_CUDA +#include "third_party/cub/device/device_radix_sort.cuh" +#include "third_party/cub/device/device_segmented_radix_sort.cuh" +#include "third_party/cub/device/device_select.cuh" +#include "third_party/cub/device/device_reduce.cuh" +#include "third_party/cub/device/device_segmented_reduce.cuh" +#include "third_party/cub/device/device_histogram.cuh" +#include "third_party/cub/iterator/counting_input_iterator.cuh" +#include "third_party/cub/iterator/transform_input_iterator.cuh" +#include "third_party/cub/warp/warp_reduce.cuh" +#include "third_party/cub/thread/thread_operators.cuh" +#include "third_party/gpus/cuda/include/cusparse.h" +#include "third_party/cub/block/block_load.cuh" +#include "third_party/cub/block/block_scan.cuh" +#include "third_party/cub/block/block_store.cuh" + +namespace gpuprim = ::cub; +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hipcub/hipcub.hpp" +namespace gpuprim = ::hipcub; +#endif + +#if GOOGLE_CUDA +// Required for sorting Eigen::half +namespace cub { +template <> +struct NumericTraits + : BaseTraits { +}; +} // namespace cub +#elif TENSORFLOW_USE_ROCM +namespace rocprim { + namespace detail { + template<> + struct radix_key_codec_base : radix_key_codec_floating { }; + }; +}; +#endif + +#endif diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc index b3d21a0f561..373d26accb2 100644 --- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc +++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc @@ -18,11 +18,6 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_histogram.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -31,12 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc index 8cb46204869..4fbf471d5b3 100644 --- a/tensorflow/core/kernels/l2loss_op_gpu.cu.cc +++ b/tensorflow/core/kernels/l2loss_op_gpu.cu.cc @@ -23,12 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/reduction_gpu_kernels.cu.h" #include "tensorflow/core/kernels/reduction_ops_common.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc index 4cb38d5873e..7684d6f5cbf 100644 --- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc +++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc @@ -28,12 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/util/gpu_kernel_helper.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/non_max_suppression_op.cu.cc b/tensorflow/core/kernels/non_max_suppression_op.cu.cc index 7b2848b2a77..8fdf2ff3938 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cu.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cu.cc @@ -19,9 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_radix_sort.cuh" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/device/device_select.cuh" +#include "gpu_prim.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index e26b9fd5ad1..ded9c974be4 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -23,16 +23,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/device/device_segmented_reduce.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/cub/warp/warp_reduce.cuh" -#include "third_party/gpus/cuda/include/cuComplex.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif +#include "gpu_prim.h" #include "tensorflow/core/kernels/reduction_ops.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/util/gpu_device_functions.h" @@ -40,12 +31,6 @@ limitations under the License. #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/transform_output_iterator.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { namespace functor { diff --git a/tensorflow/core/kernels/scan_ops_gpu.h b/tensorflow/core/kernels/scan_ops_gpu.h index 27da21982af..23924763aaa 100644 --- a/tensorflow/core/kernels/scan_ops_gpu.h +++ b/tensorflow/core/kernels/scan_ops_gpu.h @@ -24,16 +24,6 @@ limitations under the License. #define CUB_USE_COOPERATIVE_GROUPS #endif // CUDA_VERSION >= 9000 -#if GOOGLE_CUDA -#include "third_party/cub/block/block_load.cuh" -#include "third_party/cub/block/block_scan.cuh" -#include "third_party/cub/block/block_store.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/gpus/cuda/include/cuComplex.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/scan_ops.h" @@ -41,12 +31,7 @@ limitations under the License. #include "tensorflow/core/util/gpu_launch_config.h" #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/permutation_output_iterator.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc index 86f9d93b646..13b4dea0ae3 100644 --- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc @@ -27,12 +27,7 @@ limitations under the License. #include "tensorflow/core/kernels/reduction_ops_common.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD index 6b4dba69ff2..eae5a5857a3 100644 --- a/tensorflow/core/kernels/sparse/BUILD +++ b/tensorflow/core/kernels/sparse/BUILD @@ -77,6 +77,7 @@ tf_kernel_library( "//tensorflow/core/kernels:scatter_nd_op", "//tensorflow/core/kernels:slice_op", "//tensorflow/core/kernels:transpose_functor", + "//tensorflow/core/kernels:gpu_prim_hdrs", ] + if_cuda_or_rocm([ "//tensorflow/core/kernels:cuda_solvers", "//tensorflow/core/kernels:cuda_sparse", diff --git a/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc b/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc index 99c6d5b9259..42aa8bcd330 100644 --- a/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse/kernels_gpu.cu.cc @@ -18,14 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_histogram.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#include "third_party/gpus/cuda/include/cusparse.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif +#include "../gpu_prim.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/cuda_sparse.h" @@ -36,12 +29,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; diff --git a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc index f651358b47f..0faff52c3aa 100644 --- a/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_xent_op_gpu.cu.cc @@ -24,11 +24,7 @@ limitations under the License. #include "tensorflow/core/kernels/reduction_ops_common.h" #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index 3156b6d9bd9..054fb893acf 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -23,9 +23,7 @@ limitations under the License. #include #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/cub/device/device_segmented_radix_sort.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" +#include "gpu_prim.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h index 3795733f959..178c8f12eb5 100644 --- a/tensorflow/core/kernels/where_op_gpu.cu.h +++ b/tensorflow/core/kernels/where_op_gpu.cu.h @@ -21,14 +21,6 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#if GOOGLE_CUDA -#include "third_party/cub/device/device_reduce.cuh" -#include "third_party/cub/device/device_select.cuh" -#include "third_party/cub/iterator/counting_input_iterator.cuh" -#include "third_party/cub/iterator/transform_input_iterator.cuh" -#elif TENSORFLOW_USE_ROCM -#include "rocm/include/hipcub/hipcub.hpp" -#endif #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" @@ -36,12 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" - -#if GOOGLE_CUDA -namespace gpuprim = ::cub; -#elif TENSORFLOW_USE_ROCM -namespace gpuprim = ::hipcub; -#endif +#include "gpu_prim.h" namespace tensorflow { From 04336cf0c6d492653c746e26b6b024a7b9d5fe9a Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Fri, 24 Jan 2020 23:08:15 -0800 Subject: [PATCH 2/4] This struct has been moved to gpu_prim.h --- tensorflow/core/kernels/topk_op_gpu.h | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index 054fb893acf..2da110e49fe 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -34,14 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/gpu_kernel_helper.h" -// Required for sorting Eigen::half -namespace cub { -template <> -struct NumericTraits - : BaseTraits { -}; -} // namespace cub - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; From 0b058fb52e7f1669b197ffe7ea42da90fb1e8514 Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Mon, 10 Feb 2020 23:13:41 -0800 Subject: [PATCH 3/4] Revert some changes that don't belong in this PR --- tensorflow/core/kernels/BUILD | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index becc1868aeb..ef555a42d59 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3041,7 +3041,7 @@ tf_kernel_library( name = "generate_box_proposals_op", gpu_srcs = ["generate_box_proposals_op.cu.cc"], deps = [":gpu_prim_hdrs"] - + if_cuda_or_rocm([ + + if_cuda([ ":non_max_suppression_op_gpu", ]), ) @@ -4492,7 +4492,6 @@ tf_kernel_library( }) + if_cuda([ "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", - ]) + if_cuda_or_rocm([ "//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor:stream_executor_headers", "//tensorflow/core:stream_executor", From 0e0c2c734cefa77bf9992a512da71b65d9b8af2f Mon Sep 17 00:00:00 2001 From: Eugene Kuznetsov Date: Thu, 26 Mar 2020 15:48:35 -0700 Subject: [PATCH 4/4] Fix buildifier errors --- tensorflow/core/kernels/BUILD | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f7393b6ad37..459f4260184 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -524,7 +524,7 @@ cc_library( "@cub_archive//:cub", ]) + if_rocm([ "@local_config_rocm//rocm:rocprim", - ]) + ]), ) cc_library( @@ -1393,9 +1393,9 @@ tf_kernel_library( "where_op_gpu_impl_8.cu.cc", ], deps = if_cuda_or_rocm([ - ":cuda_solvers", - ]) + [":gpu_prim_hdrs",] - + ARRAY_DEPS, + ":cuda_solvers", + ]) + [":gpu_prim_hdrs"] + + ARRAY_DEPS, ) cc_library( @@ -3078,8 +3078,7 @@ tf_kernel_library( tf_kernel_library( name = "generate_box_proposals_op", gpu_srcs = ["generate_box_proposals_op.cu.cc"], - deps = [":gpu_prim_hdrs"] - + if_cuda([ + deps = [":gpu_prim_hdrs"] + if_cuda([ ":non_max_suppression_op_gpu", ]), ) @@ -4162,9 +4161,10 @@ tf_kernel_library( name = "reduction_ops", gpu_srcs = ["reduction_gpu_kernels.cu.h"], prefix = "reduction_ops", - deps = MATH_DEPS + - [":gpu_prim_hdrs", - ":transpose_functor"], + deps = MATH_DEPS + [ + ":gpu_prim_hdrs", + ":transpose_functor", + ], ) tf_kernel_library( @@ -4856,20 +4856,21 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//third_party/eigen3", - ], + ], ) tf_kernel_library( name = "l2loss_op", prefix = "l2loss_op", deps = [ - ":reduction_ops", ":gpu_prim_hdrs", - "//third_party/eigen3", + ":reduction_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:nn_grad"], + "//tensorflow/core:nn_grad", + "//third_party/eigen3", + ], ) tf_cuda_cc_test(