From f60fc7a072182df99ddbef50a873e8a544341855 Mon Sep 17 00:00:00 2001 From: Jakob Buchgraber Date: Tue, 18 Feb 2020 11:30:58 -0800 Subject: [PATCH] remote config: replace all uses of os.environ by get_host_environ This change is in prepartion for rolling out remote config. It will allow us to inject environment variables from repository rules as well as from the shell enviroment. PiperOrigin-RevId: 295782466 Change-Id: I1eb61fca3556473e94f2f12c45ee5eb1fe51625b --- third_party/gpus/cuda_configure.bzl | 39 ++++++++-------- third_party/gpus/rocm_configure.bzl | 50 +++++++++------------ third_party/nccl/nccl_configure.bzl | 6 +-- third_party/remote_config/common.bzl | 10 ++++- third_party/tensorrt/tensorrt_configure.bzl | 7 +-- 5 files changed, 54 insertions(+), 58 deletions(-) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 6fbe306457f..1f132e96f2c 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -43,6 +43,7 @@ load( "execute", "get_bash_bin", "get_cpu_value", + "get_host_environ", "get_python_bin", "is_windows", "raw_exec", @@ -223,10 +224,9 @@ def find_cc(repository_ctx): cc_path_envvar = _GCC_HOST_COMPILER_PATH cc_name = target_cc_name - if cc_path_envvar in repository_ctx.os.environ: - cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() - if cc_name_from_env: - cc_name = cc_name_from_env + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env if cc_name.startswith("/"): # Absolute path, maybe we should make this supported by our which function. return cc_name @@ -365,7 +365,7 @@ def _cuda_include_path(repository_ctx, cuda_config): def enable_cuda(repository_ctx): """Returns whether to build with CUDA support.""" - return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False)) + return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False)) def matches_version(environ_version, detected_version): """Checks whether the user-specified version matches the detected version. @@ -409,9 +409,9 @@ _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" def compute_capabilities(repository_ctx): """Returns a list of strings representing cuda compute capabilities.""" - if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ: + capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES) + if capabilities_str == None: return _DEFAULT_CUDA_COMPUTE_CAPABILITIES - capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES] capabilities = capabilities_str.split(",") for capability in capabilities: # Workaround for Skylark's lack of support for regex. This check should @@ -805,18 +805,13 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir): )""" % (name, "\n".join(outs), src_dir, out_dir) def _flag_enabled(repository_ctx, flag_name): - if flag_name in repository_ctx.os.environ: - value = repository_ctx.os.environ[flag_name].strip() - return value == "1" - return False + return get_host_environ(repository_ctx, flag_name) == "1" def _use_cuda_clang(repository_ctx): return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") def _tf_sysroot(repository_ctx): - if _TF_SYSROOT in repository_ctx.os.environ: - return repository_ctx.os.environ[_TF_SYSROOT] - return "" + return get_host_environ(repository_ctx, _TF_SYSROOT, "") def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): capability_flags = [ @@ -1006,9 +1001,10 @@ def _create_local_cuda_repository(repository_ctx): if is_cuda_clang: cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"] - host_compiler_prefix = "/usr/bin" - if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ: - host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip() + host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX) + if not host_compiler_prefix: + host_compiler_prefix = "/usr/bin" + cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see @@ -1157,14 +1153,15 @@ def _cuda_autoconf_impl(repository_ctx): """Implementation of the cuda_autoconf repository rule.""" if not enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) - elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: - if (_TF_CUDA_VERSION not in repository_ctx.os.environ or - _TF_CUDNN_VERSION not in repository_ctx.os.environ): + elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None: + has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None + has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None + if not has_cuda_version or not has_cudnn_version: auto_configure_fail("%s and %s must also be set if %s is specified" % (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO)) _create_remote_cuda_repository( repository_ctx, - repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO], + get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO), ) else: _create_local_cuda_repository(repository_ctx) diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index de885f71d18..063271b83f2 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -26,6 +26,7 @@ load( "files_exist", "get_bash_bin", "get_cpu_value", + "get_host_environ", "raw_exec", "realpath", "which", @@ -79,10 +80,9 @@ def find_cc(repository_ctx): cc_path_envvar = _GCC_HOST_COMPILER_PATH cc_name = target_cc_name - if cc_path_envvar in repository_ctx.os.environ: - cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() - if cc_name_from_env: - cc_name = cc_name_from_env + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env if cc_name.startswith("/"): # Absolute path, maybe we should make this supported by our which function. return cc_name @@ -252,13 +252,12 @@ def _rocm_include_path(repository_ctx, rocm_config): return inc_dirs def _enable_rocm(repository_ctx): - if "TF_NEED_ROCM" in repository_ctx.os.environ: - enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip() - if enable_rocm == "1": - if get_cpu_value(repository_ctx) != "Linux": - auto_configure_warning("ROCm configure is only supported on Linux") - return False - return True + enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM") + if enable_rocm == "1": + if get_cpu_value(repository_ctx) != "Linux": + auto_configure_warning("ROCm configure is only supported on Linux") + return False + return True return False def _rocm_toolkit_path(repository_ctx, bash_bin): @@ -270,18 +269,16 @@ def _rocm_toolkit_path(repository_ctx, bash_bin): Returns: A speculative real path of the rocm toolkit install directory. """ - rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH - if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ: - rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip() + rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH) if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]: auto_configure_fail("Cannot find rocm toolkit path.") return realpath(repository_ctx, rocm_toolkit_path, bash_bin) def _amdgpu_targets(repository_ctx): """Returns a list of strings representing AMDGPU targets.""" - if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ: + amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS) + if not amdgpu_targets_str: return _DEFAULT_ROCM_AMDGPU_TARGETS - amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS] amdgpu_targets = amdgpu_targets_str.split(",") for amdgpu_target in amdgpu_targets: if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): @@ -308,9 +305,9 @@ def _hipcc_env(repository_ctx): "HCC_AMDGPU_TARGET", "HIP_PLATFORM", ]: - if name in repository_ctx.os.environ: - hipcc_env = (hipcc_env + " " + name + "=\"" + - repository_ctx.os.environ[name].strip() + "\";") + env_value = get_host_environ(repository_ctx, name) + if env_value: + hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";") return hipcc_env.strip() def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin): @@ -328,7 +325,7 @@ def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin): # check user-defined hip-clang environment variables for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]: - if name in repository_ctx.os.environ: + if get_host_environ(repository_ctx, name): return "True" # grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo @@ -367,10 +364,7 @@ def _crosstool_verbose(repository_ctx): Returns: A string containing value of environment variable CROSSTOOL_VERBOSE. """ - name = "CROSSTOOL_VERBOSE" - if name in repository_ctx.os.environ: - return repository_ctx.os.environ[name].strip() - return "0" + return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0") def _lib_name(lib, version = "", static = False): """Constructs the name of a library on Linux. @@ -701,9 +695,7 @@ def _create_local_rocm_repository(repository_ctx): host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) - host_compiler_prefix = "/usr/bin" - if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ: - host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip() + host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin") rocm_defines = {} @@ -823,10 +815,10 @@ def _rocm_autoconf_impl(repository_ctx): """Implementation of the rocm_autoconf repository rule.""" if not _enable_rocm(repository_ctx): _create_dummy_repository(repository_ctx) - elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ: + elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None: _create_remote_rocm_repository( repository_ctx, - repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO], + get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO), ) else: _create_local_rocm_repository(repository_ctx) diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 952276a0701..363a65f1f43 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -20,6 +20,7 @@ load( load( "//third_party/remote_config:common.bzl", "get_cpu_value", + "get_host_environ", ) _CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH" @@ -76,9 +77,8 @@ def _nccl_configure_impl(repository_ctx): # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")) - nccl_version = "" - if _TF_NCCL_VERSION in repository_ctx.os.environ: - nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip() + nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "") + if nccl_version: nccl_version = nccl_version.split(".")[0] cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"]) diff --git a/third_party/remote_config/common.bzl b/third_party/remote_config/common.bzl index 6f6e4be2304..353e9bb1a63 100644 --- a/third_party/remote_config/common.bzl +++ b/third_party/remote_config/common.bzl @@ -135,7 +135,7 @@ def get_environ(repository_ctx, name, default_value = None): return default_value return result.stdout -def get_host_environ(repository_ctx, name): +def get_host_environ(repository_ctx, name, default_value = None): """Returns the value of an environment variable on the host platform. The host platform is the machine that Bazel runs on. @@ -147,7 +147,13 @@ def get_host_environ(repository_ctx, name): Returns: The value of the environment variable 'name' on the host platform. """ - return repository_ctx.os.environ.get(name) + if name in repository_ctx.os.environ: + return repository_ctx.os.environ.get(name).strip() + + if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ: + return repository_ctx.attr.environ.get(name).strip() + + return default_value def is_windows(repository_ctx): """Returns true if the execution platform is Windows. diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl index 1d780e855cc..b3375dc224f 100644 --- a/third_party/tensorrt/tensorrt_configure.bzl +++ b/third_party/tensorrt/tensorrt_configure.bzl @@ -15,6 +15,7 @@ load( load( "//third_party/remote_config:common.bzl", "get_cpu_value", + "get_host_environ", ) _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" @@ -72,14 +73,14 @@ def _create_dummy_repository(repository_ctx): def enable_tensorrt(repository_ctx): """Returns whether to build with TensorRT support.""" - return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False)) + return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False)) def _tensorrt_configure_impl(repository_ctx): """Implementation of the tensorrt_configure repository rule.""" - if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ: + if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None: # Forward to the pre-configured remote repository. - remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] + remote_config_repo = get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {}) repository_ctx.template( "build_defs.bzl",