Support the new CUDA compute capability options in configure.

sm_35,sm_50,sm_60,compute_70

PiperOrigin-RevId: 313660333
Change-Id: I08b6ccd62fac60645147c30c434055b4e608b190
This commit is contained in:
Amit Patankar 2020-05-28 14:29:56 -07:00 committed by TensorFlower Gardener
parent 68e13f00e1
commit a5393e9046

View File

@ -484,8 +484,8 @@ def check_bazel_version(min_version, max_version):
stderr = open(os.devnull, 'wb') stderr = open(os.devnull, 'wb')
curr_version = run_shell(['bazel', '--version'], curr_version = run_shell(['bazel', '--version'],
allow_non_zero = True, allow_non_zero=True,
stderr = stderr) stderr=stderr)
if curr_version.startswith('bazel '): if curr_version.startswith('bazel '):
curr_version = curr_version.split('bazel ')[1] curr_version = curr_version.split('bazel ')[1]
@ -1011,17 +1011,15 @@ def set_tf_cuda_compute_capabilities(environ_cp):
default_cuda_compute_capabilities = native_cuda_compute_capabilities default_cuda_compute_capabilities = native_cuda_compute_capabilities
ask_cuda_compute_capabilities = ( ask_cuda_compute_capabilities = (
'Please specify a list of comma-separated ' 'Please specify a list of comma-separated CUDA compute capabilities '
'CUDA compute capabilities you want to ' 'you want to build with.\nYou can find the compute capability of your '
'build with.\nYou can find the compute ' 'device at: https://developer.nvidia.com/cuda-gpus. Each capability '
'capability of your device at: ' 'can be specified as "x.y" or "compute_xy" to include both virtual and'
'https://developer.nvidia.com/cuda-gpus.\nPlease' ' binary GPU code, or as "sm_xy" to only include the binary '
' note that each additional compute ' 'code.\nPlease note that each additional compute capability '
'capability significantly increases your ' 'significantly increases your build time and binary size, and that '
'build time and binary size, and that ' 'TensorFlow only supports compute capabilities >= 3.5 [Default is: '
'TensorFlow only supports compute ' '%s]: ' % default_cuda_compute_capabilities)
'capabilities >= 3.5 [Default is: %s]: ' %
default_cuda_compute_capabilities)
tf_cuda_compute_capabilities = get_from_env_or_user_or_default( tf_cuda_compute_capabilities = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES',
ask_cuda_compute_capabilities, default_cuda_compute_capabilities) ask_cuda_compute_capabilities, default_cuda_compute_capabilities)
@ -1033,8 +1031,23 @@ def set_tf_cuda_compute_capabilities(environ_cp):
for compute_capability in tf_cuda_compute_capabilities.split(','): for compute_capability in tf_cuda_compute_capabilities.split(','):
m = re.match('[0-9]+.[0-9]+', compute_capability) m = re.match('[0-9]+.[0-9]+', compute_capability)
if not m: if not m:
print('Invalid compute capability: %s' % compute_capability) # We now support sm_35,sm_50,sm_60,compute_70.
all_valid = False sm_compute_match = re.match('(sm|compute)_?([0-9]+[0-9]+)',
compute_capability)
if not sm_compute_match:
print('Invalid compute capability: %s' % compute_capability)
all_valid = False
else:
ver = int(m.group(2))
if ver < 30:
print(
'ERROR: TensorFlow only supports small CUDA compute'
' capabilities of sm_30 and higher. Please re-specify the list'
' of compute capabilities excluding version %s.' % ver)
all_valid = False
if ver < 35:
print('WARNING: XLA does not support CUDA compute capabilities '
'lower than sm_35. Disable XLA when running on older GPUs.')
else: else:
ver = float(m.group(0)) ver = float(m.group(0))
if ver < 3.0: if ver < 3.0:
@ -1225,7 +1238,8 @@ def is_reduced_optimize_huge_functions_available(environ_cp):
only, as of 2019-11-19). TensorFlow needs this flag to massively reduce only, as of 2019-11-19). TensorFlow needs this flag to massively reduce
compile times, but until 16.4 is officially released, we can't depend on it. compile times, but until 16.4 is officially released, we can't depend on it.
See also https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion See also
https://groups.google.com/a/tensorflow.org/d/topic/build/SsW98Eo7l3o/discussion
Because it's very annoying to check this manually (to check the MSVC installed Because it's very annoying to check this manually (to check the MSVC installed
versions, you need to use the registry, and it's not clear if Bazel will be versions, you need to use the registry, and it's not clear if Bazel will be
@ -1372,7 +1386,7 @@ def main():
current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION, current_bazel_version = check_bazel_version(_TF_MIN_BAZEL_VERSION,
_TF_MAX_BAZEL_VERSION) _TF_MAX_BAZEL_VERSION)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print("Error checking bazel version: ", e.output.decode('UTF-8').strip()) print('Error checking bazel version: ', e.output.decode('UTF-8').strip())
raise e raise e
_TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version) _TF_CURRENT_BAZEL_VERSION = convert_version_to_int(current_bazel_version)