From 473223155ef033bab528f0a8272c05c306a792c3 Mon Sep 17 00:00:00 2001
From: Amit Patankar <amitpatankar@google.com>
Date: Tue, 5 Nov 2019 16:55:50 -0800
Subject: [PATCH] Add the default CUDA compute logic to the baseline template
 for the toolchains so they will not be overridden by the generator script.

PiperOrigin-RevId: 278738358
Change-Id: I72d57d19c2a0a44bbc0a38fbb2dfec50371bf977
---
 .../clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl     | 7 +++++++
 .../gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl   | 8 +++++++-
 2 files changed, 14 insertions(+), 1 deletion(-)

diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
index a69be47945b..b1ee3bc6c26 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
@@ -53,6 +53,13 @@ NVCC_PATH = '%{nvcc_path}'
 PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
 NVCC_VERSION = '%{cuda_version}'
 
+
+# TODO(amitpatankar): Benchmark enabling all capabilities by default.
+# Environment variable for supported TF CUDA Compute Capabilities
+# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
+CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
+DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
+
 def Log(s):
   print('gpus/crosstool: {0}'.format(s))
 
diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
index 4d0b9d2367a..59e380a64cc 100644
--- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
+++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
@@ -37,7 +37,13 @@ GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
 NVCC_PATH = '%{nvcc_path}'
 NVCC_VERSION = '%{cuda_version}'
 NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
-supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
+DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
+
+# Taken from environment variable for supported TF CUDA Compute Capabilities
+# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
+supported_cuda_compute_capabilities = os.environ.get(
+    'TF_CUDA_COMPUTE_CAPABILITIES',
+    DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
 
 def Log(s):
   print('gpus/crosstool: {0}'.format(s))