Patch CUB with http://github.com/NVlabs/cub/pull/170 to build with clang.
PiperOrigin-RevId: 270958366
This commit is contained in:
parent
93a0fb7e51
commit
b226c6aa06
@ -722,6 +722,7 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
tf_http_archive(
|
||||
name = "cub_archive",
|
||||
build_file = clean_dep("//third_party:cub.BUILD"),
|
||||
patch_file = clean_dep("//third_party:cub.pr170.patch"),
|
||||
sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3",
|
||||
strip_prefix = "cub-1.8.0",
|
||||
urls = [
|
||||
|
48
third_party/cub.pr170.patch
vendored
Normal file
48
third_party/cub.pr170.patch
vendored
Normal file
@ -0,0 +1,48 @@
|
||||
From fd6e7a61a16a17fa155cbd717de0c79001af71e6 Mon Sep 17 00:00:00 2001
|
||||
From: Artem Belevich <tra@google.com>
|
||||
Date: Mon, 23 Sep 2019 11:18:56 -0700
|
||||
Subject: [PATCH] Fix CUDA version detection in CUB
|
||||
|
||||
This fixes the problem with CUB using deprecated shfl/vote instructions when CUB
|
||||
is compiled with clang (e.g. some TensorFlow builds).
|
||||
---
|
||||
cub/util_arch.cuh | 3 ++-
|
||||
cub/util_type.cuh | 4 ++--
|
||||
2 files changed, 4 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/cub/util_arch.cuh b/cub/util_arch.cuh
|
||||
index 87c5ea2fb..9ad9d1cbb 100644
|
||||
--- a/cub/util_arch.cuh
|
||||
+++ b/cub/util_arch.cuh
|
||||
@@ -44,7 +44,8 @@ namespace cub {
|
||||
|
||||
#ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document
|
||||
|
||||
-#if (__CUDACC_VER_MAJOR__ >= 9) && !defined(CUB_USE_COOPERATIVE_GROUPS)
|
||||
+#if !defined(CUB_USE_COOPERATIVE_GROUPS) && \
|
||||
+ (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
|
||||
#define CUB_USE_COOPERATIVE_GROUPS
|
||||
#endif
|
||||
|
||||
diff --git a/cub/util_type.cuh b/cub/util_type.cuh
|
||||
index 0ba41e1ed..b2433d735 100644
|
||||
--- a/cub/util_type.cuh
|
||||
+++ b/cub/util_type.cuh
|
||||
@@ -37,7 +37,7 @@
|
||||
#include <limits>
|
||||
#include <cfloat>
|
||||
|
||||
-#if (__CUDACC_VER_MAJOR__ >= 9)
|
||||
+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
@@ -1063,7 +1063,7 @@ struct FpLimits<double>
|
||||
};
|
||||
|
||||
|
||||
-#if (__CUDACC_VER_MAJOR__ >= 9)
|
||||
+#if (__CUDACC_VER_MAJOR__ >= 9 || CUDA_VERSION >= 9000)
|
||||
template <>
|
||||
struct FpLimits<__half>
|
||||
{
|
Loading…
Reference in New Issue
Block a user