remote config: replace all uses of os.environ by get_host_environ
This change is in prepartion for rolling out remote config. It will allow us to inject environment variables from repository rules as well as from the shell enviroment. PiperOrigin-RevId: 295782466 Change-Id: I1eb61fca3556473e94f2f12c45ee5eb1fe51625b
This commit is contained in:
parent
24839fe95c
commit
f60fc7a072
|
@ -43,6 +43,7 @@ load(
|
|||
"execute",
|
||||
"get_bash_bin",
|
||||
"get_cpu_value",
|
||||
"get_host_environ",
|
||||
"get_python_bin",
|
||||
"is_windows",
|
||||
"raw_exec",
|
||||
|
@ -223,10 +224,9 @@ def find_cc(repository_ctx):
|
|||
cc_path_envvar = _GCC_HOST_COMPILER_PATH
|
||||
cc_name = target_cc_name
|
||||
|
||||
if cc_path_envvar in repository_ctx.os.environ:
|
||||
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
|
||||
if cc_name_from_env:
|
||||
cc_name = cc_name_from_env
|
||||
cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
|
||||
if cc_name_from_env:
|
||||
cc_name = cc_name_from_env
|
||||
if cc_name.startswith("/"):
|
||||
# Absolute path, maybe we should make this supported by our which function.
|
||||
return cc_name
|
||||
|
@ -365,7 +365,7 @@ def _cuda_include_path(repository_ctx, cuda_config):
|
|||
|
||||
def enable_cuda(repository_ctx):
|
||||
"""Returns whether to build with CUDA support."""
|
||||
return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False))
|
||||
return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
|
||||
|
||||
def matches_version(environ_version, detected_version):
|
||||
"""Checks whether the user-specified version matches the detected version.
|
||||
|
@ -409,9 +409,9 @@ _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
|
|||
|
||||
def compute_capabilities(repository_ctx):
|
||||
"""Returns a list of strings representing cuda compute capabilities."""
|
||||
if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
|
||||
capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES)
|
||||
if capabilities_str == None:
|
||||
return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
|
||||
capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
|
||||
capabilities = capabilities_str.split(",")
|
||||
for capability in capabilities:
|
||||
# Workaround for Skylark's lack of support for regex. This check should
|
||||
|
@ -805,18 +805,13 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
|
|||
)""" % (name, "\n".join(outs), src_dir, out_dir)
|
||||
|
||||
def _flag_enabled(repository_ctx, flag_name):
|
||||
if flag_name in repository_ctx.os.environ:
|
||||
value = repository_ctx.os.environ[flag_name].strip()
|
||||
return value == "1"
|
||||
return False
|
||||
return get_host_environ(repository_ctx, flag_name) == "1"
|
||||
|
||||
def _use_cuda_clang(repository_ctx):
|
||||
return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
|
||||
|
||||
def _tf_sysroot(repository_ctx):
|
||||
if _TF_SYSROOT in repository_ctx.os.environ:
|
||||
return repository_ctx.os.environ[_TF_SYSROOT]
|
||||
return ""
|
||||
return get_host_environ(repository_ctx, _TF_SYSROOT, "")
|
||||
|
||||
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
|
||||
capability_flags = [
|
||||
|
@ -1006,9 +1001,10 @@ def _create_local_cuda_repository(repository_ctx):
|
|||
if is_cuda_clang:
|
||||
cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
|
||||
|
||||
host_compiler_prefix = "/usr/bin"
|
||||
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
|
||||
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
|
||||
host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
|
||||
if not host_compiler_prefix:
|
||||
host_compiler_prefix = "/usr/bin"
|
||||
|
||||
cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
|
||||
|
||||
# Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
|
||||
|
@ -1157,14 +1153,15 @@ def _cuda_autoconf_impl(repository_ctx):
|
|||
"""Implementation of the cuda_autoconf repository rule."""
|
||||
if not enable_cuda(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
|
||||
if (_TF_CUDA_VERSION not in repository_ctx.os.environ or
|
||||
_TF_CUDNN_VERSION not in repository_ctx.os.environ):
|
||||
elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
|
||||
has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
|
||||
has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
|
||||
if not has_cuda_version or not has_cudnn_version:
|
||||
auto_configure_fail("%s and %s must also be set if %s is specified" %
|
||||
(_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
|
||||
_create_remote_cuda_repository(
|
||||
repository_ctx,
|
||||
repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
|
||||
get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
|
||||
)
|
||||
else:
|
||||
_create_local_cuda_repository(repository_ctx)
|
||||
|
|
|
@ -26,6 +26,7 @@ load(
|
|||
"files_exist",
|
||||
"get_bash_bin",
|
||||
"get_cpu_value",
|
||||
"get_host_environ",
|
||||
"raw_exec",
|
||||
"realpath",
|
||||
"which",
|
||||
|
@ -79,10 +80,9 @@ def find_cc(repository_ctx):
|
|||
cc_path_envvar = _GCC_HOST_COMPILER_PATH
|
||||
cc_name = target_cc_name
|
||||
|
||||
if cc_path_envvar in repository_ctx.os.environ:
|
||||
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
|
||||
if cc_name_from_env:
|
||||
cc_name = cc_name_from_env
|
||||
cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
|
||||
if cc_name_from_env:
|
||||
cc_name = cc_name_from_env
|
||||
if cc_name.startswith("/"):
|
||||
# Absolute path, maybe we should make this supported by our which function.
|
||||
return cc_name
|
||||
|
@ -252,13 +252,12 @@ def _rocm_include_path(repository_ctx, rocm_config):
|
|||
return inc_dirs
|
||||
|
||||
def _enable_rocm(repository_ctx):
|
||||
if "TF_NEED_ROCM" in repository_ctx.os.environ:
|
||||
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
|
||||
if enable_rocm == "1":
|
||||
if get_cpu_value(repository_ctx) != "Linux":
|
||||
auto_configure_warning("ROCm configure is only supported on Linux")
|
||||
return False
|
||||
return True
|
||||
enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM")
|
||||
if enable_rocm == "1":
|
||||
if get_cpu_value(repository_ctx) != "Linux":
|
||||
auto_configure_warning("ROCm configure is only supported on Linux")
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
def _rocm_toolkit_path(repository_ctx, bash_bin):
|
||||
|
@ -270,18 +269,16 @@ def _rocm_toolkit_path(repository_ctx, bash_bin):
|
|||
Returns:
|
||||
A speculative real path of the rocm toolkit install directory.
|
||||
"""
|
||||
rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH
|
||||
if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ:
|
||||
rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip()
|
||||
rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH)
|
||||
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
|
||||
auto_configure_fail("Cannot find rocm toolkit path.")
|
||||
return realpath(repository_ctx, rocm_toolkit_path, bash_bin)
|
||||
|
||||
def _amdgpu_targets(repository_ctx):
|
||||
"""Returns a list of strings representing AMDGPU targets."""
|
||||
if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ:
|
||||
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
|
||||
if not amdgpu_targets_str:
|
||||
return _DEFAULT_ROCM_AMDGPU_TARGETS
|
||||
amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS]
|
||||
amdgpu_targets = amdgpu_targets_str.split(",")
|
||||
for amdgpu_target in amdgpu_targets:
|
||||
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
|
||||
|
@ -308,9 +305,9 @@ def _hipcc_env(repository_ctx):
|
|||
"HCC_AMDGPU_TARGET",
|
||||
"HIP_PLATFORM",
|
||||
]:
|
||||
if name in repository_ctx.os.environ:
|
||||
hipcc_env = (hipcc_env + " " + name + "=\"" +
|
||||
repository_ctx.os.environ[name].strip() + "\";")
|
||||
env_value = get_host_environ(repository_ctx, name)
|
||||
if env_value:
|
||||
hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
|
||||
return hipcc_env.strip()
|
||||
|
||||
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
|
||||
|
@ -328,7 +325,7 @@ def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
|
|||
|
||||
# check user-defined hip-clang environment variables
|
||||
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
|
||||
if name in repository_ctx.os.environ:
|
||||
if get_host_environ(repository_ctx, name):
|
||||
return "True"
|
||||
|
||||
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
|
||||
|
@ -367,10 +364,7 @@ def _crosstool_verbose(repository_ctx):
|
|||
Returns:
|
||||
A string containing value of environment variable CROSSTOOL_VERBOSE.
|
||||
"""
|
||||
name = "CROSSTOOL_VERBOSE"
|
||||
if name in repository_ctx.os.environ:
|
||||
return repository_ctx.os.environ[name].strip()
|
||||
return "0"
|
||||
return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0")
|
||||
|
||||
def _lib_name(lib, version = "", static = False):
|
||||
"""Constructs the name of a library on Linux.
|
||||
|
@ -701,9 +695,7 @@ def _create_local_rocm_repository(repository_ctx):
|
|||
|
||||
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
|
||||
|
||||
host_compiler_prefix = "/usr/bin"
|
||||
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
|
||||
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
|
||||
host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin")
|
||||
|
||||
rocm_defines = {}
|
||||
|
||||
|
@ -823,10 +815,10 @@ def _rocm_autoconf_impl(repository_ctx):
|
|||
"""Implementation of the rocm_autoconf repository rule."""
|
||||
if not _enable_rocm(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
|
||||
elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None:
|
||||
_create_remote_rocm_repository(
|
||||
repository_ctx,
|
||||
repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO],
|
||||
get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO),
|
||||
)
|
||||
else:
|
||||
_create_local_rocm_repository(repository_ctx)
|
||||
|
|
|
@ -20,6 +20,7 @@ load(
|
|||
load(
|
||||
"//third_party/remote_config:common.bzl",
|
||||
"get_cpu_value",
|
||||
"get_host_environ",
|
||||
)
|
||||
|
||||
_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
|
||||
|
@ -76,9 +77,8 @@ def _nccl_configure_impl(repository_ctx):
|
|||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||
|
||||
nccl_version = ""
|
||||
if _TF_NCCL_VERSION in repository_ctx.os.environ:
|
||||
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
|
||||
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
|
||||
if nccl_version:
|
||||
nccl_version = nccl_version.split(".")[0]
|
||||
|
||||
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])
|
||||
|
|
|
@ -135,7 +135,7 @@ def get_environ(repository_ctx, name, default_value = None):
|
|||
return default_value
|
||||
return result.stdout
|
||||
|
||||
def get_host_environ(repository_ctx, name):
|
||||
def get_host_environ(repository_ctx, name, default_value = None):
|
||||
"""Returns the value of an environment variable on the host platform.
|
||||
|
||||
The host platform is the machine that Bazel runs on.
|
||||
|
@ -147,7 +147,13 @@ def get_host_environ(repository_ctx, name):
|
|||
Returns:
|
||||
The value of the environment variable 'name' on the host platform.
|
||||
"""
|
||||
return repository_ctx.os.environ.get(name)
|
||||
if name in repository_ctx.os.environ:
|
||||
return repository_ctx.os.environ.get(name).strip()
|
||||
|
||||
if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ:
|
||||
return repository_ctx.attr.environ.get(name).strip()
|
||||
|
||||
return default_value
|
||||
|
||||
def is_windows(repository_ctx):
|
||||
"""Returns true if the execution platform is Windows.
|
||||
|
|
|
@ -15,6 +15,7 @@ load(
|
|||
load(
|
||||
"//third_party/remote_config:common.bzl",
|
||||
"get_cpu_value",
|
||||
"get_host_environ",
|
||||
)
|
||||
|
||||
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
||||
|
@ -72,14 +73,14 @@ def _create_dummy_repository(repository_ctx):
|
|||
|
||||
def enable_tensorrt(repository_ctx):
|
||||
"""Returns whether to build with TensorRT support."""
|
||||
return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False))
|
||||
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
||||
|
||||
def _tensorrt_configure_impl(repository_ctx):
|
||||
"""Implementation of the tensorrt_configure repository rule."""
|
||||
|
||||
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
|
||||
if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
|
||||
# Forward to the pre-configured remote repository.
|
||||
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
|
||||
remote_config_repo = get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO)
|
||||
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
|
||||
repository_ctx.template(
|
||||
"build_defs.bzl",
|
||||
|
|
Loading…
Reference in New Issue