rocm_configure: fix quadratic runtime due to label resolution
PiperOrigin-RevId: 294196207 Change-Id: I23de3ce524a58b04655fe027ffd6ce7ff040784c
This commit is contained in:
parent
c7ce11b6f3
commit
1cb495c7eb
62
third_party/gpus/rocm_configure.bzl
vendored
62
third_party/gpus/rocm_configure.bzl
vendored
@ -495,22 +495,18 @@ def _get_rocm_config(repository_ctx):
|
||||
amdgpu_targets = _amdgpu_targets(repository_ctx),
|
||||
)
|
||||
|
||||
def _tpl_path(repository_ctx, labelname):
|
||||
return repository_ctx.path(Label("//third_party/gpus/%s.tpl" % labelname))
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
|
||||
if not out:
|
||||
out = tpl.replace(":", "/")
|
||||
repository_ctx.template(
|
||||
out,
|
||||
Label("//third_party/gpus/%s.tpl" % tpl),
|
||||
_tpl_path(repository_ctx, tpl),
|
||||
substitutions,
|
||||
)
|
||||
|
||||
def _file(repository_ctx, label):
|
||||
repository_ctx.template(
|
||||
label.replace(":", "/"),
|
||||
Label("//third_party/gpus/%s.tpl" % label),
|
||||
{},
|
||||
)
|
||||
|
||||
_DUMMY_CROSSTOOL_BZL_FILE = """
|
||||
def error_gpu_disabled():
|
||||
fail("ERROR: Building with --config=rocm but TensorFlow is not configured " +
|
||||
@ -622,6 +618,16 @@ def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets):
|
||||
|
||||
def _create_local_rocm_repository(repository_ctx):
|
||||
"""Creates the repository containing files set up to build with ROCm."""
|
||||
|
||||
tpl_paths = {labelname: _tpl_path(repository_ctx, labelname) for labelname in [
|
||||
"rocm:build_defs.bzl",
|
||||
"rocm:BUILD",
|
||||
"crosstool:BUILD.rocm",
|
||||
"crosstool:hipcc_cc_toolchain_config.bzl",
|
||||
"crosstool:clang/bin/crosstool_wrapper_driver_rocm",
|
||||
"rocm:rocm_config.h",
|
||||
]}
|
||||
|
||||
rocm_config = _get_rocm_config(repository_ctx)
|
||||
|
||||
# Copy header and library files to execroot.
|
||||
@ -680,9 +686,9 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
))
|
||||
|
||||
# Set up BUILD file for rocm/
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"rocm:build_defs.bzl",
|
||||
repository_ctx.template(
|
||||
"rocm/build_defs.bzl",
|
||||
tpl_paths["rocm:build_defs.bzl"],
|
||||
{
|
||||
"%{rocm_is_configured}": "True",
|
||||
"%{rocm_extra_copts}": _compute_rocm_extra_copts(
|
||||
@ -691,9 +697,9 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
),
|
||||
},
|
||||
)
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"rocm:BUILD",
|
||||
repository_ctx.template(
|
||||
"rocm/BUILD",
|
||||
tpl_paths["rocm:BUILD"],
|
||||
{
|
||||
"%{hip_lib}": rocm_libs["hip"].file_name,
|
||||
"%{rocblas_lib}": rocm_libs["rocblas"].file_name,
|
||||
@ -759,24 +765,22 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
verify_build_defines(rocm_defines)
|
||||
|
||||
# Only expand template variables in the BUILD file
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:BUILD.rocm",
|
||||
repository_ctx.template(
|
||||
"crosstool/BUILD",
|
||||
tpl_paths["crosstool:BUILD.rocm"],
|
||||
rocm_defines,
|
||||
out = "crosstool/BUILD",
|
||||
)
|
||||
|
||||
# No templating of cc_toolchain_config - use attributes and templatize the
|
||||
# BUILD file.
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:hipcc_cc_toolchain_config.bzl",
|
||||
out = "crosstool/cc_toolchain_config.bzl",
|
||||
repository_ctx.template(
|
||||
"crosstool/cc_toolchain_config.bzl",
|
||||
tpl_paths["crosstool:hipcc_cc_toolchain_config.bzl"],
|
||||
)
|
||||
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:clang/bin/crosstool_wrapper_driver_rocm",
|
||||
repository_ctx.template(
|
||||
"crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
|
||||
tpl_paths["crosstool:clang/bin/crosstool_wrapper_driver_rocm"],
|
||||
{
|
||||
"%{cpu_compiler}": str(cc),
|
||||
"%{hipcc_path}": rocm_config.rocm_toolkit_path + "/bin/hipcc",
|
||||
@ -794,21 +798,19 @@ def _create_local_rocm_repository(repository_ctx):
|
||||
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
|
||||
),
|
||||
},
|
||||
out = "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
|
||||
)
|
||||
|
||||
# Set up rocm_config.h, which is used by
|
||||
# tensorflow/stream_executor/dso_loader.cc.
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"rocm:rocm_config.h",
|
||||
repository_ctx.template(
|
||||
"rocm/rocm/rocm_config.h",
|
||||
tpl_paths["rocm:rocm_config.h"],
|
||||
{
|
||||
"%{rocm_amdgpu_targets}": ",".join(
|
||||
["\"%s\"" % c for c in rocm_config.amdgpu_targets],
|
||||
),
|
||||
"%{rocm_toolkit_path}": rocm_config.rocm_toolkit_path,
|
||||
},
|
||||
"rocm/rocm/rocm_config.h",
|
||||
)
|
||||
|
||||
def _create_remote_rocm_repository(repository_ctx, remote_config_repo):
|
||||
|
Loading…
Reference in New Issue
Block a user