remote config: replace all uses of os.environ by get_host_environ

This change is in prepartion for rolling out remote config. It will
allow us to inject environment variables from repository rules as
well as from the shell enviroment.

PiperOrigin-RevId: 295782466
Change-Id: I1eb61fca3556473e94f2f12c45ee5eb1fe51625b
This commit is contained in:
Jakob Buchgraber 2020-02-18 11:30:58 -08:00 committed by TensorFlower Gardener
parent 24839fe95c
commit f60fc7a072
5 changed files with 54 additions and 58 deletions

View File

@ -43,6 +43,7 @@ load(
"execute", "execute",
"get_bash_bin", "get_bash_bin",
"get_cpu_value", "get_cpu_value",
"get_host_environ",
"get_python_bin", "get_python_bin",
"is_windows", "is_windows",
"raw_exec", "raw_exec",
@ -223,10 +224,9 @@ def find_cc(repository_ctx):
cc_path_envvar = _GCC_HOST_COMPILER_PATH cc_path_envvar = _GCC_HOST_COMPILER_PATH
cc_name = target_cc_name cc_name = target_cc_name
if cc_path_envvar in repository_ctx.os.environ: cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() if cc_name_from_env:
if cc_name_from_env: cc_name = cc_name_from_env
cc_name = cc_name_from_env
if cc_name.startswith("/"): if cc_name.startswith("/"):
# Absolute path, maybe we should make this supported by our which function. # Absolute path, maybe we should make this supported by our which function.
return cc_name return cc_name
@ -365,7 +365,7 @@ def _cuda_include_path(repository_ctx, cuda_config):
def enable_cuda(repository_ctx): def enable_cuda(repository_ctx):
"""Returns whether to build with CUDA support.""" """Returns whether to build with CUDA support."""
return int(repository_ctx.os.environ.get("TF_NEED_CUDA", False)) return int(get_host_environ(repository_ctx, "TF_NEED_CUDA", False))
def matches_version(environ_version, detected_version): def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version. """Checks whether the user-specified version matches the detected version.
@ -409,9 +409,9 @@ _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
def compute_capabilities(repository_ctx): def compute_capabilities(repository_ctx):
"""Returns a list of strings representing cuda compute capabilities.""" """Returns a list of strings representing cuda compute capabilities."""
if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ: capabilities_str = get_host_environ(repository_ctx, _TF_CUDA_COMPUTE_CAPABILITIES)
if capabilities_str == None:
return _DEFAULT_CUDA_COMPUTE_CAPABILITIES return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
capabilities = capabilities_str.split(",") capabilities = capabilities_str.split(",")
for capability in capabilities: for capability in capabilities:
# Workaround for Skylark's lack of support for regex. This check should # Workaround for Skylark's lack of support for regex. This check should
@ -805,18 +805,13 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
)""" % (name, "\n".join(outs), src_dir, out_dir) )""" % (name, "\n".join(outs), src_dir, out_dir)
def _flag_enabled(repository_ctx, flag_name): def _flag_enabled(repository_ctx, flag_name):
if flag_name in repository_ctx.os.environ: return get_host_environ(repository_ctx, flag_name) == "1"
value = repository_ctx.os.environ[flag_name].strip()
return value == "1"
return False
def _use_cuda_clang(repository_ctx): def _use_cuda_clang(repository_ctx):
return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
def _tf_sysroot(repository_ctx): def _tf_sysroot(repository_ctx):
if _TF_SYSROOT in repository_ctx.os.environ: return get_host_environ(repository_ctx, _TF_SYSROOT, "")
return repository_ctx.os.environ[_TF_SYSROOT]
return ""
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
capability_flags = [ capability_flags = [
@ -1006,9 +1001,10 @@ def _create_local_cuda_repository(repository_ctx):
if is_cuda_clang: if is_cuda_clang:
cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"] cuda_defines["%{cuda_toolkit_path}"] = cuda_config.config["cuda_toolkit_path"]
host_compiler_prefix = "/usr/bin" host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX)
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ: if not host_compiler_prefix:
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip() host_compiler_prefix = "/usr/bin"
cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix cuda_defines["%{host_compiler_prefix}"] = host_compiler_prefix
# Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
@ -1157,14 +1153,15 @@ def _cuda_autoconf_impl(repository_ctx):
"""Implementation of the cuda_autoconf repository rule.""" """Implementation of the cuda_autoconf repository rule."""
if not enable_cuda(repository_ctx): if not enable_cuda(repository_ctx):
_create_dummy_repository(repository_ctx) _create_dummy_repository(repository_ctx)
elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: elif get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO) != None:
if (_TF_CUDA_VERSION not in repository_ctx.os.environ or has_cuda_version = get_host_environ(repository_ctx, _TF_CUDA_VERSION) != None
_TF_CUDNN_VERSION not in repository_ctx.os.environ): has_cudnn_version = get_host_environ(repository_ctx, _TF_CUDNN_VERSION) != None
if not has_cuda_version or not has_cudnn_version:
auto_configure_fail("%s and %s must also be set if %s is specified" % auto_configure_fail("%s and %s must also be set if %s is specified" %
(_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO)) (_TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_CONFIG_REPO))
_create_remote_cuda_repository( _create_remote_cuda_repository(
repository_ctx, repository_ctx,
repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO], get_host_environ(repository_ctx, _TF_CUDA_CONFIG_REPO),
) )
else: else:
_create_local_cuda_repository(repository_ctx) _create_local_cuda_repository(repository_ctx)

View File

@ -26,6 +26,7 @@ load(
"files_exist", "files_exist",
"get_bash_bin", "get_bash_bin",
"get_cpu_value", "get_cpu_value",
"get_host_environ",
"raw_exec", "raw_exec",
"realpath", "realpath",
"which", "which",
@ -79,10 +80,9 @@ def find_cc(repository_ctx):
cc_path_envvar = _GCC_HOST_COMPILER_PATH cc_path_envvar = _GCC_HOST_COMPILER_PATH
cc_name = target_cc_name cc_name = target_cc_name
if cc_path_envvar in repository_ctx.os.environ: cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar)
cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip() if cc_name_from_env:
if cc_name_from_env: cc_name = cc_name_from_env
cc_name = cc_name_from_env
if cc_name.startswith("/"): if cc_name.startswith("/"):
# Absolute path, maybe we should make this supported by our which function. # Absolute path, maybe we should make this supported by our which function.
return cc_name return cc_name
@ -252,13 +252,12 @@ def _rocm_include_path(repository_ctx, rocm_config):
return inc_dirs return inc_dirs
def _enable_rocm(repository_ctx): def _enable_rocm(repository_ctx):
if "TF_NEED_ROCM" in repository_ctx.os.environ: enable_rocm = get_host_environ(repository_ctx, "TF_NEED_ROCM")
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip() if enable_rocm == "1":
if enable_rocm == "1": if get_cpu_value(repository_ctx) != "Linux":
if get_cpu_value(repository_ctx) != "Linux": auto_configure_warning("ROCm configure is only supported on Linux")
auto_configure_warning("ROCm configure is only supported on Linux") return False
return False return True
return True
return False return False
def _rocm_toolkit_path(repository_ctx, bash_bin): def _rocm_toolkit_path(repository_ctx, bash_bin):
@ -270,18 +269,16 @@ def _rocm_toolkit_path(repository_ctx, bash_bin):
Returns: Returns:
A speculative real path of the rocm toolkit install directory. A speculative real path of the rocm toolkit install directory.
""" """
rocm_toolkit_path = _DEFAULT_ROCM_TOOLKIT_PATH rocm_toolkit_path = get_host_environ(repository_ctx, _ROCM_TOOLKIT_PATH, _DEFAULT_ROCM_TOOLKIT_PATH)
if _ROCM_TOOLKIT_PATH in repository_ctx.os.environ:
rocm_toolkit_path = repository_ctx.os.environ[_ROCM_TOOLKIT_PATH].strip()
if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]: if files_exist(repository_ctx, [rocm_toolkit_path], bash_bin) != [True]:
auto_configure_fail("Cannot find rocm toolkit path.") auto_configure_fail("Cannot find rocm toolkit path.")
return realpath(repository_ctx, rocm_toolkit_path, bash_bin) return realpath(repository_ctx, rocm_toolkit_path, bash_bin)
def _amdgpu_targets(repository_ctx): def _amdgpu_targets(repository_ctx):
"""Returns a list of strings representing AMDGPU targets.""" """Returns a list of strings representing AMDGPU targets."""
if _TF_ROCM_AMDGPU_TARGETS not in repository_ctx.os.environ: amdgpu_targets_str = get_host_environ(repository_ctx, _TF_ROCM_AMDGPU_TARGETS)
if not amdgpu_targets_str:
return _DEFAULT_ROCM_AMDGPU_TARGETS return _DEFAULT_ROCM_AMDGPU_TARGETS
amdgpu_targets_str = repository_ctx.os.environ[_TF_ROCM_AMDGPU_TARGETS]
amdgpu_targets = amdgpu_targets_str.split(",") amdgpu_targets = amdgpu_targets_str.split(",")
for amdgpu_target in amdgpu_targets: for amdgpu_target in amdgpu_targets:
if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit(): if amdgpu_target[:3] != "gfx" or not amdgpu_target[3:].isdigit():
@ -308,9 +305,9 @@ def _hipcc_env(repository_ctx):
"HCC_AMDGPU_TARGET", "HCC_AMDGPU_TARGET",
"HIP_PLATFORM", "HIP_PLATFORM",
]: ]:
if name in repository_ctx.os.environ: env_value = get_host_environ(repository_ctx, name)
hipcc_env = (hipcc_env + " " + name + "=\"" + if env_value:
repository_ctx.os.environ[name].strip() + "\";") hipcc_env = (hipcc_env + " " + name + "=\"" + env_value + "\";")
return hipcc_env.strip() return hipcc_env.strip()
def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin): def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
@ -328,7 +325,7 @@ def _hipcc_is_hipclang(repository_ctx, rocm_config, bash_bin):
# check user-defined hip-clang environment variables # check user-defined hip-clang environment variables
for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]: for name in ["HIP_CLANG_PATH", "HIP_VDI_HOME"]:
if name in repository_ctx.os.environ: if get_host_environ(repository_ctx, name):
return "True" return "True"
# grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo # grep for "HIP_COMPILER=clang" in /opt/rocm/hip/lib/.hipInfo
@ -367,10 +364,7 @@ def _crosstool_verbose(repository_ctx):
Returns: Returns:
A string containing value of environment variable CROSSTOOL_VERBOSE. A string containing value of environment variable CROSSTOOL_VERBOSE.
""" """
name = "CROSSTOOL_VERBOSE" return get_host_environ(repository_ctx, "CROSSTOOL_VERBOSE", "0")
if name in repository_ctx.os.environ:
return repository_ctx.os.environ[name].strip()
return "0"
def _lib_name(lib, version = "", static = False): def _lib_name(lib, version = "", static = False):
"""Constructs the name of a library on Linux. """Constructs the name of a library on Linux.
@ -701,9 +695,7 @@ def _create_local_rocm_repository(repository_ctx):
host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc)
host_compiler_prefix = "/usr/bin" host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin")
if _GCC_HOST_COMPILER_PREFIX in repository_ctx.os.environ:
host_compiler_prefix = repository_ctx.os.environ[_GCC_HOST_COMPILER_PREFIX].strip()
rocm_defines = {} rocm_defines = {}
@ -823,10 +815,10 @@ def _rocm_autoconf_impl(repository_ctx):
"""Implementation of the rocm_autoconf repository rule.""" """Implementation of the rocm_autoconf repository rule."""
if not _enable_rocm(repository_ctx): if not _enable_rocm(repository_ctx):
_create_dummy_repository(repository_ctx) _create_dummy_repository(repository_ctx)
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ: elif get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO) != None:
_create_remote_rocm_repository( _create_remote_rocm_repository(
repository_ctx, repository_ctx,
repository_ctx.os.environ[_TF_ROCM_CONFIG_REPO], get_host_environ(repository_ctx, _TF_ROCM_CONFIG_REPO),
) )
else: else:
_create_local_rocm_repository(repository_ctx) _create_local_rocm_repository(repository_ctx)

View File

@ -20,6 +20,7 @@ load(
load( load(
"//third_party/remote_config:common.bzl", "//third_party/remote_config:common.bzl",
"get_cpu_value", "get_cpu_value",
"get_host_environ",
) )
_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH" _CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
@ -76,9 +77,8 @@ def _nccl_configure_impl(repository_ctx):
# See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778 # See https://github.com/tensorflow/tensorflow/commit/62bd3534525a036f07d9851b3199d68212904778
find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py")) find_cuda_config_path = repository_ctx.path(Label("@org_tensorflow//third_party/gpus:find_cuda_config.py"))
nccl_version = "" nccl_version = get_host_environ(repository_ctx, _TF_NCCL_VERSION, "")
if _TF_NCCL_VERSION in repository_ctx.os.environ: if nccl_version:
nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
nccl_version = nccl_version.split(".")[0] nccl_version = nccl_version.split(".")[0]
cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"]) cuda_config = find_cuda_config(repository_ctx, find_cuda_config_path, ["cuda"])

View File

@ -135,7 +135,7 @@ def get_environ(repository_ctx, name, default_value = None):
return default_value return default_value
return result.stdout return result.stdout
def get_host_environ(repository_ctx, name): def get_host_environ(repository_ctx, name, default_value = None):
"""Returns the value of an environment variable on the host platform. """Returns the value of an environment variable on the host platform.
The host platform is the machine that Bazel runs on. The host platform is the machine that Bazel runs on.
@ -147,7 +147,13 @@ def get_host_environ(repository_ctx, name):
Returns: Returns:
The value of the environment variable 'name' on the host platform. The value of the environment variable 'name' on the host platform.
""" """
return repository_ctx.os.environ.get(name) if name in repository_ctx.os.environ:
return repository_ctx.os.environ.get(name).strip()
if hasattr(repository_ctx.attr, "environ") and name in repository_ctx.attr.environ:
return repository_ctx.attr.environ.get(name).strip()
return default_value
def is_windows(repository_ctx): def is_windows(repository_ctx):
"""Returns true if the execution platform is Windows. """Returns true if the execution platform is Windows.

View File

@ -15,6 +15,7 @@ load(
load( load(
"//third_party/remote_config:common.bzl", "//third_party/remote_config:common.bzl",
"get_cpu_value", "get_cpu_value",
"get_host_environ",
) )
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
@ -72,14 +73,14 @@ def _create_dummy_repository(repository_ctx):
def enable_tensorrt(repository_ctx): def enable_tensorrt(repository_ctx):
"""Returns whether to build with TensorRT support.""" """Returns whether to build with TensorRT support."""
return int(repository_ctx.os.environ.get(_TF_NEED_TENSORRT, False)) return int(get_host_environ(repository_ctx, _TF_NEED_TENSORRT, False))
def _tensorrt_configure_impl(repository_ctx): def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule.""" """Implementation of the tensorrt_configure repository rule."""
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ: if get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO) != None:
# Forward to the pre-configured remote repository. # Forward to the pre-configured remote repository.
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] remote_config_repo = get_host_environ(repository_ctx, _TF_TENSORRT_CONFIG_REPO)
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {}) repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
repository_ctx.template( repository_ctx.template(
"build_defs.bzl", "build_defs.bzl",