Add the default CUDA compute logic to the baseline template for the toolchains so they will not be overridden by the generator script.
PiperOrigin-RevId: 278738358 Change-Id: I72d57d19c2a0a44bbc0a38fbb2dfec50371bf977
This commit is contained in:
parent
596dcd2fd1
commit
473223155e
@ -53,6 +53,13 @@ NVCC_PATH = '%{nvcc_path}'
|
||||
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
|
||||
NVCC_VERSION = '%{cuda_version}'
|
||||
|
||||
|
||||
# TODO(amitpatankar): Benchmark enabling all capabilities by default.
|
||||
# Environment variable for supported TF CUDA Compute Capabilities
|
||||
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
|
||||
CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
|
||||
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
|
||||
|
||||
def Log(s):
|
||||
print('gpus/crosstool: {0}'.format(s))
|
||||
|
||||
|
@ -37,7 +37,13 @@ GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
|
||||
NVCC_PATH = '%{nvcc_path}'
|
||||
NVCC_VERSION = '%{cuda_version}'
|
||||
NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
|
||||
supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
|
||||
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
|
||||
|
||||
# Taken from environment variable for supported TF CUDA Compute Capabilities
|
||||
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
|
||||
supported_cuda_compute_capabilities = os.environ.get(
|
||||
'TF_CUDA_COMPUTE_CAPABILITIES',
|
||||
DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
|
||||
|
||||
def Log(s):
|
||||
print('gpus/crosstool: {0}'.format(s))
|
||||
|
Loading…
Reference in New Issue
Block a user