remote config: replace all uses of os.environ by get_host_environ

This change is in prepartion for rolling out remote config. It will
allow us to inject environment variables from repository rules as
well as from the shell enviroment.

PiperOrigin-RevId: 295782466
Change-Id: I1eb61fca3556473e94f2f12c45ee5eb1fe51625b
This commit is contained in:
Jakob Buchgraber 2020-02-18 11:30:58 -08:00 committed by TensorFlower Gardener
parent 24839fe95c
commit f60fc7a072
5 changed files with 54 additions and 58 deletions

View File

@ -43,6 +43,7 @@ load(
"execute",
"get_bash_bin",
"get_cpu_value",
"get_host_environ",
"get_python_bin",
"is_windows",
"raw_exec",
@ -223,10 +224,9 @@ def find_cc(repository_ctx):
cc_path_envvar = _GCC_HOST_COMPILER_PATH
cc_name = target_cc_name
if cc_path_envvar in repository_ctx.os.environ:
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
if cc_name_from_env:
cc_name = cc_name_from_env
cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
if cc_name_from_env:
cc_name = cc_name_from_env
if cc_name.startswith("/"):
# Absolute path, maybe we should make this supported by our which function.
return cc_name
@ -365,7 +365,7 @@ def _cuda_include_path(repository_ctx, cuda_config):
def enable_cuda(repository_ctx):
"""Returns whether to build with CUDA support."""
return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False))
return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version.
@ -409,9 +409,9 @@ _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
def compute_capabilities(repository_ctx):
"""Returns a list of strings representing cuda compute capabilities."""
if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES)
if capabilities_str == None:
return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
capabilities = capabilities_str.split(",")
for capability in capabilities:
# Workaround for Skylark's lack of support for regex. This check should
@ -805,18 +805,13 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
)""" % (name, "\n".join(outs), src_dir, out_dir)
def _flag_enabled(repository_ctx, flag_name):
if flag_name in repository_ctx.os.environ:
value = repository_ctx.os.environ[flag_name].strip()
return value == "1"
return False
return get_host_environ(repository_ctx, flag_name) == "1"
def _use_cuda_clang(repository_ctx):
return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
def _tf_sysroot(repository_ctx):
if _TF_SYSROOT in repository_ctx.os.environ:
return repository_ctx.os.environ[_TF_SYSROOT]
return ""
return get_host_environ(repository_ctx, _TF_SYSROOT, "")
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
capability_flags = [
@ -1006,9 +1001,10 @@ def _create_local_cuda_repository(repository_ctx):
if is_cuda_clang:
cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
host_compiler_prefix = "/usr/bin"
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
if not host_compiler_prefix:
host_compiler_prefix = "/usr/bin"
cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
# Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
@ -1157,14 +1153,15 @@ def _cuda_autoconf_impl(repository_ctx):
"""Implementation of the cuda_autoconf repository rule."""
if not enable_cuda(repository_ctx):
_create_dummy_repository(repository_ctx)
elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
if (_TF_CUDA_VERSION not in repository_ctx.os.environ or
_TF_CUDNN_VERSION not in repository_ctx.os.environ):
elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
if not has_cuda_version or not has_cudnn_version:
auto_configure_fail("%s and %s must also be set if %s is specified" %
(_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
_create_remote_cuda_repository(
repository_ctx,
repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
)
else:
_create_local_cuda_repository(repository_ctx)

View File

@ -26,6 +26,7 @@ load(
"files_exist",
"get_bash_bin",
"get_cpu_value",
"get_host_environ",
"raw_exec",
"realpath",
"which",
@ -79,10 +80,9 @@ def find_cc(repository_ctx):
cc_path_envvar = _GCC_HOST_COMPILER_PATH
cc_name = target_cc_name
if cc_path_envvar in repository_ctx.os.environ:
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
if cc_name_from_env:
cc_name = cc_name_from_env
cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
if cc_name_from_env:
cc_name = cc_name_from_env
if cc_name.startswith("/"):
# Absolute path, maybe we should make this supported by our which function.
return cc_name
@ -252,13 +252,12 @@ def _rocm_include_path(repository_ctx, rocm_config):
return inc_dirs
def _enable_rocm(repository_ctx):
if "TF_NEED_ROCM" in repository_ctx.os.environ:
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
if enable_rocm == "1":
if get_cpu_value(repository_ctx) != "Linux":
auto_configure_warning("ROCm configure is only supported on Linux")
return False
return True
enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM")
if enable_rocm == "1":
if get_cpu_value(repository_ctx) != "Linux":
auto_configure_warning("ROCm configure is only supported on Linux")
return False
return True
return False
def _rocm_toolkit_path(repository_ctx, bash_bin):
@ -270,18 +269,16 @@ def _rocm_toolkit_path(repository_ctx, bash_bin):
Returns:
A speculative real path of the rocm toolkit install directory.
"""
rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH
if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ:
rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip()
rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH)
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
auto_configure_fail("Cannot find rocm toolkit path.")
return realpath(repository_ctx, rocm_toolkit_path, bash_bin)
def _amdgpu_targets(repository_ctx):
"""Returns a list of strings representing AMDGPU targets."""
if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ:
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
if not amdgpu_targets_str:
return _DEFAULT_ROCM_AMDGPU_TARGETS
amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS]
amdgpu_targets = amdgpu_targets_str.split(",")
for amdgpu_target in amdgpu_targets:
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
@ -308,9 +305,9 @@ def _hipcc_env(repository_ctx):
"HCC_AMDGPU_TARGET",
"HIP_PLATFORM",
]:
if name in repository_ctx.os.environ:
hipcc_env = (hipcc_env + " " + name + "=\"" +
repository_ctx.os.environ[name].strip() + "\";")
env_value = get_host_environ(repository_ctx, name)
if env_value:
hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
return hipcc_env.strip()
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
@ -328,7 +325,7 @@ def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
# check user-defined hip-clang environment variables
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
if name in repository_ctx.os.environ:
if get_host_environ(repository_ctx, name):
return "True"
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
@ -367,10 +364,7 @@ def _crosstool_verbose(repository_ctx):
Returns:
A string containing value of environment variable CROSSTOOL_VERBOSE.
"""
name = "CROSSTOOL_VERBOSE"
if name in repository_ctx.os.environ:
return repository_ctx.os.environ[name].strip()
return "0"
return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0")
def _lib_name(lib, version = "", static = False):
"""Constructs the name of a library on Linux.
@ -701,9 +695,7 @@ def _create_local_rocm_repository(repository_ctx):
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
host_compiler_prefix = "/usr/bin"
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin")
rocm_defines = {}
@ -823,10 +815,10 @@ def _rocm_autoconf_impl(repository_ctx):
"""Implementation of the rocm_autoconf repository rule."""
if not _enable_rocm(repository_ctx):
_create_dummy_repository(repository_ctx)
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None:
_create_remote_rocm_repository(
repository_ctx,
repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO],
get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO),
)
else:
_create_local_rocm_repository(repository_ctx)

