Add android rule helpers and cleanup input loops

This change teaches the configure script how to search for Android NDK
and SDK installations and create new WORKSPACE rules pointing to them.
It also refactors many similar loop-over-user-input functions into using
a reusable method (not the more complex ones).

Specifying an SDK directory will further query for the available SDK API
levels and build tools versions, but it won't perform any compatibility
checks.

Like other settings, every android-related setting can be set beforehand
via an env param. The script will not ask for any Android settings if
there are already any android repository rules in the WORKSPACE.

The script will emit a warning if using an NDK version newer than 14 due
to https://github.com/bazelbuild/bazel/issues/4068.

PiperOrigin-RevId: 177989785
This commit is contained in:
Austin Anderson 2017-12-05 11:59:17 -08:00 committed by TensorFlower Gardener
parent 21e831dc4a
commit 6affacedbb

View File

@ -34,6 +34,8 @@ except ImportError:
_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'.tf_configure.bazelrc')
_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'WORKSPACE')
_DEFAULT_CUDA_VERSION = '8.0'
_DEFAULT_CUDNN_VERSION = '6'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
@ -44,6 +46,13 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing '
_TF_OPENCL_VERSION = '1.2'
_DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp'
_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include'
_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
class UserInputError(Exception):
pass
def is_windows():
@ -158,7 +167,7 @@ def get_python_path(environ_cp, python_bin_path):
try:
library_paths = run_shell(
[python_bin_path, '-c',
'import site; print("\\n".join(site.getsitepackages()))']).split("\n")
'import site; print("\\n".join(site.getsitepackages()))']).split('\n')
except subprocess.CalledProcessError:
library_paths = [run_shell(
[python_bin_path, '-c',
@ -557,6 +566,218 @@ def set_clang_cuda_compiler_path(environ_cp):
clang_cuda_compiler_path)
def prompt_loop_or_load_from_env(
environ_cp,
var_name,
var_default,
ask_for_var,
check_success,
error_msg,
suppress_default_error=False,
n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS
):
"""Loop over user prompts for an ENV param until receiving a valid response.
For the env param var_name, read from the environment or verify user input
until receiving valid input. When done, set var_name in the environ_cp to its
new value.
Args:
environ_cp: (Dict) copy of the os.environ.
var_name: (String) string for name of environment variable, e.g. "TF_MYVAR".
var_default: (String) default value string.
ask_for_var: (String) string for how to ask for user input.
check_success: (Function) function that takes one argument and returns a
boolean. Should return True if the value provided is considered valid. May
contain a complex error message if error_msg does not provide enough
information. In that case, set suppress_default_error to True.
error_msg: (String) String with one and only one '%s'. Formatted with each
invalid response upon check_success(input) failure.
suppress_default_error: (Bool) Suppress the above error message in favor of
one from the check_success function.
n_ask_attempts: (Integer) Number of times to query for valid input before
raising an error and quitting.
Returns:
[String] The value of var_name after querying for input.
Raises:
UserInputError: if a query has been attempted n_ask_attempts times without
success, assume that the user has made a scripting error, and will continue
to provide invalid input. Raise the error to avoid infinitely looping.
"""
default = environ_cp.get(var_name) or var_default
full_query = '%s [Default is %s]: ' % (
ask_for_var,
default,
)
for _ in range(n_ask_attempts):
val = get_from_env_or_user_or_default(environ_cp,
var_name,
full_query,
default)
if check_success(val):
break
if not suppress_default_error:
print(error_msg % val)
environ_cp[var_name] = ''
else:
raise UserInputError('Invalid %s setting was provided %d times in a row. '
'Assuming to be a scripting mistake.' %
(var_name, n_ask_attempts))
environ_cp[var_name] = val
return val
def create_android_ndk_rule(environ_cp):
"""Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule."""
if is_windows() or is_cygwin():
default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' %
environ_cp['APPDATA'])
elif is_macos():
default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
else:
default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME']
def valid_ndk_path(path):
return (os.path.exists(path) and
os.path.exists(os.path.join(path, 'source.properties')))
android_ndk_home_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_NDK_HOME',
var_default=default_ndk_path,
ask_for_var='Please specify the home path of the Android NDK to use.',
check_success=valid_ndk_path,
error_msg=('The path %s or its child file "source.properties" '
'does not exist.')
)
write_android_ndk_workspace_rule(android_ndk_home_path)
def create_android_sdk_rule(environ_cp):
"""Set Android variables and write Android SDK WORKSPACE rule."""
if is_windows() or is_cygwin():
default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA'])
elif is_macos():
default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME']
else:
default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME']
def valid_sdk_path(path):
return (os.path.exists(path) and
os.path.exists(os.path.join(path, 'platforms')) and
os.path.exists(os.path.join(path, 'build-tools')))
android_sdk_home_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_SDK_HOME',
var_default=default_sdk_path,
ask_for_var='Please specify the home path of the Android SDK to use.',
check_success=valid_sdk_path,
error_msg=('Either %s does not exist, or it does not contain the '
'subdirectories "platforms" and "build-tools".'))
platforms = os.path.join(android_sdk_home_path, 'platforms')
api_levels = sorted(os.listdir(platforms))
api_levels = [x.replace('android-', '') for x in api_levels]
def valid_api_level(api_level):
return os.path.exists(os.path.join(android_sdk_home_path,
'platforms',
'android-' + api_level))
android_api_level = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_API_LEVEL',
var_default=api_levels[-1],
ask_for_var=('Please specify the Android SDK API level to use. '
'[Available levels: %s]') % api_levels,
check_success=valid_api_level,
error_msg='Android-%s is not present in the SDK path.')
build_tools = os.path.join(android_sdk_home_path, 'build-tools')
versions = sorted(os.listdir(build_tools))
def valid_build_tools(version):
return os.path.exists(os.path.join(android_sdk_home_path,
'build-tools',
version))
android_build_tools_version = prompt_loop_or_load_from_env(
environ_cp,
var_name='ANDROID_BUILD_TOOLS_VERSION',
var_default=versions[-1],
ask_for_var=('Please specify an Android build tools version to use. '
'[Available versions: %s]') % versions,
check_success=valid_build_tools,
error_msg=('The selected SDK does not have build-tools version %s '
'available.'))
write_android_sdk_workspace_rule(android_sdk_home_path,
android_build_tools_version,
android_api_level)
def write_android_sdk_workspace_rule(android_sdk_home_path,
android_build_tools_version,
android_api_level):
print('Writing android_sdk_workspace rule.\n')
with open(_TF_WORKSPACE, 'a') as f:
f.write("""
android_sdk_repository(
name="androidsdk",
api_level=%s,
path="%s",
build_tools_version="%s")\n
""" % (android_api_level, android_sdk_home_path, android_build_tools_version))
def write_android_ndk_workspace_rule(android_ndk_home_path):
print('Writing android_ndk_workspace rule.')
ndk_api_level = check_ndk_level(android_ndk_home_path)
if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS:
print('WARNING: The API level of the NDK in %s is %s, which is not '
'supported by Bazel (officially supported versions: %s). Please use '
'another version. Compiling Android targets may result in confusing '
'errors.\n' % (android_ndk_home_path, ndk_api_level,
_SUPPORTED_ANDROID_NDK_VERSIONS))
with open(_TF_WORKSPACE, 'a') as f:
f.write("""
android_ndk_repository(
name="androidndk",
path="%s",
api_level=%s)\n
""" % (android_ndk_home_path, ndk_api_level))
def check_ndk_level(android_ndk_home_path):
"""Check the revision number of an Android NDK path."""
properties_path = '%s/source.properties' % android_ndk_home_path
if is_windows() or is_cygwin():
properties_path = cygpath(properties_path)
with open(properties_path, 'r') as f:
filedata = f.read()
revision = re.search(r'Pkg.Revision = (\d+)', filedata)
if revision:
return revision.group(1)
return None
def workspace_has_any_android_rule():
"""Check the WORKSPACE for existing android_*_repository rules."""
with open(_TF_WORKSPACE, 'r') as f:
workspace = f.read()
has_any_rule = re.search(r'^android_[ns]dk_repository',
workspace,
re.MULTILINE)
return has_any_rule
def set_gcc_host_compiler_path(environ_cp):
"""Set GCC_HOST_COMPILER_PATH."""
default_gcc_host_compiler_path = which('gcc') or ''
@ -566,23 +787,16 @@ def set_gcc_host_compiler_path(environ_cp):
# os.readlink is only available in linux
default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink)
ask_gcc_path = (
'Please specify which gcc should be used by nvcc as the '
'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path
while True:
gcc_host_compiler_path = get_from_env_or_user_or_default(
environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path,
default_gcc_host_compiler_path)
gcc_host_compiler_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='GCC_HOST_COMPILER_PATH',
var_default=default_gcc_host_compiler_path,
ask_for_var=
'Please specify which gcc should be used by nvcc as the host compiler.',
check_success=os.path.exists,
error_msg='Invalid gcc path. %s cannot be found.',
)
if os.path.exists(gcc_host_compiler_path):
break
# Reset and retry
print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path)
environ_cp['GCC_HOST_COMPILER_PATH'] = ''
# Set GCC_HOST_COMPILER_PATH
environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path
write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path)
@ -810,124 +1024,110 @@ def set_other_cuda_vars(environ_cp):
def set_host_cxx_compiler(environ_cp):
"""Set HOST_CXX_COMPILER."""
default_cxx_host_compiler = which('g++') or ''
ask_cxx_host_compiler = (
'Please specify which C++ compiler should be used as'
' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler
while True:
host_cxx_compiler = get_from_env_or_user_or_default(
environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler,
default_cxx_host_compiler)
if os.path.exists(host_cxx_compiler):
break
host_cxx_compiler = prompt_loop_or_load_from_env(
environ_cp,
var_name='HOST_CXX_COMPILER',
var_default=default_cxx_host_compiler,
ask_for_var=('Please specify which C++ compiler should be used as the '
'host C++ compiler.'),
check_success=os.path.exists,
error_msg='Invalid C++ compiler path. %s cannot be found.',
)
# Reset and retry
print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler)
environ_cp['HOST_CXX_COMPILER'] = ''
# Set HOST_CXX_COMPILER
environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler
write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler)
def set_host_c_compiler(environ_cp):
"""Set HOST_C_COMPILER."""
default_c_host_compiler = which('gcc') or ''
ask_c_host_compiler = (
'Please specify which C compiler should be used as the'
' host C compiler. [Default is %s]: ') % default_c_host_compiler
while True:
host_c_compiler = get_from_env_or_user_or_default(
environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler,
default_c_host_compiler)
if os.path.exists(host_c_compiler):
break
host_c_compiler = prompt_loop_or_load_from_env(
environ_cp,
var_name='HOST_C_COMPILER',
var_default=default_c_host_compiler,
ask_for_var=('Please specify which C compiler should be used as the host'
'C compiler.'),
check_success=os.path.exists,
error_msg='Invalid C compiler path. %s cannot be found.',
)
# Reset and retry
print('Invalid C compiler path. %s cannot be found' % host_c_compiler)
environ_cp['HOST_C_COMPILER'] = ''
# Set HOST_C_COMPILER
environ_cp['HOST_C_COMPILER'] = host_c_compiler
write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler)
def set_computecpp_toolkit_path(environ_cp):
"""Set COMPUTECPP_TOOLKIT_PATH."""
ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp '
'for SYCL %s is installed. [Default is %s]: '
) % (_TF_OPENCL_VERSION,
_DEFAULT_COMPUTECPP_TOOLKIT_PATH)
while True:
computecpp_toolkit_path = get_from_env_or_user_or_default(
environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path,
_DEFAULT_COMPUTECPP_TOOLKIT_PATH)
def toolkit_exists(toolkit_path):
"""Check if a computecpp toolkit path is valid."""
if is_linux():
sycl_rt_lib_path = 'lib/libComputeCpp.so'
else:
sycl_rt_lib_path = ''
sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path,
sycl_rt_lib_path_full = os.path.join(toolkit_path,
sycl_rt_lib_path)
if os.path.exists(sycl_rt_lib_path_full):
break
exists = os.path.exists(sycl_rt_lib_path_full)
if not exists:
print('Invalid SYCL %s library path. %s cannot be found' %
(_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
return exists
print('Invalid SYCL %s library path. %s cannot be found' %
(_TF_OPENCL_VERSION, sycl_rt_lib_path_full))
environ_cp['COMPUTECPP_TOOLKIT_PATH'] = ''
computecpp_toolkit_path = prompt_loop_or_load_from_env(
environ_cp,
var_name='COMPUTECPP_TOOLKIT_PATH',
var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH,
ask_for_var=(
'Please specify the location where ComputeCpp for SYCL %s is '
'installed.' % _TF_OPENCL_VERSION),
check_success=toolkit_exists,
error_msg='Invalid SYCL compiler path. %s cannot be found.',
suppress_default_error=True)
# Set COMPUTECPP_TOOLKIT_PATH
environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
computecpp_toolkit_path)
def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break
print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))
trisycl_include_dir = prompt_loop_or_load_from_env(
environ_cp,
var_name='TRISYCL_INCLUDE_DIR',
var_default=_DEFAULT_TRISYCL_INCLUDE_DIR,
ask_for_var=('Please specify the location of the triSYCL include '
'directory. (Use --config=sycl_trisycl when building with '
'Bazel)'),
check_success=os.path.exists,
error_msg='Invalid trySYCL include directory. %s cannot be found.',
suppress_default_error=True)
# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)
def set_mpi_home(environ_cp):
"""Set MPI_HOME."""
default_mpi_home = which('mpirun') or which('mpiexec') or ''
default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home))
ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: '
) % default_mpi_home
while True:
mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME',
ask_mpi_home, default_mpi_home)
def valid_mpi_path(mpi_home):
exists = (os.path.exists(os.path.join(mpi_home, 'include')) and
os.path.exists(os.path.join(mpi_home, 'lib')))
if not exists:
print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
(os.path.join(mpi_home, 'include'),
os.path.exists(os.path.join(mpi_home, 'lib'))))
return exists
if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists(
os.path.join(mpi_home, 'lib')):
break
print('Invalid path to the MPI Toolkit. %s or %s cannot be found' %
(os.path.join(mpi_home, 'include'),
os.path.exists(os.path.join(mpi_home, 'lib'))))
environ_cp['MPI_HOME'] = ''
# Set MPI_HOME
environ_cp['MPI_HOME'] = str(mpi_home)
_ = prompt_loop_or_load_from_env(
environ_cp,
var_name='MPI_HOME',
var_default=default_mpi_home,
ask_for_var='Please specify the MPI toolkit folder.',
check_success=valid_mpi_path,
error_msg='',
suppress_default_error=True)
def set_other_mpi_vars(environ_cp):
@ -970,7 +1170,7 @@ def set_mkl():
'support.\nPlease note that MKL on MacOS or windows is still not '
'supported.\nIf you would like to use a local MKL instead of '
'downloading, please set the environment variable \"TF_MKL_ROOT\" every '
'time before build.')
'time before build.\n')
def set_monolithic():
@ -1082,5 +1282,22 @@ def main():
set_monolithic()
create_android_bazelrc_configs()
if workspace_has_any_android_rule():
print('The WORKSPACE file has at least one of ["android_sdk_repository", '
'"android_ndk_repository"] already set. Will not ask to help '
'configure the WORKSPACE. Please delete the existing rules to '
'activate the helper.\n')
else:
if get_var(
environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace',
False,
('Would you like to interactively configure ./WORKSPACE for '
'Android builds?'),
'Searching for NDK and SDK installations.',
'Not configuring the WORKSPACE for Android builds.'):
create_android_ndk_rule(environ_cp)
create_android_sdk_rule(environ_cp)
if __name__ == '__main__':
main()