diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl index a69be47945b..b1ee3bc6c26 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl @@ -53,6 +53,13 @@ NVCC_PATH = '%{nvcc_path}' PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) NVCC_VERSION = '%{cuda_version}' + +# TODO(amitpatankar): Benchmark enabling all capabilities by default. +# Environment variable for supported TF CUDA Compute Capabilities +# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0 +CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES' +DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0' + def Log(s): print('gpus/crosstool: {0}'.format(s)) diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index 4d0b9d2367a..59e380a64cc 100644 --- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -37,7 +37,13 @@ GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}') NVCC_PATH = '%{nvcc_path}' NVCC_VERSION = '%{cuda_version}' NVCC_TEMP_DIR = "%{nvcc_tmp_dir}" -supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ] +DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0' + +# Taken from environment variable for supported TF CUDA Compute Capabilities +# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0 +supported_cuda_compute_capabilities = os.environ.get( + 'TF_CUDA_COMPUTE_CAPABILITIES', + DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',') def Log(s): print('gpus/crosstool: {0}'.format(s))