Merge pull request #34885 from samikama:NCCL_cuda10.2

PiperOrigin-RevId: 284216577
Change-Id: I7286bd01720a92f15ea70409efb5d6fc17df303e
This commit is contained in:
TensorFlower Gardener 2019-12-06 11:03:13 -08:00
commit e37dc0f10a
2 changed files with 25 additions and 9 deletions

View File

@ -104,19 +104,21 @@ def _device_link_impl(ctx):
tmp_fatbin = ctx.actions.declare_file("%s.fatbin" % name)
fatbin_h = ctx.actions.declare_file("%s_fatbin.h" % name)
bin2c = ctx.file._bin2c
ctx.actions.run(
outputs = [tmp_fatbin, fatbin_h],
inputs = cubins,
executable = ctx.file._fatbinary,
arguments = [
arguments_list = [
"-64",
"--cmdline=--compile-only",
"--link",
"--compress-all",
"--bin2c-path=%s" % bin2c.dirname,
"--create=%s" % tmp_fatbin.path,
"--embedded-fatbin=%s" % fatbin_h.path,
] + images,
]
if %{use_bin2c_path}:
arguments_list.append("--bin2c-path=%s" % bin2c.dirname)
ctx.actions.run(
outputs = [tmp_fatbin, fatbin_h],
inputs = cubins,
executable = ctx.file._fatbinary,
arguments = arguments_list + images,
tools = [bin2c],
mnemonic = "fatbinary",
)

View File

@ -72,6 +72,11 @@ def _nccl_configure_impl(repository_ctx):
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
nccl_version = nccl_version.split(".")[0]
cuda_config = find_cuda_config(repository_ctx, ["cuda"])
cuda_version = cuda_config["cuda_version"].split(".")
cuda_major = cuda_version[0]
cuda_minor = cuda_version[1]
if nccl_version == "":
# Alias to open source build from @nccl_archive.
repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
@ -84,9 +89,18 @@ def _nccl_configure_impl(repository_ctx):
# Round-about way to make the list unique.
gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys()
repository_ctx.template("build_defs.bzl", _label("build_defs.bzl.tpl"), {
config_wrap = {
"%{gpu_architectures}": str(gpu_architectures),
})
"%{use_bin2c_path}": "False",
}
if (int(cuda_major), int(cuda_minor)) <= (10, 1):
config_wrap["%{use_bin2c_path}"] = "True"
repository_ctx.template(
"build_defs.bzl",
_label("build_defs.bzl.tpl"),
config_wrap,
)
else:
# Create target for locally installed NCCL.
config = find_cuda_config(repository_ctx, ["nccl"])