Use CUB from the CUDA Toolkit starting with version 11.0.
PiperOrigin-RevId: 327096097 Change-Id: I444ec3ac3348f76728c931a4bb4aa1b7cbe1b673
This commit is contained in:
parent
4112865ad4
commit
c6769e20bf
@ -490,7 +490,7 @@ cc_library(
|
|||||||
name = "gpu_prim_hdrs",
|
name = "gpu_prim_hdrs",
|
||||||
hdrs = ["gpu_prim.h"],
|
hdrs = ["gpu_prim.h"],
|
||||||
deps = if_cuda([
|
deps = if_cuda([
|
||||||
"@cub_archive//:cub",
|
"@local_config_cuda//cuda:cub_headers",
|
||||||
]) + if_rocm([
|
]) + if_rocm([
|
||||||
"@local_config_rocm//rocm:rocprim",
|
"@local_config_rocm//rocm:rocprim",
|
||||||
]),
|
]),
|
||||||
@ -3896,7 +3896,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
"@cub_archive//:cub",
|
"@local_config_cuda//cuda:cub_headers",
|
||||||
"@local_config_cuda//cuda:cudnn_header",
|
"@local_config_cuda//cuda:cudnn_header",
|
||||||
]) + if_rocm([
|
]) + if_rocm([
|
||||||
"@local_config_rocm//rocm:rocprim",
|
"@local_config_rocm//rocm:rocprim",
|
||||||
@ -3986,7 +3986,7 @@ tf_kernel_library(
|
|||||||
] + if_cuda_or_rocm([
|
] + if_cuda_or_rocm([
|
||||||
":reduction_ops",
|
":reduction_ops",
|
||||||
]) + if_cuda([
|
]) + if_cuda([
|
||||||
"@cub_archive//:cub",
|
"@local_config_cuda//cuda:cub_headers",
|
||||||
"//tensorflow/core:stream_executor",
|
"//tensorflow/core:stream_executor",
|
||||||
"//tensorflow/stream_executor/cuda:cuda_stream",
|
"//tensorflow/stream_executor/cuda:cuda_stream",
|
||||||
]) + if_rocm([
|
]) + if_rocm([
|
||||||
@ -4708,7 +4708,7 @@ tf_kernel_library(
|
|||||||
] + if_cuda_or_rocm([
|
] + if_cuda_or_rocm([
|
||||||
":reduction_ops",
|
":reduction_ops",
|
||||||
]) + if_cuda([
|
]) + if_cuda([
|
||||||
"@cub_archive//:cub",
|
"@local_config_cuda//cuda:cub_headers",
|
||||||
]) + if_rocm([
|
]) + if_rocm([
|
||||||
"@local_config_rocm//rocm:rocprim",
|
"@local_config_rocm//rocm:rocprim",
|
||||||
]),
|
]),
|
||||||
|
@ -15,19 +15,19 @@ limitations under the license, the license you must see.
|
|||||||
#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
#define TENSORFLOW_CORE_KERNELS_GPU_PRIM_H_
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/cub/block/block_load.cuh"
|
#include "cub/block/block_load.cuh"
|
||||||
#include "third_party/cub/block/block_scan.cuh"
|
#include "cub/block/block_scan.cuh"
|
||||||
#include "third_party/cub/block/block_store.cuh"
|
#include "cub/block/block_store.cuh"
|
||||||
#include "third_party/cub/device/device_histogram.cuh"
|
#include "cub/device/device_histogram.cuh"
|
||||||
#include "third_party/cub/device/device_radix_sort.cuh"
|
#include "cub/device/device_radix_sort.cuh"
|
||||||
#include "third_party/cub/device/device_reduce.cuh"
|
#include "cub/device/device_reduce.cuh"
|
||||||
#include "third_party/cub/device/device_segmented_radix_sort.cuh"
|
#include "cub/device/device_segmented_radix_sort.cuh"
|
||||||
#include "third_party/cub/device/device_segmented_reduce.cuh"
|
#include "cub/device/device_segmented_reduce.cuh"
|
||||||
#include "third_party/cub/device/device_select.cuh"
|
#include "cub/device/device_select.cuh"
|
||||||
#include "third_party/cub/iterator/counting_input_iterator.cuh"
|
#include "cub/iterator/counting_input_iterator.cuh"
|
||||||
#include "third_party/cub/iterator/transform_input_iterator.cuh"
|
#include "cub/iterator/transform_input_iterator.cuh"
|
||||||
#include "third_party/cub/thread/thread_operators.cuh"
|
#include "cub/thread/thread_operators.cuh"
|
||||||
#include "third_party/cub/warp/warp_reduce.cuh"
|
#include "cub/warp/warp_reduce.cuh"
|
||||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||||
|
|
||||||
namespace gpuprim = ::cub;
|
namespace gpuprim = ::cub;
|
||||||
|
@ -626,7 +626,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
"//tensorflow/stream_executor/cuda:cusparse_lib",
|
||||||
"@cub_archive//:cub",
|
"@local_config_cuda//cuda:cub_headers",
|
||||||
]) + if_rocm([
|
]) + if_rocm([
|
||||||
"@local_config_rocm//rocm:hipsparse",
|
"@local_config_rocm//rocm:hipsparse",
|
||||||
]),
|
]),
|
||||||
|
1
third_party/cub.BUILD
vendored
1
third_party/cub.BUILD
vendored
@ -20,7 +20,6 @@ filegroup(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "cub",
|
name = "cub",
|
||||||
hdrs = if_cuda([":cub_header_files"]),
|
hdrs = if_cuda([":cub_header_files"]),
|
||||||
include_prefix = "third_party",
|
|
||||||
deps = [
|
deps = [
|
||||||
"@local_config_cuda//cuda:cuda_headers",
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
],
|
],
|
||||||
|
48
third_party/cub.pr170.patch
vendored
48
third_party/cub.pr170.patch
vendored
@ -1,48 +0,0 @@
|
|||||||
From fd6e7a61a16a17fa155cbd717de0c79001af71e6 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Artem Belevich <tra@google.com>
|
|
||||||
Date: Mon, 23 Sep 2019 11:18:56 -0700
|
|
||||||
Subject: [PATCH] Fix CUDA version detection in CUB
|
|
||||||
|
|
||||||
This fixes the problem with CUB using deprecated shfl/vote instructions when CUB
|
|
||||||
is compiled with clang (e.g. some TensorFlow builds).
|
|
||||||
---
|
|
||||||
cub/util_arch.cuh | 3 ++-
|
|
||||||
cub/util_type.cuh | 4 ++--
|
|
||||||
2 files changed, 4 insertions(+), 3 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/cub/util_arch.cuh b/cub/util_arch.cuh
|
|
||||||
index 87c5ea2fb..9ad9d1cbb 100644
|
|
||||||
--- a/cub/util_arch.cuh
|
|
||||||
+++ b/cub/util_arch.cuh
|
|
||||||
@@ -44,7 +44,8 @@ namespace cub {
|
|
||||||
|
|
||||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
|
|
||||||
|
|
||||||
-#if (__CUDACC_VER_MAJOR__ >= 9) && !defined(CUB_USE_COOPERATIVE_GROUPS)
|
|
||||||
+#if !defined(CUB_USE_COOPERATIVE_GROUPS) && \
|
|
||||||
+ (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
|
|
||||||
#define CUB_USE_COOPERATIVE_GROUPS
|
|
||||||
#endif
|
|
||||||
|
|
||||||
diff --git a/cub/util_type.cuh b/cub/util_type.cuh
|
|
||||||
index 0ba41e1ed..b2433d735 100644
|
|
||||||
--- a/cub/util_type.cuh
|
|
||||||
+++ b/cub/util_type.cuh
|
|
||||||
@@ -37,7 +37,7 @@
|
|
||||||
#include <limits>
|
|
||||||
#include <cfloat>
|
|
||||||
|
|
||||||
-#if (__CUDACC_VER_MAJOR__ >= 9)
|
|
||||||
+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
@@ -1063,7 +1063,7 @@ struct FpLimits<double>
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
-#if (__CUDACC_VER_MAJOR__ >= 9)
|
|
||||||
+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
|
|
||||||
template <>
|
|
||||||
struct FpLimits<__half>
|
|
||||||
{
|
|
6
third_party/gpus/cuda/BUILD.tpl
vendored
6
third_party/gpus/cuda/BUILD.tpl
vendored
@ -176,6 +176,11 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "cub_headers",
|
||||||
|
actual = "%{cub_actual}"
|
||||||
|
)
|
||||||
|
|
||||||
cuda_header_library(
|
cuda_header_library(
|
||||||
name = "cupti_headers",
|
name = "cupti_headers",
|
||||||
hdrs = [":cuda-extras"],
|
hdrs = [":cuda-extras"],
|
||||||
@ -224,3 +229,4 @@ py_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
%{copy_rules}
|
%{copy_rules}
|
||||||
|
|
||||||
|
5
third_party/gpus/cuda/BUILD.windows.tpl
vendored
5
third_party/gpus/cuda/BUILD.windows.tpl
vendored
@ -171,6 +171,11 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
alias(
|
||||||
|
name = "cub_headers",
|
||||||
|
actual = "%{cub_actual}"
|
||||||
|
)
|
||||||
|
|
||||||
cuda_header_library(
|
cuda_header_library(
|
||||||
name = "cupti_headers",
|
name = "cupti_headers",
|
||||||
hdrs = [":cuda-extras"],
|
hdrs = [":cuda-extras"],
|
||||||
|
7
third_party/gpus/cuda_configure.bzl
vendored
7
third_party/gpus/cuda_configure.bzl
vendored
@ -692,6 +692,7 @@ def _get_cuda_config(repository_ctx, find_cuda_config_script):
|
|||||||
return struct(
|
return struct(
|
||||||
cuda_toolkit_path = toolkit_path,
|
cuda_toolkit_path = toolkit_path,
|
||||||
cuda_version = cuda_version,
|
cuda_version = cuda_version,
|
||||||
|
cuda_version_major = cuda_major,
|
||||||
cublas_version = cublas_version,
|
cublas_version = cublas_version,
|
||||||
cusolver_version = cusolver_version,
|
cusolver_version = cusolver_version,
|
||||||
curand_version = curand_version,
|
curand_version = curand_version,
|
||||||
@ -776,6 +777,7 @@ def _create_dummy_repository(repository_ctx):
|
|||||||
"%{curand_lib}": lib_name("curand", cpu_value),
|
"%{curand_lib}": lib_name("curand", cpu_value),
|
||||||
"%{cupti_lib}": lib_name("cupti", cpu_value),
|
"%{cupti_lib}": lib_name("cupti", cpu_value),
|
||||||
"%{cusparse_lib}": lib_name("cusparse", cpu_value),
|
"%{cusparse_lib}": lib_name("cusparse", cpu_value),
|
||||||
|
"%{cub_actual}": ":cuda_headers",
|
||||||
"%{copy_rules}": """
|
"%{copy_rules}": """
|
||||||
filegroup(name="cuda-include")
|
filegroup(name="cuda-include")
|
||||||
filegroup(name="cublas-include")
|
filegroup(name="cublas-include")
|
||||||
@ -1122,6 +1124,10 @@ def _create_local_cuda_repository(repository_ctx):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cub_actual = "@cub_archive//:cub"
|
||||||
|
if int(cuda_config.cuda_version_major) >= 11:
|
||||||
|
cub_actual = ":cuda_headers"
|
||||||
|
|
||||||
repository_ctx.template(
|
repository_ctx.template(
|
||||||
"cuda/BUILD",
|
"cuda/BUILD",
|
||||||
tpl_paths["cuda:BUILD"],
|
tpl_paths["cuda:BUILD"],
|
||||||
@ -1137,6 +1143,7 @@ def _create_local_cuda_repository(repository_ctx):
|
|||||||
"%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]),
|
"%{curand_lib}": _basename(repository_ctx, cuda_libs["curand"]),
|
||||||
"%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]),
|
"%{cupti_lib}": _basename(repository_ctx, cuda_libs["cupti"]),
|
||||||
"%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]),
|
"%{cusparse_lib}": _basename(repository_ctx, cuda_libs["cusparse"]),
|
||||||
|
"%{cub_actual}": cub_actual,
|
||||||
"%{copy_rules}": "\n".join(copy_rules),
|
"%{copy_rules}": "\n".join(copy_rules),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user