remote config: replace all uses of os.environ by get_host_environ
This change is in prepartion for rolling out remote config. It will allow us to inject environment variables from repository rules as well as from the shell enviroment. PiperOrigin-RevId: 295782466 Change-Id: I1eb61fca3556473e94f2f12c45ee5eb1fe51625b
This commit is contained in:
parent
24839fe95c
commit
f60fc7a072
|
@ -43,6 +43,7 @@ load(
|
||||||
"execute",
|
"execute",
|
||||||
"get_bash_bin",
|
"get_bash_bin",
|
||||||
"get_cpu_value",
|
"get_cpu_value",
|
||||||
|
"get_host_environ",
|
||||||
"get_python_bin",
|
"get_python_bin",
|
||||||
"is_windows",
|
"is_windows",
|
||||||
"raw_exec",
|
"raw_exec",
|
||||||
|
@ -223,10 +224,9 @@ def find_cc(repository_ctx):
|
||||||
cc_path_envvar = _GCC_HOST_COMPILER_PATH
|
cc_path_envvar = _GCC_HOST_COMPILER_PATH
|
||||||
cc_name = target_cc_name
|
cc_name = target_cc_name
|
||||||
|
|
||||||
if cc_path_envvar in repository_ctx.os.environ:
|
cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
|
||||||
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
|
if cc_name_from_env:
|
||||||
if cc_name_from_env:
|
cc_name = cc_name_from_env
|
||||||
cc_name = cc_name_from_env
|
|
||||||
if cc_name.startswith("/"):
|
if cc_name.startswith("/"):
|
||||||
# Absolute path, maybe we should make this supported by our which function.
|
# Absolute path, maybe we should make this supported by our which function.
|
||||||
return cc_name
|
return cc_name
|
||||||
|
@ -365,7 +365,7 @@ def _cuda_include_path(repository_ctx, cuda_config):
|
||||||
|
|
||||||
def enable_cuda(repository_ctx):
|
def enable_cuda(repository_ctx):
|
||||||
"""Returns whether to build with CUDA support."""
|
"""Returns whether to build with CUDA support."""
|
||||||
return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False))
|
return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
|
||||||
|
|
||||||
def matches_version(environ_version, detected_version):
|
def matches_version(environ_version, detected_version):
|
||||||
"""Checks whether the user-specified version matches the detected version.
|
"""Checks whether the user-specified version matches the detected version.
|
||||||
|
@ -409,9 +409,9 @@ _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
|
||||||
|
|
||||||
def compute_capabilities(repository_ctx):
|
def compute_capabilities(repository_ctx):
|
||||||
"""Returns a list of strings representing cuda compute capabilities."""
|
"""Returns a list of strings representing cuda compute capabilities."""
|
||||||
if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
|
capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES)
|
||||||
|
if capabilities_str == None:
|
||||||
return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
|
return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
|
||||||
capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
|
|
||||||
capabilities = capabilities_str.split(",")
|
capabilities = capabilities_str.split(",")
|
||||||
for capability in capabilities:
|
for capability in capabilities:
|
||||||
# Workaround for Skylark's lack of support for regex. This check should
|
# Workaround for Skylark's lack of support for regex. This check should
|
||||||
|
@ -805,18 +805,13 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
|
||||||
)""" % (name, "\n".join(outs), src_dir, out_dir)
|
)""" % (name, "\n".join(outs), src_dir, out_dir)
|
||||||
|
|
||||||
def _flag_enabled(repository_ctx, flag_name):
|
def _flag_enabled(repository_ctx, flag_name):
|
||||||
if flag_name in repository_ctx.os.environ:
|
return get_host_environ(repository_ctx, flag_name) == "1"
|
||||||
value = repository_ctx.os.environ[flag_name].strip()
|
|
||||||
return value == "1"
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _use_cuda_clang(repository_ctx):
|
def _use_cuda_clang(repository_ctx):
|
||||||
return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
|
return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
|
||||||
|
|
||||||
def _tf_sysroot(repository_ctx):
|
def _tf_sysroot(repository_ctx):
|
||||||
if _TF_SYSROOT in repository_ctx.os.environ:
|
return get_host_environ(repository_ctx, _TF_SYSROOT, "")
|
||||||
return repository_ctx.os.environ[_TF_SYSROOT]
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
|
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
|
||||||
capability_flags = [
|
capability_flags = [
|
||||||
|
@ -1006,9 +1001,10 @@ def _create_local_cuda_repository(repository_ctx):
|
||||||
if is_cuda_clang:
|
if is_cuda_clang:
|
||||||
cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
|
cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
|
||||||
|
|
||||||
host_compiler_prefix = "/usr/bin"
|
host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
|
||||||
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
|
if not host_compiler_prefix:
|
||||||
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
|
host_compiler_prefix = "/usr/bin"
|
||||||
|
|
||||||
cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
|
cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
|
||||||
|
|
||||||
# Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
|
# Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
|
||||||
|
@ -1157,14 +1153,15 @@ def _cuda_autoconf_impl(repository_ctx):
|
||||||
"""Implementation of the cuda_autoconf repository rule."""
|
"""Implementation of the cuda_autoconf repository rule."""
|
||||||
if not enable_cuda(repository_ctx):
|
if not enable_cuda(repository_ctx):
|
||||||
_create_dummy_repository(repository_ctx)
|
_create_dummy_repository(repository_ctx)
|
||||||
elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
|
elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
|
||||||
if (_TF_CUDA_VERSION not in repository_ctx.os.environ or
|
has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
|
||||||
_TF_CUDNN_VERSION not in repository_ctx.os.environ):
|
has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
|
||||||
|
if not has_cuda_version or not has_cudnn_version:
|
||||||
auto_configure_fail("%s and %s must also be set if %s is specified" %
|
auto_configure_fail("%s and %s must also be set if %s is specified" %
|
||||||
(_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
|
(_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
|
||||||
_create_remote_cuda_repository(
|
_create_remote_cuda_repository(
|
||||||
repository_ctx,
|
repository_ctx,
|
||||||
repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
|
get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_create_local_cuda_repository(repository_ctx)
|
_create_local_cuda_repository(repository_ctx)
|
||||||
|
|
|
@ -26,6 +26,7 @@ load(
|
||||||
"files_exist",
|
"files_exist",
|
||||||
"get_bash_bin",
|
"get_bash_bin",
|
||||||
"get_cpu_value",
|
"get_cpu_value",
|
||||||
|
"get_host_environ",
|
||||||
"raw_exec",
|
"raw_exec",
|
||||||
"realpath",
|
"realpath",
|
||||||
"which",
|
"which",
|
||||||
|
@ -79,10 +80,9 @@ def find_cc(repository_ctx):
|
||||||
cc_path_envvar = _GCC_HOST_COMPILER_PATH
|
cc_path_envvar = _GCC_HOST_COMPILER_PATH
|
||||||
cc_name = target_cc_name
|
cc_name = target_cc_name
|
||||||
|
|
||||||
if cc_path_envvar in repository_ctx.os.environ:
|
cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
|
||||||
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
|
if cc_name_from_env:
|
||||||
if cc_name_from_env:
|
cc_name = cc_name_from_env
|
||||||
cc_name = cc_name_from_env
|
|
||||||
if cc_name.startswith("/"):
|
if cc_name.startswith("/"):
|
||||||
# Absolute path, maybe we should make this supported by our which function.
|
# Absolute path, maybe we should make this supported by our which function.
|
||||||
return cc_name
|
return cc_name
|
||||||
|
@ -252,13 +252,12 @@ def _rocm_include_path(repository_ctx, rocm_config):
|
||||||
return inc_dirs
|
return inc_dirs
|
||||||
|
|
||||||
def _enable_rocm(repository_ctx):
|
def _enable_rocm(repository_ctx):
|
||||||
if "TF_NEED_ROCM" in repository_ctx.os.environ:
|
enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM")
|
||||||
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
|
if enable_rocm == "1":
|
||||||
if enable_rocm == "1":
|
if get_cpu_value(repository_ctx) != "Linux":
|
||||||
if get_cpu_value(repository_ctx) != "Linux":
|
auto_configure_warning("ROCm configure is only supported on Linux")
|
||||||
auto_configure_warning("ROCm configure is only supported on Linux")
|
return False
|
||||||
return False
|
return True
|
||||||
return True
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _rocm_toolkit_path(repository_ctx, bash_bin):
|
def _rocm_toolkit_path(repository_ctx, bash_bin):
|
||||||
|
@ -270,18 +269,16 @@ def _rocm_toolkit_path(repository_ctx, bash_bin):
|
||||||
Returns:
|
Returns:
|
||||||
A speculative real path of the rocm toolkit install directory.
|
A speculative real path of the rocm toolkit install directory.
|
||||||
"""
|
"""
|
||||||
rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH
|
rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH)
|
||||||
if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ:
|
|
||||||
rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip()
|
|
||||||
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
|
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
|
||||||
auto_configure_fail("Cannot find rocm toolkit path.")
|
auto_configure_fail("Cannot find rocm toolkit path.")
|
||||||
return realpath(repository_ctx, rocm_toolkit_path, bash_bin)
|
return realpath(repository_ctx, rocm_toolkit_path, bash_bin)
|
||||||
|
|
||||||
def _amdgpu_targets(repository_ctx):
|
def _amdgpu_targets(repository_ctx):
|
||||||
"""Returns a list of strings representing AMDGPU targets."""
|
"""Returns a list of strings representing AMDGPU targets."""
|
||||||
if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ:
|
amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
|
||||||
|
if not amdgpu_targets_str:
|
||||||
return _DEFAULT_ROCM_AMDGPU_TARGETS
|
return _DEFAULT_ROCM_AMDGPU_TARGETS
|
||||||
amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS]
|
|
||||||
amdgpu_targets = amdgpu_targets_str.split(",")
|
amdgpu_targets = amdgpu_targets_str.split(",")
|
||||||
for amdgpu_target in amdgpu_targets:
|
for amdgpu_target in amdgpu_targets:
|
||||||
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
|
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
|
||||||
|
@ -308,9 +305,9 @@ def _hipcc_env(repository_ctx):
|
||||||
"HCC_AMDGPU_TARGET",
|
"HCC_AMDGPU_TARGET",
|
||||||
"HIP_PLATFORM",
|
"HIP_PLATFORM",
|
||||||
]:
|
]:
|
||||||
if name in repository_ctx.os.environ:
|
env_value = get_host_environ(repository_ctx, name)
|
||||||
hipcc_env = (hipcc_env + " " + name + "=\"" +
|
if env_value:
|
||||||
repository_ctx.os.environ[name].strip() + "\";")
|
hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
|
||||||
return hipcc_env.strip()
|
return hipcc_env.strip()
|
||||||
|
|
||||||
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
|
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
|
||||||
|
@ -328,7 +325,7 @@ def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
|
||||||
|
|
||||||
# check user-defined hip-clang environment variables
|
# check user-defined hip-clang environment variables
|
||||||
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
|
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
|
||||||
if name in repository_ctx.os.environ:
|
if get_host_environ(repository_ctx, name):
|
||||||
return "True"
|
return "True"
|
||||||
|
|
||||||
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
|
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
|
||||||
|
@ -367,10 +364,7 @@ def _crosstool_verbose(repository_ctx):
|
||||||
Returns:
|
Returns:
|
||||||
A string containing value of environment variable CROSSTOOL_VERBOSE.
|
A string containing value of environment variable CROSSTOOL_VERBOSE.
|
||||||
"""
|
"""
|
||||||
name = "CROSSTOOL_VERBOSE"
|
return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0")
|
||||||
if name in repository_ctx.os.environ:
|
|
||||||
return repository_ctx.os.environ[name].strip()
|
|
||||||
return "0"
|
|
||||||
|
|
||||||
def _lib_name(lib, version = "", static = False):
|
def _lib_name(lib, version = "", static = False):
|
||||||
"""Constructs the name of a library on Linux.
|
"""Constructs the name of a library on Linux.
|
||||||
|
@ -701,9 +695,7 @@ def _create_local_rocm_repository(repository_ctx):
|
||||||
|
|
||||||
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
|
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
|
||||||
|
|
||||||
host_compiler_prefix = "/usr/bin"
|
host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin")
|
||||||
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
|
|
||||||
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
|
|
||||||
|
|
||||||
rocm_defines = {}
|
rocm_defines = {}
|
||||||
|
|
||||||
|
@ -823,10 +815,10 @@ def _rocm_autoconf_impl(repository_ctx):
|
||||||
"""Implementation of the rocm_autoconf repository rule."""
|
"""Implementation of the rocm_autoconf repository rule."""
|
||||||
if not _enable_rocm(repository_ctx):
|
if not _enable_rocm(repository_ctx):
|
||||||
_create_dummy_repository(repository_ctx)
|
_create_dummy_repository(repository_ctx)
|
||||||
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
|
elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None:
|
||||||
_create_remote_rocm_repository(
|
_create_remote_rocm_repository(
|
||||||
repository_ctx,
|
repository_ctx,
|
||||||
repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO],
|
get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_create_local_rocm_repository(repository_ctx)
|
_create_local_rocm_repository(repository_ctx)
|
||||||
|
|
|
@ -20,6 +20,7 @@ load(
|
||||||
load(
|
load(
|
||||||
"//third_party/remote_config:common.bzl",
|
"//third_party/remote_config:common.bzl",
|
||||||
"get_cpu_value",
|
"get_cpu_value",
|
||||||
|
"get_host_environ",
|
||||||
)
|
)
|
||||||
|
|
||||||
_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
|
_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
|
||||||
|
@ -76,9 +77,8 @@ def _nccl_configure_impl(repository_ctx):
|
||||||
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
|
||||||
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
|
||||||
|
|
||||||
nccl_version = ""
|
nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
|
||||||
if _TF_NCCL_VERSION in repository_ctx.os.environ:
|
if nccl_version:
|
||||||
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
|
|
||||||
nccl_version = nccl_version.split(".")[0]
|
nccl_version = nccl_version.split(".")[0]
|
||||||
|
|
||||||
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])
|
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])
|
||||||
|
|
|
@ -135,7 +135,7 @@ def get_environ(repository_ctx, name, default_value = None):
|
||||||
return default_value
|
return default_value
|
||||||
return result.stdout
|
return result.stdout
|
||||||
|
|
||||||
def get_host_environ(repository_ctx, name):
|
def get_host_environ(repository_ctx, name, default_value = None):
|
||||||
"""Returns the value of an environment variable on the host platform.
|
"""Returns the value of an environment variable on the host platform.
|
||||||
|
|
||||||
The host platform is the machine that Bazel runs on.
|
The host platform is the machine that Bazel runs on.
|
||||||
|
@ -147,7 +147,13 @@ def get_host_environ(repository_ctx, name):
|
||||||
Returns:
|
Returns:
|
||||||
The value of the environment variable 'name' on the host platform.
|
The value of the environment variable 'name' on the host platform.
|
||||||
"""
|
"""
|
||||||
return repository_ctx.os.environ.get(name)
|
if name in repository_ctx.os.environ:
|
||||||
|
return repository_ctx.os.environ.get(name).strip()
|
||||||
|
|
||||||
|
if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ:
|
||||||
|
return repository_ctx.attr.environ.get(name).strip()
|
||||||
|
|
||||||
|
return default_value
|
||||||
|
|
||||||
def is_windows(repository_ctx):
|
def is_windows(repository_ctx):
|
||||||
"""Returns true if the execution platform is Windows.
|
"""Returns true if the execution platform is Windows.
|
||||||
|
|
|
@ -15,6 +15,7 @@ load(
|
||||||
load(
|
load(
|
||||||
"//third_party/remote_config:common.bzl",
|
"//third_party/remote_config:common.bzl",
|
||||||
"get_cpu_value",
|
"get_cpu_value",
|
||||||
|
"get_host_environ",
|
||||||
)
|
)
|
||||||
|
|
||||||
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
|
||||||
|
@ -72,14 +73,14 @@ def _create_dummy_repository(repository_ctx):
|
||||||
|
|
||||||
def enable_tensorrt(repository_ctx):
|
def enable_tensorrt(repository_ctx):
|
||||||
"""Returns whether to build with TensorRT support."""
|
"""Returns whether to build with TensorRT support."""
|
||||||
return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False))
|
return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
|
||||||
|
|
||||||
def _tensorrt_configure_impl(repository_ctx):
|
def _tensorrt_configure_impl(repository_ctx):
|
||||||
"""Implementation of the tensorrt_configure repository rule."""
|
"""Implementation of the tensorrt_configure repository rule."""
|
||||||
|
|
||||||
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
|
if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
|
||||||
# Forward to the pre-configured remote repository.
|
# Forward to the pre-configured remote repository.
|
||||||
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
|
remote_config_repo = get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO)
|
||||||
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
|
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
|
||||||
repository_ctx.template(
|
repository_ctx.template(
|
||||||
"build_defs.bzl",
|
"build_defs.bzl",
|
||||||
|
|
Loading…
Reference in New Issue