Make nccl bindings compilable with cuda 10.2
This commit is contained in:
parent
c5816f6baa
commit
67edc16326
16
third_party/nccl/build_defs.bzl.tpl
vendored
16
third_party/nccl/build_defs.bzl.tpl
vendored
@ -104,19 +104,21 @@ def _device_link_impl(ctx):
|
|||||||
tmp_fatbin = ctx.actions.declare_file("%s.fatbin" % name)
|
tmp_fatbin = ctx.actions.declare_file("%s.fatbin" % name)
|
||||||
fatbin_h = ctx.actions.declare_file("%s_fatbin.h" % name)
|
fatbin_h = ctx.actions.declare_file("%s_fatbin.h" % name)
|
||||||
bin2c = ctx.file._bin2c
|
bin2c = ctx.file._bin2c
|
||||||
ctx.actions.run(
|
arguments_list = [
|
||||||
outputs = [tmp_fatbin, fatbin_h],
|
|
||||||
inputs = cubins,
|
|
||||||
executable = ctx.file._fatbinary,
|
|
||||||
arguments = [
|
|
||||||
"-64",
|
"-64",
|
||||||
"--cmdline=--compile-only",
|
"--cmdline=--compile-only",
|
||||||
"--link",
|
"--link",
|
||||||
"--compress-all",
|
"--compress-all",
|
||||||
"--bin2c-path=%s" % bin2c.dirname,
|
|
||||||
"--create=%s" % tmp_fatbin.path,
|
"--create=%s" % tmp_fatbin.path,
|
||||||
"--embedded-fatbin=%s" % fatbin_h.path,
|
"--embedded-fatbin=%s" % fatbin_h.path,
|
||||||
] + images,
|
]
|
||||||
|
if %{use_bin2c_path}:
|
||||||
|
arguments_list.append("--bin2c-path=%s" % bin2c.dirname)
|
||||||
|
ctx.actions.run(
|
||||||
|
outputs = [tmp_fatbin, fatbin_h],
|
||||||
|
inputs = cubins,
|
||||||
|
executable = ctx.file._fatbinary,
|
||||||
|
arguments = arguments_list + images,
|
||||||
tools = [bin2c],
|
tools = [bin2c],
|
||||||
mnemonic = "fatbinary",
|
mnemonic = "fatbinary",
|
||||||
)
|
)
|
||||||
|
|||||||
18
third_party/nccl/nccl_configure.bzl
vendored
18
third_party/nccl/nccl_configure.bzl
vendored
@ -72,6 +72,11 @@ def _nccl_configure_impl(repository_ctx):
|
|||||||
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
|
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
|
||||||
nccl_version = nccl_version.split(".")[0]
|
nccl_version = nccl_version.split(".")[0]
|
||||||
|
|
||||||
|
cuda_config = find_cuda_config(repository_ctx, ["cuda"])
|
||||||
|
cuda_version = cuda_config["cuda_version"].split(".")
|
||||||
|
cuda_major = cuda_version[0]
|
||||||
|
cuda_minor = cuda_version[1]
|
||||||
|
|
||||||
if nccl_version == "":
|
if nccl_version == "":
|
||||||
# Alias to open source build from @nccl_archive.
|
# Alias to open source build from @nccl_archive.
|
||||||
repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
|
repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
|
||||||
@ -84,9 +89,18 @@ def _nccl_configure_impl(repository_ctx):
|
|||||||
|
|
||||||
# Round-about way to make the list unique.
|
# Round-about way to make the list unique.
|
||||||
gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys()
|
gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys()
|
||||||
repository_ctx.template("build_defs.bzl", _label("build_defs.bzl.tpl"), {
|
config_wrap = {
|
||||||
"%{gpu_architectures}": str(gpu_architectures),
|
"%{gpu_architectures}": str(gpu_architectures),
|
||||||
})
|
"%{use_bin2c_path}": "False",
|
||||||
|
}
|
||||||
|
if (int(cuda_major), int(cuda_minor)) <= (10, 1):
|
||||||
|
config_wrap["%{use_bin2c_path}"] = "True"
|
||||||
|
|
||||||
|
repository_ctx.template(
|
||||||
|
"build_defs.bzl",
|
||||||
|
_label("build_defs.bzl.tpl"),
|
||||||
|
config_wrap,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Create target for locally installed NCCL.
|
# Create target for locally installed NCCL.
|
||||||
config = find_cuda_config(repository_ctx, ["nccl"])
|
config = find_cuda_config(repository_ctx, ["nccl"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user