diff --git a/third_party/gpus/rocm/build_defs.bzl.tpl b/third_party/gpus/rocm/build_defs.bzl.tpl index 08c59f95a07..ce4c1b04399 100644 --- a/third_party/gpus/rocm/build_defs.bzl.tpl +++ b/third_party/gpus/rocm/build_defs.bzl.tpl @@ -34,6 +34,10 @@ def rocm_is_configured(): """Returns true if ROCm was enabled during the configure process.""" return %{rocm_is_configured} +def rocm_gpu_architectures(): + """Returns a list of supported GPU architectures.""" + return %{rocm_gpu_architectures} + def if_rocm_is_configured(x): """Tests if the ROCm was enabled during the configure process. diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 752f48aa25b..1b429d346e3 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -9,8 +9,7 @@ * `TF_ROCM_VERSION`: The version of the ROCm toolkit. If this is blank, then use the system default. * `TF_MIOPEN_VERSION`: The version of the MIOpen library. - * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. Default is - `gfx803,gfx900`. + * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ load( @@ -44,7 +43,6 @@ _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" _DEFAULT_ROCM_VERSION = "" _DEFAULT_MIOPEN_VERSION = "" _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" -_DEFAULT_ROCM_AMDGPU_TARGETS = ["gfx803", "gfx900"] def verify_build_defines(params): """Verify all variables that crosstool/BUILD.rocm.tpl expects are substituted. @@ -228,11 +226,14 @@ def _rocm_toolkit_path(repository_ctx, bash_bin): auto_configure_fail("Cannot find rocm toolkit path.") return rocm_toolkit_path -def _amdgpu_targets(repository_ctx): +def _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin): """Returns a list of strings representing AMDGPU targets.""" amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS) if not amdgpu_targets_str: - return _DEFAULT_ROCM_AMDGPU_TARGETS + cmd = "%s/bin/rocm_agent_enumerator" % rocm_toolkit_path + result = execute(repository_ctx, [bash_bin, "-c", cmd]) + targets = [target for target in result.stdout.strip().split("\n") if target != "gfx000"] + amdgpu_targets_str = ",".join(targets) amdgpu_targets = amdgpu_targets_str.split(",") for amdgpu_target in amdgpu_targets: if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): @@ -416,7 +417,7 @@ def _get_rocm_config(repository_ctx, bash_bin): rocm_toolkit_path = _rocm_toolkit_path(repository_ctx, bash_bin) return struct( rocm_toolkit_path = rocm_toolkit_path, - amdgpu_targets = _amdgpu_targets(repository_ctx), + amdgpu_targets = _amdgpu_targets(repository_ctx, rocm_toolkit_path, bash_bin), ) def _tpl_path(repository_ctx, labelname): @@ -464,6 +465,7 @@ def _create_dummy_repository(repository_ctx): { "%{rocm_is_configured}": "False", "%{rocm_extra_copts}": "[]", + "%{rocm_gpu_architectures}": "[]", }, ) _tpl( @@ -532,12 +534,8 @@ def _genrule(src_dir, genrule_name, command, outs): ) def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): - if False: - amdgpu_target_flags = ["--amdgpu-target=" + + amdgpu_target_flags = ["--amdgpu-target=" + amdgpu_target for amdgpu_target in amdgpu_targets] - else: - # AMDGPU targets are handled in the "crosstool_wrapper_driver_is_not_gcc" - amdgpu_target_flags = [] return str(amdgpu_target_flags) def _create_local_rocm_repository(repository_ctx): @@ -621,6 +619,7 @@ def _create_local_rocm_repository(repository_ctx): repository_ctx, rocm_config.amdgpu_targets, ), + "%{rocm_gpu_architectures}": str(rocm_config.amdgpu_targets), }, ) repository_ctx.template( @@ -719,10 +718,7 @@ def _create_local_rocm_repository(repository_ctx): "%{hcc_runtime_library}": "mcwamp", "%{crosstool_verbose}": _crosstool_verbose(repository_ctx), "%{gcc_host_compiler_path}": str(cc), - "%{rocm_amdgpu_targets}": ",".join( - ["\"%s\"" % c for c in rocm_config.amdgpu_targets], - ), - }, + }, ) # Set up rocm_config.h, which is used by