View File

@ -20,6 +20,7 @@ load(
load(
"//third_party/remote_config:common.bzl",
"get_cpu_value",
"get_host_environ",
)
_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
@ -76,9 +77,8 @@ def _nccl_configure_impl(repository_ctx):
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
nccl_version = ""
if _TF_NCCL_VERSION in repository_ctx.os.environ:
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
if nccl_version:
nccl_version = nccl_version.split(".")[0]
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])

View File

@ -135,7 +135,7 @@ def get_environ(repository_ctx, name, default_value = None):
return default_value
return result.stdout
def get_host_environ(repository_ctx, name):
def get_host_environ(repository_ctx, name, default_value = None):
"""Returns the value of an environment variable on the host platform.
The host platform is the machine that Bazel runs on.
@ -147,7 +147,13 @@ def get_host_environ(repository_ctx, name):
Returns:
The value of the environment variable 'name' on the host platform.
"""
return repository_ctx.os.environ.get(name)
if name in repository_ctx.os.environ:
return repository_ctx.os.environ.get(name).strip()
if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ:
return repository_ctx.attr.environ.get(name).strip()
return default_value
def is_windows(repository_ctx):
"""Returns true if the execution platform is Windows.

View File

@ -15,6 +15,7 @@ load(
load(
"//third_party/remote_config:common.bzl",
"get_cpu_value",
"get_host_environ",
)
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
@ -72,14 +73,14 @@ def _create_dummy_repository(repository_ctx):
def enable_tensorrt(repository_ctx):
"""Returns whether to build with TensorRT support."""
return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False))
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule."""
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
# Forward to the pre-configured remote repository.
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
remote_config_repo = get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO)
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
repository_ctx.template(
"build_defs.bzl",