Retrieve CUDA targets to build in nvcc wrapper from clang command line option.

Hard-coding it (through a repo rule) in one place is bad enough. The proper solution would be to make CUDA targets a bazel 'feature' and map it to compiler flags in crosstools. The more pressing requirement though is to allow compiling a mix of SASS and PTX binaries, instead of SASS+PTX for every CUDA target.

PiperOrigin-RevId: 311081931
Change-Id: If6aea7bfa08e21984471ce3593e0df3ac2c21798
This commit is contained in:
Christian Sigg 2020-05-12 01:34:43 -07:00 committed by TensorFlower Gardener
parent 4926e23ba4
commit 65773fd394
3 changed files with 29 additions and 47 deletions

View File

@ -53,13 +53,6 @@ NVCC_PATH = '%{nvcc_path}'
PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
NVCC_VERSION = '%{cuda_version}' NVCC_VERSION = '%{cuda_version}'
# TODO(amitpatankar): Benchmark enabling all capabilities by default.
# Environment variable for supported TF CUDA Compute Capabilities
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
CUDA_COMPUTE_ENV_VAR = 'TF_CUDA_COMPUTE_CAPABILITIES'
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
def Log(s): def Log(s):
print('gpus/crosstool: {0}'.format(s)) print('gpus/crosstool: {0}'.format(s))
@ -78,7 +71,8 @@ def GetOptionValue(argv, option):
""" """
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('-' + option, nargs='*', action='append') parser.add_argument(option, nargs='*', action='append')
option = option.lstrip('-').replace('-', '_')
args, _ = parser.parse_known_args(argv) args, _ = parser.parse_known_args(argv)
if not args or not vars(args)[option]: if not args or not vars(args)[option]:
return [] return []
@ -180,17 +174,17 @@ def InvokeNvcc(argv, log=False):
host_compiler_options = GetHostCompilerOptions(argv) host_compiler_options = GetHostCompilerOptions(argv)
nvcc_compiler_options = GetNvccOptions(argv) nvcc_compiler_options = GetNvccOptions(argv)
opt_option = GetOptionValue(argv, 'O') opt_option = GetOptionValue(argv, '-O')
m_options = GetOptionValue(argv, 'm') m_options = GetOptionValue(argv, '-m')
m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
include_options = GetOptionValue(argv, 'I') include_options = GetOptionValue(argv, '-I')
out_file = GetOptionValue(argv, 'o') out_file = GetOptionValue(argv, '-o')
depfiles = GetOptionValue(argv, 'MF') depfiles = GetOptionValue(argv, '-MF')
defines = GetOptionValue(argv, 'D') defines = GetOptionValue(argv, '-D')
defines = ''.join([' -D' + define for define in defines]) defines = ''.join([' -D' + define for define in defines])
undefines = GetOptionValue(argv, 'U') undefines = GetOptionValue(argv, '-U')
undefines = ''.join([' -U' + define for define in undefines]) undefines = ''.join([' -U' + define for define in undefines])
std_options = GetOptionValue(argv, 'std') std_options = GetOptionValue(argv, '-std')
# Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang. # Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang.
nvcc_allowed_std_options = ["c++03", "c++11", "c++14"] nvcc_allowed_std_options = ["c++03", "c++11", "c++14"]
std_options = ''.join([' -std=' + define std_options = ''.join([' -std=' + define
@ -198,7 +192,7 @@ def InvokeNvcc(argv, log=False):
# The list of source files get passed after the -c option. I don't know of # The list of source files get passed after the -c option. I don't know of
# any other reliable way to just get the list of source files to be compiled. # any other reliable way to just get the list of source files to be compiled.
src_files = GetOptionValue(argv, 'c') src_files = GetOptionValue(argv, '-c')
# Pass -w through from host to nvcc, but don't do anything fancier with # Pass -w through from host to nvcc, but don't do anything fancier with
# warnings-related flags, since they're not necessarily the same across # warnings-related flags, since they're not necessarily the same across
@ -224,13 +218,12 @@ def InvokeNvcc(argv, log=False):
srcs = ' '.join(src_files) srcs = ' '.join(src_files)
out = ' -o ' + out_file[0] out = ' -o ' + out_file[0]
supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
nvccopts = '-D_FORCE_INLINES ' nvccopts = '-D_FORCE_INLINES '
for capability in supported_cuda_compute_capabilities: for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
capability = capability.replace('.', '') capability = capability[len('sm_'):]
nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % ( nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
capability, capability, capability) capability, capability, capability)
nvccopts += ' ' + nvcc_compiler_options nvccopts += nvcc_compiler_options
nvccopts += undefines nvccopts += undefines
nvccopts += defines nvccopts += defines
nvccopts += std_options nvccopts += std_options
@ -272,6 +265,7 @@ def main():
if args.x and args.x[0] == 'cuda': if args.x and args.x[0] == 'cuda':
if args.cuda_log: Log('-x cuda') if args.cuda_log: Log('-x cuda')
leftover = [pipes.quote(s) for s in leftover] leftover = [pipes.quote(s) for s in leftover]
args.cuda_log = True
if args.cuda_log: Log('using nvcc') if args.cuda_log: Log('using nvcc')
return InvokeNvcc(leftover, log=args.cuda_log) return InvokeNvcc(leftover, log=args.cuda_log)

View File

@ -37,13 +37,6 @@ GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
NVCC_PATH = '%{nvcc_path}' NVCC_PATH = '%{nvcc_path}'
NVCC_VERSION = '%{cuda_version}' NVCC_VERSION = '%{cuda_version}'
NVCC_TEMP_DIR = "%{nvcc_tmp_dir}" NVCC_TEMP_DIR = "%{nvcc_tmp_dir}"
DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,6.0'
# Taken from environment variable for supported TF CUDA Compute Capabilities
# eg. export TF_CUDA_COMPUTE_CAPABILITIES=3.5,3.7,5.2,6.0,6.1,7.0
supported_cuda_compute_capabilities = os.environ.get(
'TF_CUDA_COMPUTE_CAPABILITIES',
DEFAULT_CUDA_COMPUTE_CAPABILITIES).split(',')
def Log(s): def Log(s):
print('gpus/crosstool: {0}'.format(s)) print('gpus/crosstool: {0}'.format(s))
@ -53,7 +46,7 @@ def GetOptionValue(argv, option):
"""Extract the list of values for option from options. """Extract the list of values for option from options.
Args: Args:
option: The option whose value to extract, without the leading '/'. option: The option whose value to extract.
Returns: Returns:
1. A list of values, either directly following the option, 1. A list of values, either directly following the option,
@ -62,10 +55,11 @@ def GetOptionValue(argv, option):
2. The leftover options. 2. The leftover options.
""" """
parser = ArgumentParser(prefix_chars='/') parser = ArgumentParser(prefix_chars='-/')
parser.add_argument('/' + option, nargs='*', action='append') parser.add_argument(option, nargs='*', action='append')
option = option.lstrip('-/').replace('-', '_')
args, leftover = parser.parse_known_args(argv) args, leftover = parser.parse_known_args(argv)
if args and vars(args)[option]: if args and vars(args).get(option):
return (sum(vars(args)[option], []), leftover) return (sum(vars(args)[option], []), leftover)
return ([], leftover) return ([], leftover)
@ -122,18 +116,18 @@ def InvokeNvcc(argv, log=False):
nvcc_compiler_options, argv = GetNvccOptions(argv) nvcc_compiler_options, argv = GetNvccOptions(argv)
opt_option, argv = GetOptionValue(argv, 'O') opt_option, argv = GetOptionValue(argv, '/O')
opt = ['-g'] opt = ['-g']
if (len(opt_option) > 0 and opt_option[0] != 'd'): if (len(opt_option) > 0 and opt_option[0] != 'd'):
opt = ['-O2'] opt = ['-O2']
include_options, argv = GetOptionValue(argv, 'I') include_options, argv = GetOptionValue(argv, '/I')
includes = ["-I " + include for include in include_options] includes = ["-I " + include for include in include_options]
defines, argv = GetOptionValue(argv, 'D') defines, argv = GetOptionValue(argv, '/D')
defines = ['-D' + define for define in defines] defines = ['-D' + define for define in defines]
undefines, argv = GetOptionValue(argv, 'U') undefines, argv = GetOptionValue(argv, '/U')
undefines = ['-U' + define for define in undefines] undefines = ['-U' + define for define in undefines]
# The rest of the unrecognized options should be passed to host compiler # The rest of the unrecognized options should be passed to host compiler
@ -142,10 +136,10 @@ def InvokeNvcc(argv, log=False):
m_options = ["-m64"] m_options = ["-m64"]
nvccopts = ['-D_FORCE_INLINES'] nvccopts = ['-D_FORCE_INLINES']
for capability in supported_cuda_compute_capabilities: for capability in GetOptionValue(argv, "--cuda-gpu-arch"):
capability = capability.replace('.', '') capability = capability[len('sm_'):]
nvccopts += [r'-gencode=arch=compute_%s,"code=sm_%s,compute_%s"' % ( nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
capability, capability, capability)] capability, capability, capability)
nvccopts += nvcc_compiler_options nvccopts += nvcc_compiler_options
nvccopts += undefines nvccopts += undefines
nvccopts += defines nvccopts += defines

View File

@ -840,10 +840,7 @@ def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
"--cuda-gpu-arch=sm_" + cap.replace(".", "") "--cuda-gpu-arch=sm_" + cap.replace(".", "")
for cap in compute_capabilities for cap in compute_capabilities
] ]
return str(capability_flags)
# Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
# TODO(csigg): Make this consistent with cuda clang and pass unconditionally.
return "if_cuda_clang(%s)" % str(capability_flags)
def _tpl_path(repository_ctx, filename): def _tpl_path(repository_ctx, filename):
return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename)) return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % filename))
@ -1092,9 +1089,6 @@ def _create_local_cuda_repository(repository_ctx):
"%{cuda_version}": cuda_config.cuda_version, "%{cuda_version}": cuda_config.cuda_version,
"%{nvcc_path}": nvcc_path, "%{nvcc_path}": nvcc_path,
"%{gcc_host_compiler_path}": str(cc), "%{gcc_host_compiler_path}": str(cc),
"%{cuda_compute_capabilities}": ", ".join(
["\"%s\"" % c for c in cuda_config.compute_capabilities],
),
"%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx), "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
} }
repository_ctx.template( repository_ctx.template(