Formatting BUILD and bzl files.

PiperOrigin-RevId: 234543907
This commit is contained in:
A. Unique TensorFlower 2019-02-18 22:19:08 -08:00 committed by TensorFlower Gardener
parent 6f7a0a83c1
commit a0b0a50328
26 changed files with 2088 additions and 2032 deletions

View File

@ -27,4 +27,3 @@ plugins = {
# "deps":[], # "deps":[],
#}, #},
} }

View File

@ -1,12 +1,11 @@
"""build_defs for service/cpu.""" """build_defs for service/cpu."""
def runtime_copts(): def runtime_copts():
"""Returns copts used for CPU runtime libraries.""" """Returns copts used for CPU runtime libraries."""
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
"//tensorflow:android_arm": ["-mfpu=neon"], "//tensorflow:android_arm": ["-mfpu=neon"],
"//conditions:default": [] "//conditions:default": [],
}) + select({ }) + select({
"//tensorflow:android": ["-O2"], "//tensorflow:android": ["-O2"],
"//conditions:default": [] "//conditions:default": [],
})) }))

View File

@ -33,4 +33,3 @@
# } # }
plugins = {} plugins = {}

View File

@ -16,8 +16,7 @@ def cuda_library_path(name, version = cuda_sdk_version()):
return "lib/lib{}.dylib".format(name) return "lib/lib{}.dylib".format(name)
else: else:
return "lib/lib{}.{}.dylib".format(name, version) return "lib/lib{}.{}.dylib".format(name, version)
else: elif not version:
if not version:
return "lib64/lib{}.so".format(name) return "lib64/lib{}.so".format(name)
else: else:
return "lib64/lib{}.so.{}".format(name, version) return "lib64/lib{}.so.{}".format(name, version)
@ -34,8 +33,7 @@ def cudnn_library_path(version = cudnn_sdk_version()):
return "lib/libcudnn.dylib" return "lib/libcudnn.dylib"
else: else:
return "lib/libcudnn.{}.dylib".format(version) return "lib/libcudnn.{}.dylib".format(version)
else: elif not version:
if not version:
return "lib64/libcudnn.so" return "lib64/libcudnn.so"
else: else:
return "lib64/libcudnn.so.{}".format(version) return "lib64/libcudnn.so.{}".format(version)
@ -46,8 +44,7 @@ def cupti_library_path(version = cuda_sdk_version()):
return "extras/CUPTI/lib/libcupti.dylib" return "extras/CUPTI/lib/libcupti.dylib"
else: else:
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version) return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
else: elif not version:
if not version:
return "extras/CUPTI/lib64/libcupti.so" return "extras/CUPTI/lib64/libcupti.so"
else: else:
return "extras/CUPTI/lib64/libcupti.so.{}".format(version) return "extras/CUPTI/lib64/libcupti.so.{}".format(version)

View File

@ -17,14 +17,14 @@ load(
# and then archive those source files into # and then archive those source files into
# ops/gen_sources.srcjar # ops/gen_sources.srcjar
# #
def tf_java_op_gen_srcjar(name, def tf_java_op_gen_srcjar(
name,
gen_tool, gen_tool,
base_package, base_package,
api_def_srcs = [], api_def_srcs = [],
out_dir = "ops/", out_dir = "ops/",
out_src_dir = "src/main/java/", out_src_dir = "src/main/java/",
visibility = ["//tensorflow/java:__pkg__"]): visibility = ["//tensorflow/java:__pkg__"]):
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
srcs = api_def_srcs[:] srcs = api_def_srcs[:]
@ -38,7 +38,8 @@ def tf_java_op_gen_srcjar(name,
# same directory. # same directory.
api_def_args.append( api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src + "$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))") ") | cut -d\" \" -f1))",
)
api_def_args_str = ",".join(api_def_args) api_def_args_str = ",".join(api_def_args)
gen_cmds += ["$(location " + gen_tool + ")" + gen_cmds += ["$(location " + gen_tool + ")" +
@ -57,6 +58,7 @@ def tf_java_op_gen_srcjar(name,
tools = [ tools = [
"@local_jdk//:jar", "@local_jdk//:jar",
"@local_jdk//:jdk", "@local_jdk//:jdk",
gen_tool gen_tool,
] + tf_binary_additional_srcs(), ] + tf_binary_additional_srcs(),
cmd=" && ".join(gen_cmds)) cmd = " && ".join(gen_cmds),
)

View File

@ -12,7 +12,10 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
# consumers of the tf_gen_op_wrapper_py rule would be simplified if we don't # consumers of the tf_gen_op_wrapper_py rule would be simplified if we don't
# hard code the ops/ directory. # hard code the ops/ directory.
def tf_gen_op_wrapper_private_py(name, out=None, deps=[], def tf_gen_op_wrapper_private_py(
name,
out = None,
deps = [],
require_shape_functions = True, require_shape_functions = True,
visibility = []): visibility = []):
if not name.endswith("_gen"): if not name.endswith("_gen"):
@ -20,7 +23,8 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
if not visibility: if not visibility:
visibility = ["//visibility:private"] visibility = ["//visibility:private"]
bare_op_name = name[:-4] # Strip off the _gen bare_op_name = name[:-4] # Strip off the _gen
tf_gen_op_wrapper_py(name=bare_op_name, tf_gen_op_wrapper_py(
name = bare_op_name,
out = out, out = out,
visibility = visibility, visibility = visibility,
deps = deps, deps = deps,

View File

@ -45,6 +45,7 @@ load(
"//third_party/ngraph:build_defs.bzl", "//third_party/ngraph:build_defs.bzl",
"if_ngraph", "if_ngraph",
) )
def register_extension_info(**kwargs): def register_extension_info(**kwargs):
pass pass

View File

@ -14,16 +14,20 @@ def tf_cc_logged_benchmark(
fail("Must provide a name") fail("Must provide a name")
if not target: if not target:
fail("Must provide a target") fail("Must provide a target")
if (not ":" in target if (not ":" in target or
or not target.startswith("//") not target.startswith("//") or
or target.endswith(":all") target.endswith(":all") or
or target.endswith(".")): target.endswith(".")):
fail(" ".join(("Target must be a single well-defined test, e.g.,", fail(" ".join((
"//path/to:test. Received: %s" % target))) "Target must be a single well-defined test, e.g.,",
"//path/to:test. Received: %s" % target,
)))
all_tags = ( all_tags = (
depset(tags) + depset( depset(tags) + depset(
["benchmark-test", "local", "manual", "regression-test"])).to_list() ["benchmark-test", "local", "manual", "regression-test"],
)
).to_list()
tf_py_test( tf_py_test(
name = name, name = name,
@ -41,8 +45,9 @@ def tf_cc_logged_benchmark(
], ],
main = "run_and_gather_logs.py", main = "run_and_gather_logs.py",
additional_deps = [ additional_deps = [
"//tensorflow/tools/test:run_and_gather_logs" "//tensorflow/tools/test:run_and_gather_logs",
]) ],
)
# Create a benchmark test target of a TensorFlow python test (*py_tests) # Create a benchmark test target of a TensorFlow python test (*py_tests)
def tf_py_logged_benchmark( def tf_py_logged_benchmark(
@ -60,4 +65,5 @@ def tf_py_logged_benchmark(
benchmarks = benchmarks, benchmarks = benchmarks,
tags = tags, tags = tags,
test_log_output_prefix = test_log_output_prefix, test_log_output_prefix = test_log_output_prefix,
benchmark_type="python_benchmark") benchmark_type = "python_benchmark",
)

View File

@ -7,7 +7,6 @@ load("//third_party:nccl/nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure") load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure")

View File

@ -40,14 +40,18 @@ def _android_autoconf_impl(repository_ctx):
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME) sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION) sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
build_tools_version = repository_ctx.os.environ.get( build_tools_version = repository_ctx.os.environ.get(
_ANDROID_BUILD_TOOLS_VERSION) _ANDROID_BUILD_TOOLS_VERSION,
)
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME) ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION) ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
sdk_rule = "pass" sdk_rule = "pass"
if all([sdk_home, sdk_api_level, build_tools_version]): if all([sdk_home, sdk_api_level, build_tools_version]):
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % ( sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
sdk_home, sdk_api_level, build_tools_version) sdk_home,
sdk_api_level,
build_tools_version,
)
ndk_rule = "pass" ndk_rule = "pass"
if all([ndk_home, ndk_api_level]): if all([ndk_home, ndk_api_level]):
@ -55,14 +59,16 @@ def _android_autoconf_impl(repository_ctx):
repository_ctx.template( repository_ctx.template(
"BUILD", "BUILD",
Label("//third_party/android:android_configure.BUILD.tpl")) Label("//third_party/android:android_configure.BUILD.tpl"),
)
repository_ctx.template( repository_ctx.template(
"android.bzl", "android.bzl",
Label("//third_party/android:android.bzl.tpl"), Label("//third_party/android:android.bzl.tpl"),
substitutions = { substitutions = {
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule, "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule, "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
}) },
)
android_configure = repository_rule( android_configure = repository_rule(
implementation = _android_autoconf_impl, implementation = _android_autoconf_impl,

View File

@ -24,34 +24,42 @@ def _get_python_bin(repository_ctx):
_fail("Cannot find python in PATH, please make sure " + _fail("Cannot find python in PATH, please make sure " +
"python is installed and add its directory in PATH, or --define " + "python is installed and add its directory in PATH, or --define " +
"%s='/something/else'.\nPATH=%s" % ( "%s='/something/else'.\nPATH=%s" % (
_PYTHON_BIN_PATH, repository_ctx.os.environ.get("PATH", ""))) _PYTHON_BIN_PATH,
repository_ctx.os.environ.get("PATH", ""),
))
def _git_conf_impl(repository_ctx): def _git_conf_impl(repository_ctx):
repository_ctx.template( repository_ctx.template(
"BUILD", "BUILD",
Label("//third_party/git:BUILD.tpl")) Label("//third_party/git:BUILD.tpl"),
)
tensorflow_root_path = str(repository_ctx.path( tensorflow_root_path = str(repository_ctx.path(
Label("@org_tensorflow//:BUILD")))[:-len("BUILD")] Label("@org_tensorflow//:BUILD"),
))[:-len("BUILD")]
python_script_path = repository_ctx.path( python_script_path = repository_ctx.path(
Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py")) Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py"),
)
generated_files_path = repository_ctx.path("gen") generated_files_path = repository_ctx.path("gen")
r = repository_ctx.execute( r = repository_ctx.execute(
["test", "-f", "%s/.git/logs/HEAD" % tensorflow_root_path]) ["test", "-f", "%s/.git/logs/HEAD" % tensorflow_root_path],
)
if r.return_code == 0: if r.return_code == 0:
unused_var = repository_ctx.path(Label("//:.git/HEAD")) # pylint: disable=unused-variable unused_var = repository_ctx.path(Label("//:.git/HEAD")) # pylint: disable=unused-variable
result = repository_ctx.execute([ result = repository_ctx.execute([
_get_python_bin(repository_ctx), _get_python_bin(repository_ctx),
python_script_path, "--configure", tensorflow_root_path, python_script_path,
"--gen_root_path", generated_files_path], quiet=False) "--configure",
tensorflow_root_path,
"--gen_root_path",
generated_files_path,
], quiet = False)
if not result.return_code == 0: if not result.return_code == 0:
_fail(result.stderr) _fail(result.stderr)
git_configure = repository_rule( git_configure = repository_rule(
implementation = _git_conf_impl, implementation = _git_conf_impl,
environ = [ environ = [

View File

@ -140,22 +140,23 @@ def _get_python_bin(repository_ctx):
"%s='/something/else'.\nPATH=%s" % ( "%s='/something/else'.\nPATH=%s" % (
_PYTHON_BIN_PATH, _PYTHON_BIN_PATH,
repository_ctx.os.environ.get("PATH", ""), repository_ctx.os.environ.get("PATH", ""),
)) ),
)
def _get_nvcc_tmp_dir_for_windows(repository_ctx): def _get_nvcc_tmp_dir_for_windows(repository_ctx):
"""Return the tmp directory for nvcc to generate intermediate source files.""" """Return the tmp directory for nvcc to generate intermediate source files."""
escaped_tmp_dir = escape_string( escaped_tmp_dir = escape_string(
get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
"\\", "\\\\"),) "\\",
"\\\\",
),
)
return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir" return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
def _get_msvc_compiler(repository_ctx): def _get_msvc_compiler(repository_ctx):
vc_path = find_vc_path(repository_ctx) vc_path = find_vc_path(repository_ctx)
return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/") return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
def _get_win_cuda_defines(repository_ctx): def _get_win_cuda_defines(repository_ctx):
"""Return CROSSTOOL defines for Windows""" """Return CROSSTOOL defines for Windows"""
@ -178,7 +179,7 @@ def _get_win_cuda_defines(repository_ctx):
if not vc_path: if not vc_path:
auto_configure_fail( auto_configure_fail(
"Visual C++ build tools not found on your machine." + "Visual C++ build tools not found on your machine." +
"Please check your installation following https://docs.bazel.build/versions/master/windows.html#using" "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using",
) )
return {} return {}
@ -188,46 +189,47 @@ def _get_win_cuda_defines(repository_ctx):
escaped_lib_paths = escape_string(env["LIB"]) escaped_lib_paths = escape_string(env["LIB"])
escaped_tmp_dir = escape_string( escaped_tmp_dir = escape_string(
get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace( get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
"\\", "\\\\"),) "\\",
"\\\\",
),
)
msvc_cl_path = _get_python_bin(repository_ctx) msvc_cl_path = _get_python_bin(repository_ctx)
msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace( msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
"\\", "/") "\\",
"/",
)
msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace( msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
"\\", "/") "\\",
"/",
)
msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace( msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
"\\", "/") "\\",
"/",
)
# nvcc will generate some temporary source files under %{nvcc_tmp_dir} # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
# The generated files are guranteed to have unique name, so they can share the same tmp directory # The generated files are guranteed to have unique name, so they can share the same tmp directory
escaped_cxx_include_directories = [ escaped_cxx_include_directories = [
"cxx_builtin_include_directory: \"%s\"" % "cxx_builtin_include_directory: \"%s\"" %
_get_nvcc_tmp_dir_for_windows(repository_ctx) _get_nvcc_tmp_dir_for_windows(repository_ctx),
] ]
for path in escaped_include_paths.split(";"): for path in escaped_include_paths.split(";"):
if path: if path:
escaped_cxx_include_directories.append( escaped_cxx_include_directories.append(
"cxx_builtin_include_directory: \"%s\"" % path) "cxx_builtin_include_directory: \"%s\"" % path,
)
return { return {
"%{msvc_env_tmp}": "%{msvc_env_tmp}": escaped_tmp_dir,
escaped_tmp_dir, "%{msvc_env_path}": escaped_paths,
"%{msvc_env_path}": "%{msvc_env_include}": escaped_include_paths,
escaped_paths, "%{msvc_env_lib}": escaped_lib_paths,
"%{msvc_env_include}": "%{msvc_cl_path}": msvc_cl_path,
escaped_include_paths, "%{msvc_ml_path}": msvc_ml_path,
"%{msvc_env_lib}": "%{msvc_link_path}": msvc_link_path,
escaped_lib_paths, "%{msvc_lib_path}": msvc_lib_path,
"%{msvc_cl_path}": "%{cxx_builtin_include_directory}": "\n".join(escaped_cxx_include_directories),
msvc_cl_path,
"%{msvc_ml_path}":
msvc_ml_path,
"%{msvc_link_path}":
msvc_link_path,
"%{msvc_lib_path}":
msvc_lib_path,
"%{cxx_builtin_include_directory}":
"\n".join(escaped_cxx_include_directories),
} }
# TODO(dzc): Once these functions have been factored out of Bazel's # TODO(dzc): Once these functions have been factored out of Bazel's
@ -261,7 +263,6 @@ def find_cc(repository_ctx):
" environment variable").format(target_cc_name, cc_path_envvar)) " environment variable").format(target_cc_name, cc_path_envvar))
return cc return cc
_INC_DIR_MARKER_BEGIN = "#include <...>" _INC_DIR_MARKER_BEGIN = "#include <...>"
# OSX add " (framework directory)" at the end of line, strip it. # OSX add " (framework directory)" at the end of line, strip it.
@ -275,7 +276,6 @@ def _cxx_inc_convert(path):
path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
return path return path
def _normalize_include_path(repository_ctx, path): def _normalize_include_path(repository_ctx, path):
"""Normalizes include paths before writing them to the crosstool. """Normalizes include paths before writing them to the crosstool.
@ -291,7 +291,6 @@ def _normalize_include_path(repository_ctx, path):
return path[len(crosstool_folder) + 1:] return path[len(crosstool_folder) + 1:]
return path return path
def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
"""Compute the list of default C or C++ include directories.""" """Compute the list of default C or C++ include directories."""
if lang_is_cpp: if lang_is_cpp:
@ -319,7 +318,6 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
for p in inc_dirs.split("\n") for p in inc_dirs.split("\n")
] ]
def get_cxx_inc_directories(repository_ctx, cc): def get_cxx_inc_directories(repository_ctx, cc):
"""Compute the list of default C and C++ include directories.""" """Compute the list of default C and C++ include directories."""
@ -331,10 +329,11 @@ def get_cxx_inc_directories(repository_ctx, cc):
includes_cpp_set = depset(includes_cpp) includes_cpp_set = depset(includes_cpp)
return includes_cpp + [ return includes_cpp + [
inc for inc in includes_c if inc not in includes_cpp_set inc
for inc in includes_c
if inc not in includes_cpp_set
] ]
def auto_configure_fail(msg): def auto_configure_fail(msg):
"""Output failure message when cuda configuration fails.""" """Output failure message when cuda configuration fails."""
red = "\033[0;31m" red = "\033[0;31m"
@ -361,7 +360,6 @@ def _host_compiler_includes(repository_ctx, cc):
inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir) inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
return "\n".join(inc_entries) return "\n".join(inc_entries)
def _cuda_include_path(repository_ctx, cuda_config): def _cuda_include_path(repository_ctx, cuda_config):
"""Generates the cxx_builtin_include_directory entries for cuda inc dirs. """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
@ -390,23 +388,25 @@ def _cuda_include_path(repository_ctx, cuda_config):
if one_line.startswith("#$ _TARGET_DIR_="): if one_line.startswith("#$ _TARGET_DIR_="):
target_dir = ( target_dir = (
cuda_config.cuda_toolkit_path + "/" + one_line.replace( cuda_config.cuda_toolkit_path + "/" + one_line.replace(
"#$ _TARGET_DIR_=", "") + "/include") "#$ _TARGET_DIR_=",
"",
) + "/include"
)
inc_entries = [] inc_entries = []
if target_dir != "": if target_dir != "":
inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir) inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
default_include = cuda_config.cuda_toolkit_path + "/include" default_include = cuda_config.cuda_toolkit_path + "/include"
inc_entries.append( inc_entries.append(
" cxx_builtin_include_directory: \"%s\"" % default_include) " cxx_builtin_include_directory: \"%s\"" % default_include,
)
return "\n".join(inc_entries) return "\n".join(inc_entries)
def enable_cuda(repository_ctx): def enable_cuda(repository_ctx):
if "TF_NEED_CUDA" in repository_ctx.os.environ: if "TF_NEED_CUDA" in repository_ctx.os.environ:
enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip() enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
return enable_cuda == "1" return enable_cuda == "1"
return False return False
def cuda_toolkit_path(repository_ctx): def cuda_toolkit_path(repository_ctx):
"""Finds the cuda toolkit directory. """Finds the cuda toolkit directory.
@ -423,7 +423,6 @@ def cuda_toolkit_path(repository_ctx):
auto_configure_fail("Cannot find cuda toolkit path.") auto_configure_fail("Cannot find cuda toolkit path.")
return str(repository_ctx.path(cuda_toolkit_path).realpath) return str(repository_ctx.path(cuda_toolkit_path).realpath)
def _cudnn_install_basedir(repository_ctx): def _cudnn_install_basedir(repository_ctx):
"""Finds the cudnn install directory.""" """Finds the cudnn install directory."""
cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
@ -433,7 +432,6 @@ def _cudnn_install_basedir(repository_ctx):
auto_configure_fail("Cannot find cudnn install path.") auto_configure_fail("Cannot find cudnn install path.")
return cudnn_install_path return cudnn_install_path
def matches_version(environ_version, detected_version): def matches_version(environ_version, detected_version):
"""Checks whether the user-specified version matches the detected version. """Checks whether the user-specified version matches the detected version.
@ -470,7 +468,6 @@ def matches_version(environ_version, detected_version):
return False return False
return True return True
_NVCC_VERSION_PREFIX = "Cuda compilation tools, release " _NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value): def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
@ -499,7 +496,8 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
if version_line.find(_NVCC_VERSION_PREFIX) == -1: if version_line.find(_NVCC_VERSION_PREFIX) == -1:
auto_configure_fail( auto_configure_fail(
"Could not parse CUDA version from nvcc --version. Got: %s" % "Could not parse CUDA version from nvcc --version. Got: %s" %
result.stdout,) result.stdout,
)
# Parse the CUDA version from the line containing the CUDA version. # Parse the CUDA version from the line containing the CUDA version.
prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, "") prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, "")
@ -507,7 +505,8 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
if len(parts) != 2 or len(parts[0]) < 2: if len(parts) != 2 or len(parts[0]) < 2:
auto_configure_fail( auto_configure_fail(
"Could not parse CUDA version from nvcc --version. Got: %s" % "Could not parse CUDA version from nvcc --version. Got: %s" %
result.stdout,) result.stdout,
)
full_version = parts[1].strip() full_version = parts[1].strip()
if full_version.startswith("V"): if full_version.startswith("V"):
full_version = full_version[1:] full_version = full_version[1:]
@ -520,7 +519,8 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
if environ_version and not matches_version(environ_version, full_version): if environ_version and not matches_version(environ_version, full_version):
auto_configure_fail( auto_configure_fail(
("CUDA version detected from nvcc (%s) does not match " + ("CUDA version detected from nvcc (%s) does not match " +
"TF_CUDA_VERSION (%s)") % (full_version, environ_version),) "TF_CUDA_VERSION (%s)") % (full_version, environ_version),
)
# We only use the version consisting of the major and minor version numbers. # We only use the version consisting of the major and minor version numbers.
version_parts = full_version.split(".") version_parts = full_version.split(".")
@ -532,7 +532,6 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
version = "%s.%s" % (version_parts[0], version_parts[1]) version = "%s.%s" % (version_parts[0], version_parts[1])
return version return version
_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR" _DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
_DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR" _DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL" _DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
@ -559,15 +558,23 @@ def find_cuda_define(repository_ctx, header_dir, header_file, define):
auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path))) auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
result = repository_ctx.execute( result = repository_ctx.execute(
# Grep one more lines as some #defines are splitted into two lines. # Grep one more lines as some #defines are splitted into two lines.
["grep", "--color=never", "-A1", "-E", define, [
str(h_path)],) "grep",
"--color=never",
"-A1",
"-E",
define,
str(h_path),
],
)
if result.stderr: if result.stderr:
auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr)) auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
# Parse the version from the line defining the macro. # Parse the version from the line defining the macro.
if result.stdout.find(define) == -1: if result.stdout.find(define) == -1:
auto_configure_fail( auto_configure_fail(
"Cannot find line containing '%s' in %s" % (define, h_path)) "Cannot find line containing '%s' in %s" % (define, h_path),
)
# Split results to lines # Split results to lines
lines = result.stdout.split("\n") lines = result.stdout.split("\n")
@ -592,11 +599,11 @@ def find_cuda_define(repository_ctx, header_dir, header_file, define):
if version_end == 0: if version_end == 0:
auto_configure_fail( auto_configure_fail(
"Cannot extract the version from line containing '%s' in %s" % "Cannot extract the version from line containing '%s' in %s" %
(define, str(h_path)),) (define, str(h_path)),
)
version = version[:version_end].strip() version = version[:version_end].strip()
return version return version
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value): def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
"""Detects the version of cuDNN installed on the system. """Detects the version of cuDNN installed on the system.
@ -639,17 +646,18 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip() environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
if environ_version and not matches_version(environ_version, full_version): if environ_version and not matches_version(environ_version, full_version):
cudnn_h_path = repository_ctx.path( cudnn_h_path = repository_ctx.path(
"%s/include/cudnn.h" % cudnn_install_basedir) "%s/include/cudnn.h" % cudnn_install_basedir,
)
auto_configure_fail(("cuDNN version detected from %s (%s) does not match " + auto_configure_fail(("cuDNN version detected from %s (%s) does not match " +
"TF_CUDNN_VERSION (%s)") % "TF_CUDNN_VERSION (%s)") %
(str(cudnn_h_path), full_version, environ_version),) (str(cudnn_h_path), full_version, environ_version))
# Only use the major version to match the SONAME of the library. # Only use the major version to match the SONAME of the library.
version = major_version version = major_version
if cpu_value == "Windows": if cpu_value == "Windows":
version = "64_" + version version = "64_" + version
return version return version
def compute_capabilities(repository_ctx): def compute_capabilities(repository_ctx):
"""Returns a list of strings representing cuda compute capabilities.""" """Returns a list of strings representing cuda compute capabilities."""
if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ: if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
@ -665,7 +673,6 @@ def compute_capabilities(repository_ctx):
auto_configure_fail("Invalid compute capability: %s" % capability) auto_configure_fail("Invalid compute capability: %s" % capability)
return capabilities return capabilities
def get_cpu_value(repository_ctx): def get_cpu_value(repository_ctx):
"""Returns the name of the host operating system. """Returns the name of the host operating system.
@ -683,12 +690,10 @@ def get_cpu_value(repository_ctx):
result = repository_ctx.execute(["uname", "-s"]) result = repository_ctx.execute(["uname", "-s"])
return result.stdout.strip() return result.stdout.strip()
def _is_windows(repository_ctx): def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows.""" """Returns true if the host operating system is windows."""
return get_cpu_value(repository_ctx) == "Windows" return get_cpu_value(repository_ctx) == "Windows"
def lib_name(base_name, cpu_value, version = None, static = False): def lib_name(base_name, cpu_value, version = None, static = False):
"""Constructs the platform-specific name of a library. """Constructs the platform-specific name of a library.
@ -740,10 +745,10 @@ def find_lib(repository_ctx, paths, check_soname = True):
return path return path
if mismatches: if mismatches:
auto_configure_fail( auto_configure_fail(
"None of the libraries match their SONAME: " + ", ".join(mismatches)) "None of the libraries match their SONAME: " + ", ".join(mismatches),
)
auto_configure_fail("No library found under: " + ", ".join(paths)) auto_configure_fail("No library found under: " + ", ".join(paths))
def _find_cuda_lib( def _find_cuda_lib(
lib, lib,
repository_ctx, repository_ctx,
@ -766,10 +771,10 @@ def _find_cuda_lib(
""" """
file_name = lib_name(lib, cpu_value, version, static) file_name = lib_name(lib, cpu_value, version, static)
return find_lib(repository_ctx, [ return find_lib(repository_ctx, [
"%s/%s%s" % (basedir, path, file_name) for path in CUDA_LIB_PATHS "%s/%s%s" % (basedir, path, file_name)
for path in CUDA_LIB_PATHS
], check_soname = version and not static) ], check_soname = version and not static)
def _find_cupti_header_dir(repository_ctx, cuda_config): def _find_cupti_header_dir(repository_ctx, cuda_config):
"""Returns the path to the directory containing cupti.h """Returns the path to the directory containing cupti.h
@ -786,11 +791,12 @@ def _find_cupti_header_dir(repository_ctx, cuda_config):
cuda_toolkit_path = cuda_config.cuda_toolkit_path cuda_toolkit_path = cuda_config.cuda_toolkit_path
for relative_path in CUPTI_HEADER_PATHS: for relative_path in CUPTI_HEADER_PATHS:
if repository_ctx.path( if repository_ctx.path(
"%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists: "%s/%scupti.h" % (cuda_toolkit_path, relative_path),
).exists:
return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1] return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
auto_configure_fail("Cannot find cupti.h under %s" % ", ".join( auto_configure_fail("Cannot find cupti.h under %s" % ", ".join(
[cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS])) [cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS],
))
def _find_cupti_lib(repository_ctx, cuda_config): def _find_cupti_lib(repository_ctx, cuda_config):
"""Finds the cupti library on the system. """Finds the cupti library on the system.
@ -812,10 +818,10 @@ def _find_cupti_lib(repository_ctx, cuda_config):
) )
basedir = cuda_config.cuda_toolkit_path basedir = cuda_config.cuda_toolkit_path
return find_lib(repository_ctx, [ return find_lib(repository_ctx, [
"%s/%s%s" % (basedir, path, file_name) for path in CUPTI_LIB_PATHS "%s/%s%s" % (basedir, path, file_name)
for path in CUPTI_LIB_PATHS
]) ])
def _find_libs(repository_ctx, cuda_config): def _find_libs(repository_ctx, cuda_config):
"""Returns the CUDA and cuDNN libraries on the system. """Returns the CUDA and cuDNN libraries on the system.
@ -828,23 +834,21 @@ def _find_libs(repository_ctx, cuda_config):
""" """
cpu_value = cuda_config.cpu_value cpu_value = cuda_config.cpu_value
return { return {
"cuda": "cuda": _find_cuda_lib(
_find_cuda_lib(
"cuda", "cuda",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
None), None,
"cudart": ),
_find_cuda_lib( "cudart": _find_cuda_lib(
"cudart", "cudart",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version, cuda_config.cuda_version,
), ),
"cudart_static": "cudart_static": _find_cuda_lib(
_find_cuda_lib(
"cudart_static", "cudart_static",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
@ -852,51 +856,44 @@ def _find_libs(repository_ctx, cuda_config):
cuda_config.cuda_version, cuda_config.cuda_version,
static = True, static = True,
), ),
"cublas": "cublas": _find_cuda_lib(
_find_cuda_lib(
"cublas", "cublas",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version, cuda_config.cuda_version,
), ),
"cusolver": "cusolver": _find_cuda_lib(
_find_cuda_lib(
"cusolver", "cusolver",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version, cuda_config.cuda_version,
), ),
"curand": "curand": _find_cuda_lib(
_find_cuda_lib(
"curand", "curand",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version, cuda_config.cuda_version,
), ),
"cufft": "cufft": _find_cuda_lib(
_find_cuda_lib(
"cufft", "cufft",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
cuda_config.cuda_version, cuda_config.cuda_version,
), ),
"cudnn": "cudnn": _find_cuda_lib(
_find_cuda_lib(
"cudnn", "cudnn",
repository_ctx, repository_ctx,
cpu_value, cpu_value,
cuda_config.cudnn_install_basedir, cuda_config.cudnn_install_basedir,
cuda_config.cudnn_version, cuda_config.cudnn_version,
), ),
"cupti": "cupti": _find_cupti_lib(repository_ctx, cuda_config),
_find_cupti_lib(repository_ctx, cuda_config),
} }
def _find_cuda_include_path(repository_ctx, cuda_config): def _find_cuda_include_path(repository_ctx, cuda_config):
"""Returns the path to the directory containing cuda.h """Returns the path to the directory containing cuda.h
@ -910,11 +907,11 @@ def _find_cuda_include_path(repository_ctx, cuda_config):
cuda_toolkit_path = cuda_config.cuda_toolkit_path cuda_toolkit_path = cuda_config.cuda_toolkit_path
for relative_path in CUDA_INCLUDE_PATHS: for relative_path in CUDA_INCLUDE_PATHS:
if repository_ctx.path( if repository_ctx.path(
"%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists: "%s/%scuda.h" % (cuda_toolkit_path, relative_path),
).exists:
return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1] return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path) auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir): def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
"""Returns the path to the directory containing cudnn.h """Returns the path to the directory containing cudnn.h
@ -928,13 +925,13 @@ def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
""" """
for relative_path in CUDA_INCLUDE_PATHS: for relative_path in CUDA_INCLUDE_PATHS:
if repository_ctx.path( if repository_ctx.path(
"%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists: "%s/%scudnn.h" % (cudnn_install_basedir, relative_path),
).exists:
return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1] return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
if repository_ctx.path("/usr/include/cudnn.h").exists: if repository_ctx.path("/usr/include/cudnn.h").exists:
return "/usr/include" return "/usr/include"
auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir) auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
def _find_nvvm_libdevice_dir(repository_ctx, cuda_config): def _find_nvvm_libdevice_dir(repository_ctx, cuda_config):
"""Returns the path to the directory containing libdevice in bitcode format. """Returns the path to the directory containing libdevice in bitcode format.
@ -948,18 +945,20 @@ def _find_nvvm_libdevice_dir(repository_ctx, cuda_config):
cuda_toolkit_path = cuda_config.cuda_toolkit_path cuda_toolkit_path = cuda_config.cuda_toolkit_path
for libdevice_file in NVVM_LIBDEVICE_FILES: for libdevice_file in NVVM_LIBDEVICE_FILES:
for relative_path in NVVM_LIBDEVICE_PATHS: for relative_path in NVVM_LIBDEVICE_PATHS:
if repository_ctx.path("%s/%s%s" % (cuda_toolkit_path, relative_path, if repository_ctx.path("%s/%s%s" % (
libdevice_file)).exists: cuda_toolkit_path,
relative_path,
libdevice_file,
)).exists:
return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1] return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
auto_configure_fail( auto_configure_fail(
"Cannot find libdevice*.bc files under %s" % cuda_toolkit_path) "Cannot find libdevice*.bc files under %s" % cuda_toolkit_path,
)
def _cudart_static_linkopt(cpu_value): def _cudart_static_linkopt(cpu_value):
"""Returns additional platform-specific linkopts for cudart.""" """Returns additional platform-specific linkopts for cudart."""
return "" if cpu_value == "Darwin" else "\"-lrt\"," return "" if cpu_value == "Darwin" else "\"-lrt\","
def _get_cuda_config(repository_ctx): def _get_cuda_config(repository_ctx):
"""Detects and returns information about the CUDA installation on the system. """Detects and returns information about the CUDA installation on the system.
@ -979,8 +978,11 @@ def _get_cuda_config(repository_ctx):
toolkit_path = cuda_toolkit_path(repository_ctx) toolkit_path = cuda_toolkit_path(repository_ctx)
cuda_version = _cuda_version(repository_ctx, toolkit_path, cpu_value) cuda_version = _cuda_version(repository_ctx, toolkit_path, cpu_value)
cudnn_install_basedir = _cudnn_install_basedir(repository_ctx) cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cudnn_version = _cudnn_version(
cpu_value) repository_ctx,
cudnn_install_basedir,
cpu_value,
)
return struct( return struct(
cuda_toolkit_path = toolkit_path, cuda_toolkit_path = toolkit_path,
cudnn_install_basedir = cudnn_install_basedir, cudnn_install_basedir = cudnn_install_basedir,
@ -990,7 +992,6 @@ def _get_cuda_config(repository_ctx):
cpu_value = cpu_value, cpu_value = cpu_value,
) )
def _tpl(repository_ctx, tpl, substitutions = {}, out = None): def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out: if not out:
out = tpl.replace(":", "/") out = tpl.replace(":", "/")
@ -1000,7 +1001,6 @@ def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
substitutions, substitutions,
) )
def _file(repository_ctx, label): def _file(repository_ctx, label):
repository_ctx.template( repository_ctx.template(
label.replace(":", "/"), label.replace(":", "/"),
@ -1008,7 +1008,6 @@ def _file(repository_ctx, label):
{}, {},
) )
_DUMMY_CROSSTOOL_BZL_FILE = """ _DUMMY_CROSSTOOL_BZL_FILE = """
def error_gpu_disabled(): def error_gpu_disabled():
fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + fail("ERROR: Building with --config=cuda but TensorFlow is not configured " +
@ -1050,34 +1049,22 @@ def _create_dummy_repository(repository_ctx):
repository_ctx, repository_ctx,
"cuda:BUILD", "cuda:BUILD",
{ {
"%{cuda_driver_lib}": "%{cuda_driver_lib}": lib_name("cuda", cpu_value),
lib_name("cuda", cpu_value), "%{cudart_static_lib}": lib_name(
"%{cudart_static_lib}":
lib_name(
"cudart_static", "cudart_static",
cpu_value, cpu_value,
static = True, static = True,
), ),
"%{cudart_static_linkopt}": "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
_cudart_static_linkopt(cpu_value), "%{cudart_lib}": lib_name("cudart", cpu_value),
"%{cudart_lib}": "%{cublas_lib}": lib_name("cublas", cpu_value),
lib_name("cudart", cpu_value), "%{cusolver_lib}": lib_name("cusolver", cpu_value),
"%{cublas_lib}": "%{cudnn_lib}": lib_name("cudnn", cpu_value),
lib_name("cublas", cpu_value), "%{cufft_lib}": lib_name("cufft", cpu_value),
"%{cusolver_lib}": "%{curand_lib}": lib_name("curand", cpu_value),
lib_name("cusolver", cpu_value), "%{cupti_lib}": lib_name("cupti", cpu_value),
"%{cudnn_lib}": "%{copy_rules}": "",
lib_name("cudnn", cpu_value), "%{cuda_headers}": "",
"%{cufft_lib}":
lib_name("cufft", cpu_value),
"%{curand_lib}":
lib_name("curand", cpu_value),
"%{cupti_lib}":
lib_name("cupti", cpu_value),
"%{copy_rules}":
"",
"%{cuda_headers}":
"",
}, },
) )
@ -1090,7 +1077,8 @@ def _create_dummy_repository(repository_ctx):
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cuda", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudart", cpu_value))
repository_ctx.file( repository_ctx.file(
"cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value)) "cuda/cuda/lib/%s" % lib_name("cudart_static", cpu_value),
)
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cublas", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cusolver", cpu_value))
repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value)) repository_ctx.file("cuda/cuda/lib/%s" % lib_name("cudnn", cpu_value))
@ -1104,17 +1092,13 @@ def _create_dummy_repository(repository_ctx):
repository_ctx, repository_ctx,
"cuda:cuda_config.h", "cuda:cuda_config.h",
{ {
"%{cuda_version}": "%{cuda_version}": _DEFAULT_CUDA_VERSION,
_DEFAULT_CUDA_VERSION, "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
"%{cudnn_version}": "%{cuda_compute_capabilities}": ",".join([
_DEFAULT_CUDNN_VERSION,
"%{cuda_compute_capabilities}":
",".join([
"CudaVersion(\"%s\")" % c "CudaVersion(\"%s\")" % c
for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
]), ]),
"%{cuda_toolkit_path}": "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH,
_DEFAULT_CUDA_TOOLKIT_PATH,
}, },
"cuda/cuda/cuda_config.h", "cuda/cuda/cuda_config.h",
) )
@ -1128,7 +1112,6 @@ def _create_dummy_repository(repository_ctx):
) )
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
def _execute( def _execute(
repository_ctx, repository_ctx,
cmdline, cmdline,
@ -1153,10 +1136,10 @@ def _execute(
error_msg.strip() if error_msg else "Repository command failed", error_msg.strip() if error_msg else "Repository command failed",
result.stderr.strip(), result.stderr.strip(),
error_details if error_details else "", error_details if error_details else "",
]),) ]),
)
return result return result
def _norm_path(path): def _norm_path(path):
"""Returns a path with '/' and remove the trailing slash.""" """Returns a path with '/' and remove the trailing slash."""
path = path.replace("\\", "/") path = path.replace("\\", "/")
@ -1167,6 +1150,7 @@ def _norm_path(path):
def make_copy_files_rule(repository_ctx, name, srcs, outs): def make_copy_files_rule(repository_ctx, name, srcs, outs):
"""Returns a rule to copy a set of files.""" """Returns a rule to copy a set of files."""
cmds = [] cmds = []
# Copy files. # Copy files.
for src, out in zip(srcs, outs): for src, out in zip(srcs, outs):
cmds.append('cp -f "%s" $(location %s)' % (src, out)) cmds.append('cp -f "%s" $(location %s)' % (src, out))
@ -1185,6 +1169,7 @@ def make_copy_dir_rule(repository_ctx, name, src_dir, out_dir):
out_dir = _norm_path(out_dir) out_dir = _norm_path(out_dir)
outs = _read_dir(repository_ctx, src_dir) outs = _read_dir(repository_ctx, src_dir)
outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs] outs = [(' "%s",' % out.replace(src_dir, out_dir)) for out in outs]
# '@D' already contains the relative path for a single file, see # '@D' already contains the relative path for a single file, see
# http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables # http://docs.bazel.build/versions/master/be/make-variables.html#predefined_genrule_variables
out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)" out_dir = "$(@D)/%s" % out_dir if len(outs) > 1 else "$(@D)"
@ -1223,18 +1208,15 @@ def _read_dir(repository_ctx, src_dir):
result = find_result.stdout result = find_result.stdout
return sorted(result.splitlines()) return sorted(result.splitlines())
def _flag_enabled(repository_ctx, flag_name): def _flag_enabled(repository_ctx, flag_name):
if flag_name in repository_ctx.os.environ: if flag_name in repository_ctx.os.environ:
value = repository_ctx.os.environ[flag_name].strip() value = repository_ctx.os.environ[flag_name].strip()
return value == "1" return value == "1"
return False return False
def _use_cuda_clang(repository_ctx): def _use_cuda_clang(repository_ctx):
return _flag_enabled(repository_ctx, "TF_CUDA_CLANG") return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
if _use_cuda_clang(repository_ctx): if _use_cuda_clang(repository_ctx):
capability_flags = [ capability_flags = [
@ -1247,7 +1229,6 @@ def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
capability_flags = [] capability_flags = []
return str(capability_flags) return str(capability_flags)
def _create_local_cuda_repository(repository_ctx): def _create_local_cuda_repository(repository_ctx):
"""Creates the repository containing files set up to build with CUDA.""" """Creates the repository containing files set up to build with CUDA."""
cuda_config = _get_cuda_config(repository_ctx) cuda_config = _get_cuda_config(repository_ctx)
@ -1299,7 +1280,7 @@ def _create_local_cuda_repository(repository_ctx):
repository_ctx, repository_ctx,
name = "cuda-bin", name = "cuda-bin",
src_dir = cuda_config.cuda_toolkit_path + "/bin", src_dir = cuda_config.cuda_toolkit_path + "/bin",
out_dir = "cuda/bin" out_dir = "cuda/bin",
)) ))
# Copy cudnn.h if cuDNN was not installed to CUDA_TOOLKIT_PATH. # Copy cudnn.h if cuDNN was not installed to CUDA_TOOLKIT_PATH.
@ -1319,10 +1300,8 @@ def _create_local_cuda_repository(repository_ctx):
repository_ctx, repository_ctx,
"cuda:build_defs.bzl", "cuda:build_defs.bzl",
{ {
"%{cuda_is_configured}": "%{cuda_is_configured}": "True",
"True", "%{cuda_extra_copts}": _compute_cuda_extra_copts(
"%{cuda_extra_copts}":
_compute_cuda_extra_copts(
repository_ctx, repository_ctx,
cuda_config.compute_capabilities, cuda_config.compute_capabilities,
), ),
@ -1332,29 +1311,19 @@ def _create_local_cuda_repository(repository_ctx):
repository_ctx, repository_ctx,
"cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD", "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
{ {
"%{cuda_driver_lib}": "%{cuda_driver_lib}": cuda_libs["cuda"].basename,
cuda_libs["cuda"].basename, "%{cudart_static_lib}": cuda_libs["cudart_static"].basename,
"%{cudart_static_lib}": "%{cudart_static_linkopt}": _cudart_static_linkopt(cuda_config.cpu_value),
cuda_libs["cudart_static"].basename, "%{cudart_lib}": cuda_libs["cudart"].basename,
"%{cudart_static_linkopt}": "%{cublas_lib}": cuda_libs["cublas"].basename,
_cudart_static_linkopt(cuda_config.cpu_value,), "%{cusolver_lib}": cuda_libs["cusolver"].basename,
"%{cudart_lib}": "%{cudnn_lib}": cuda_libs["cudnn"].basename,
cuda_libs["cudart"].basename, "%{cufft_lib}": cuda_libs["cufft"].basename,
"%{cublas_lib}": "%{curand_lib}": cuda_libs["curand"].basename,
cuda_libs["cublas"].basename, "%{cupti_lib}": cuda_libs["cupti"].basename,
"%{cusolver_lib}": "%{copy_rules}": "\n".join(copy_rules),
cuda_libs["cusolver"].basename, "%{cuda_headers}": (
"%{cudnn_lib}": '":cuda-include",\n' + ' ":cudnn-include",'
cuda_libs["cudnn"].basename,
"%{cufft_lib}":
cuda_libs["cufft"].basename,
"%{curand_lib}":
cuda_libs["curand"].basename,
"%{cupti_lib}":
cuda_libs["cupti"].basename,
"%{copy_rules}":
"\n".join(copy_rules),
"%{cuda_headers}": ('":cuda-include",\n' + ' ":cudnn-include",'
), ),
}, },
"cuda/BUILD", "cuda/BUILD",
@ -1375,6 +1344,7 @@ def _create_local_cuda_repository(repository_ctx):
host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath) host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
cuda_defines = {} cuda_defines = {}
# Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
# https://github.com/bazelbuild/bazel/issues/760). # https://github.com/bazelbuild/bazel/issues/760).
# However, this stops our custom clang toolchain from picking the provided # However, this stops our custom clang toolchain from picking the provided
@ -1399,14 +1369,15 @@ def _create_local_cuda_repository(repository_ctx):
cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ""
_tpl(repository_ctx, "crosstool:BUILD", { _tpl(repository_ctx, "crosstool:BUILD", {
"%{linker_files}": ":empty", "%{linker_files}": ":empty",
"%{win_linker_files}": ":empty" "%{win_linker_files}": ":empty",
}) })
repository_ctx.file( repository_ctx.file(
"crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "") "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc",
"",
)
repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "") repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
else: else:
cuda_defines[ cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
"%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
cuda_defines["%{host_compiler_warnings}"] = "" cuda_defines["%{host_compiler_warnings}"] = ""
# nvcc has the system include paths built in and will automatically # nvcc has the system include paths built in and will automatically
@ -1414,9 +1385,12 @@ def _create_local_cuda_repository(repository_ctx):
# system paths to the allowed compiler specific include paths. # system paths to the allowed compiler specific include paths.
cuda_defines["%{host_compiler_includes}"] = ( cuda_defines["%{host_compiler_includes}"] = (
host_compiler_includes + "\n" + _cuda_include_path( host_compiler_includes + "\n" + _cuda_include_path(
repository_ctx, cuda_config) + repository_ctx,
cuda_config,
) +
"\n cxx_builtin_include_directory: \"%s\"" % cupti_header_dir + "\n cxx_builtin_include_directory: \"%s\"" % cupti_header_dir +
"\n cxx_builtin_include_directory: \"%s\"" % cudnn_header_dir) "\n cxx_builtin_include_directory: \"%s\"" % cudnn_header_dir
)
# For gcc, do not canonicalize system header paths; some versions of gcc # For gcc, do not canonicalize system header paths; some versions of gcc
# pick the shortest possible path for system includes when creating the # pick the shortest possible path for system includes when creating the
@ -1424,12 +1398,14 @@ def _create_local_cuda_repository(repository_ctx):
# time quickly grow longer than the root of the tree, this can lead to # time quickly grow longer than the root of the tree, this can lead to
# bazel's header check failing. # bazel's header check failing.
cuda_defines["%{extra_no_canonical_prefixes_flags}"] = ( cuda_defines["%{extra_no_canonical_prefixes_flags}"] = (
"flag: \"-fno-canonical-system-headers\"") "flag: \"-fno-canonical-system-headers\""
)
nvcc_path = str( nvcc_path = str(
repository_ctx.path("%s/bin/nvcc%s" % ( repository_ctx.path("%s/bin/nvcc%s" % (
cuda_config.cuda_toolkit_path, cuda_config.cuda_toolkit_path,
".exe" if _is_windows(repository_ctx) else "", ".exe" if _is_windows(repository_ctx) else "",
))) )),
)
_tpl( _tpl(
repository_ctx, repository_ctx,
"crosstool:BUILD", "crosstool:BUILD",
@ -1439,19 +1415,14 @@ def _create_local_cuda_repository(repository_ctx):
}, },
) )
wrapper_defines = { wrapper_defines = {
"%{cpu_compiler}": "%{cpu_compiler}": str(cc),
str(cc), "%{cuda_version}": cuda_config.cuda_version,
"%{cuda_version}": "%{nvcc_path}": nvcc_path,
cuda_config.cuda_version, "%{gcc_host_compiler_path}": str(cc),
"%{nvcc_path}": "%{cuda_compute_capabilities}": ", ".join(
nvcc_path, ["\"%s\"" % c for c in cuda_config.compute_capabilities],
"%{gcc_host_compiler_path}": ),
str(cc), "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
"%{cuda_compute_capabilities}":
", ".join(
["\"%s\"" % c for c in cuda_config.compute_capabilities],),
"%{nvcc_tmp_dir}":
_get_nvcc_tmp_dir_for_windows(repository_ctx),
} }
_tpl( _tpl(
repository_ctx, repository_ctx,
@ -1477,32 +1448,25 @@ def _create_local_cuda_repository(repository_ctx):
repository_ctx, repository_ctx,
"cuda:cuda_config.h", "cuda:cuda_config.h",
{ {
"%{cuda_version}": "%{cuda_version}": cuda_config.cuda_version,
cuda_config.cuda_version, "%{cudnn_version}": cuda_config.cudnn_version,
"%{cudnn_version}": "%{cuda_compute_capabilities}": ",".join([
cuda_config.cudnn_version,
"%{cuda_compute_capabilities}":
",".join([
"CudaVersion(\"%s\")" % c "CudaVersion(\"%s\")" % c
for c in cuda_config.compute_capabilities for c in cuda_config.compute_capabilities
],), ]),
"%{cuda_toolkit_path}": "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
cuda_config.cuda_toolkit_path,
}, },
"cuda/cuda/cuda_config.h", "cuda/cuda/cuda_config.h",
) )
def _create_remote_cuda_repository(repository_ctx, remote_config_repo): def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
"""Creates pointers to a remotely configured repo set up to build with CUDA.""" """Creates pointers to a remotely configured repo set up to build with CUDA."""
_tpl( _tpl(
repository_ctx, repository_ctx,
"cuda:build_defs.bzl", "cuda:build_defs.bzl",
{ {
"%{cuda_is_configured}": "%{cuda_is_configured}": "True",
"True", "%{cuda_extra_copts}": _compute_cuda_extra_copts(
"%{cuda_extra_copts}":
_compute_cuda_extra_copts(
repository_ctx, repository_ctx,
compute_capabilities(repository_ctx), compute_capabilities(repository_ctx),
), ),
@ -1524,7 +1488,6 @@ def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
{}, {},
) )
def _cuda_autoconf_impl(repository_ctx): def _cuda_autoconf_impl(repository_ctx):
"""Implementation of the cuda_autoconf repository rule.""" """Implementation of the cuda_autoconf repository rule."""
if not enable_cuda(repository_ctx): if not enable_cuda(repository_ctx):
@ -1537,7 +1500,6 @@ def _cuda_autoconf_impl(repository_ctx):
else: else:
_create_local_cuda_repository(repository_ctx) _create_local_cuda_repository(repository_ctx)
cuda_configure = repository_rule( cuda_configure = repository_rule(
implementation = _cuda_autoconf_impl, implementation = _cuda_autoconf_impl,
environ = [ environ = [

View File

@ -242,8 +242,13 @@ def _hipcc_env(repository_ctx):
A string containing environment variables for hipcc. A string containing environment variables for hipcc.
""" """
hipcc_env = "" hipcc_env = ""
for name in ["HIP_CLANG_PATH", "DEVICE_LIB_PATH", "HIP_VDI_HOME",\ for name in [
"HIPCC_VERBOSE", "HIPCC_COMPILE_FLAGS_APPEND"]: "HIP_CLANG_PATH",
"DEVICE_LIB_PATH",
"HIP_VDI_HOME",
"HIPCC_VERBOSE",
"HIPCC_COMPILE_FLAGS_APPEND",
]:
if name in repository_ctx.os.environ: if name in repository_ctx.os.environ:
hipcc_env = hipcc_env + " " + name + "=\"" + \ hipcc_env = hipcc_env + " " + name + "=\"" + \
repository_ctx.os.environ[name].strip() + "\";" repository_ctx.os.environ[name].strip() + "\";"
@ -636,7 +641,6 @@ def _create_local_rocm_repository(repository_ctx):
outs = rocm_lib_outs, outs = rocm_lib_outs,
)) ))
# Set up BUILD file for rocm/ # Set up BUILD file for rocm/
_tpl( _tpl(
repository_ctx, repository_ctx,

View File

@ -8,10 +8,10 @@ cc_library(
srcs = ["libnccl.so.%{version}"], srcs = ["libnccl.so.%{version}"],
hdrs = ["nccl.h"], hdrs = ["nccl.h"],
include_prefix = "third_party/nccl", include_prefix = "third_party/nccl",
visibility = ["//visibility:public"],
deps = [ deps = [
"@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cuda_headers",
], ],
visibility = ["//visibility:public"],
) )
genrule( genrule(
@ -23,4 +23,3 @@ genrule(
cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" && cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" &&
cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """, cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
) )

View File

@ -11,15 +11,14 @@ _PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
_PYTHON_LIB_PATH = "PYTHON_LIB_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
_TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO" _TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO"
def _tpl(repository_ctx, tpl, substitutions = {}, out = None): def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out: if not out:
out = tpl out = tpl
repository_ctx.template( repository_ctx.template(
out, out,
Label("//third_party/py:%s.tpl" % tpl), Label("//third_party/py:%s.tpl" % tpl),
substitutions) substitutions,
)
def _fail(msg): def _fail(msg):
"""Output failure message when auto configuration fails.""" """Output failure message when auto configuration fails."""
@ -27,7 +26,6 @@ def _fail(msg):
no_color = "\033[0m" no_color = "\033[0m"
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
def _is_windows(repository_ctx): def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows.""" """Returns true if the host operating system is windows."""
os_name = repository_ctx.os.name.lower() os_name = repository_ctx.os.name.lower()
@ -35,8 +33,11 @@ def _is_windows(repository_ctx):
return True return True
return False return False
def _execute(
def _execute(repository_ctx, cmdline, error_msg=None, error_details=None, repository_ctx,
cmdline,
error_msg = None,
error_details = None,
empty_stdout_fine = False): empty_stdout_fine = False):
"""Executes an arbitrary shell command. """Executes an arbitrary shell command.
@ -55,10 +56,10 @@ def _execute(repository_ctx, cmdline, error_msg=None, error_details=None,
_fail("\n".join([ _fail("\n".join([
error_msg.strip() if error_msg else "Repository command failed", error_msg.strip() if error_msg else "Repository command failed",
result.stderr.strip(), result.stderr.strip(),
error_details if error_details else ""])) error_details if error_details else "",
]))
return result return result
def _read_dir(repository_ctx, src_dir): def _read_dir(repository_ctx, src_dir):
"""Returns a string with all files in a directory. """Returns a string with all files in a directory.
@ -69,38 +70,41 @@ def _read_dir(repository_ctx, src_dir):
if _is_windows(repository_ctx): if _is_windows(repository_ctx):
src_dir = src_dir.replace("/", "\\") src_dir = src_dir.replace("/", "\\")
find_result = _execute( find_result = _execute(
repository_ctx, ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], repository_ctx,
empty_stdout_fine=True) ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
empty_stdout_fine = True,
)
# src_files will be used in genrule.outs where the paths must # src_files will be used in genrule.outs where the paths must
# use forward slashes. # use forward slashes.
result = find_result.stdout.replace("\\", "/") result = find_result.stdout.replace("\\", "/")
else: else:
find_result = _execute( find_result = _execute(
repository_ctx, ["find", src_dir, "-follow", "-type", "f"], repository_ctx,
empty_stdout_fine=True) ["find", src_dir, "-follow", "-type", "f"],
empty_stdout_fine = True,
)
result = find_result.stdout result = find_result.stdout
return result return result
def _genrule(src_dir, genrule_name, command, outs): def _genrule(src_dir, genrule_name, command, outs):
"""Returns a string with a genrule. """Returns a string with a genrule.
Genrule executes the given command and produces the given outputs. Genrule executes the given command and produces the given outputs.
""" """
return ( return (
'genrule(\n' + "genrule(\n" +
' name = "' + ' name = "' +
genrule_name + '",\n' + genrule_name + '",\n' +
' outs = [\n' + " outs = [\n" +
outs + outs +
'\n ],\n' + "\n ],\n" +
' cmd = """\n' + ' cmd = """\n' +
command + command +
'\n """,\n' + '\n """,\n' +
')\n' ")\n"
) )
def _norm_path(path): def _norm_path(path):
"""Returns a path with '/' and remove the trailing slash.""" """Returns a path with '/' and remove the trailing slash."""
path = path.replace("\\", "/") path = path.replace("\\", "/")
@ -108,9 +112,13 @@ def _norm_path(path):
path = path[:-1] path = path[:-1]
return path return path
def _symlink_genrule_for_dir(
def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, repository_ctx,
src_files = [], dest_files = []): src_dir,
dest_dir,
genrule_name,
src_files = [],
dest_files = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files. """Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise If src_dir is passed, files will be read from the given directory; otherwise
@ -119,9 +127,10 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
if src_dir != None: if src_dir != None:
src_dir = _norm_path(src_dir) src_dir = _norm_path(src_dir)
dest_dir = _norm_path(dest_dir) dest_dir = _norm_path(dest_dir)
files = '\n'.join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
# Create a list with the src_dir stripped to use for outputs. # Create a list with the src_dir stripped to use for outputs.
dest_files = files.replace(src_dir, '').splitlines() dest_files = files.replace(src_dir, "").splitlines()
src_files = files.splitlines() src_files = files.splitlines()
command = [] command = []
outs = [] outs = []
@ -129,16 +138,20 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
if dest_files[i] != "": if dest_files[i] != "":
# If we have only one file to link we do not want to use the dest_dir, as # If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file. # $(@D) will include the full path to the file.
dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i] dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
# Copy the headers to create a sandboxable setup. # Copy the headers to create a sandboxable setup.
cmd = 'cp -f' cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",') outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(src_dir, genrule_name, " && ".join(command), genrule = _genrule(
"\n".join(outs)) src_dir,
genrule_name,
" && ".join(command),
"\n".join(outs),
)
return genrule return genrule
def _get_python_bin(repository_ctx): def _get_python_bin(repository_ctx):
"""Gets the python bin path.""" """Gets the python bin path."""
python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
@ -150,8 +163,9 @@ def _get_python_bin(repository_ctx):
_fail("Cannot find python in PATH, please make sure " + _fail("Cannot find python in PATH, please make sure " +
"python is installed and add its directory in PATH, or --define " + "python is installed and add its directory in PATH, or --define " +
"%s='/something/else'.\nPATH=%s" % ( "%s='/something/else'.\nPATH=%s" % (
_PYTHON_BIN_PATH, repository_ctx.os.environ.get("PATH", ""))) _PYTHON_BIN_PATH,
repository_ctx.os.environ.get("PATH", ""),
))
def _get_bash_bin(repository_ctx): def _get_bash_bin(repository_ctx):
"""Gets the bash bin path.""" """Gets the bash bin path."""
@ -166,8 +180,9 @@ def _get_bash_bin(repository_ctx):
_fail("Cannot find bash in PATH, please make sure " + _fail("Cannot find bash in PATH, please make sure " +
"bash is installed and add its directory in PATH, or --define " + "bash is installed and add its directory in PATH, or --define " +
"%s='/path/to/bash'.\nPATH=%s" % ( "%s='/path/to/bash'.\nPATH=%s" % (
_BAZEL_SH, repository_ctx.os.environ.get("PATH", ""))) _BAZEL_SH,
repository_ctx.os.environ.get("PATH", ""),
))
def _get_python_lib(repository_ctx, python_bin): def _get_python_lib(repository_ctx, python_bin):
"""Gets the python lib path.""" """Gets the python lib path."""
@ -200,10 +215,9 @@ def _get_python_lib(repository_ctx, python_bin):
"if len(paths) >=1:\n" + "if len(paths) >=1:\n" +
" print(paths[0])\n" + " print(paths[0])\n" +
"END") "END")
cmd = '%s - %s' % (python_bin, print_lib) cmd = "%s - %s" % (python_bin, print_lib)
result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd])
return result.stdout.strip('\n') return result.stdout.strip("\n")
def _check_python_lib(repository_ctx, python_lib): def _check_python_lib(repository_ctx, python_lib):
"""Checks the python lib path.""" """Checks the python lib path."""
@ -212,55 +226,65 @@ def _check_python_lib(repository_ctx, python_lib):
if result.return_code == 1: if result.return_code == 1:
_fail("Invalid python library path: %s" % python_lib) _fail("Invalid python library path: %s" % python_lib)
def _check_python_bin(repository_ctx, python_bin): def _check_python_bin(repository_ctx, python_bin):
"""Checks the python bin path.""" """Checks the python bin path."""
cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin) cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin)
result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd])
if result.return_code == 1: if result.return_code == 1:
_fail("--define %s='%s' is not executable. Is it the python binary?" % ( _fail("--define %s='%s' is not executable. Is it the python binary?" % (
_PYTHON_BIN_PATH, python_bin)) _PYTHON_BIN_PATH,
python_bin,
))
def _get_python_include(repository_ctx, python_bin): def _get_python_include(repository_ctx, python_bin):
"""Gets the python include path.""" """Gets the python include path."""
result = _execute( result = _execute(
repository_ctx, repository_ctx,
[python_bin, "-c", [
'from __future__ import print_function;' + python_bin,
'from distutils import sysconfig;' + "-c",
'print(sysconfig.get_python_inc())'], "from __future__ import print_function;" +
"from distutils import sysconfig;" +
"print(sysconfig.get_python_inc())",
],
error_msg = "Problem getting python include path.", error_msg = "Problem getting python include path.",
error_details = ("Is the Python binary path set up right? " + error_details = ("Is the Python binary path set up right? " +
"(See ./configure or " + _PYTHON_BIN_PATH + ".) " + "(See ./configure or " + _PYTHON_BIN_PATH + ".) " +
"Is distutils installed?")) "Is distutils installed?"),
)
return result.stdout.splitlines()[0] return result.stdout.splitlines()[0]
def _get_python_import_lib_name(repository_ctx, python_bin): def _get_python_import_lib_name(repository_ctx, python_bin):
"""Get Python import library name (pythonXY.lib) on Windows.""" """Get Python import library name (pythonXY.lib) on Windows."""
result = _execute( result = _execute(
repository_ctx, repository_ctx,
[python_bin, "-c", [
'import sys;' + python_bin,
"-c",
"import sys;" +
'print("python" + str(sys.version_info[0]) + ' + 'print("python" + str(sys.version_info[0]) + ' +
' str(sys.version_info[1]) + ".lib")'], ' str(sys.version_info[1]) + ".lib")',
],
error_msg = "Problem getting python import library.", error_msg = "Problem getting python import library.",
error_details = ("Is the Python binary path set up right? " + error_details = ("Is the Python binary path set up right? " +
"(See ./configure or " + _PYTHON_BIN_PATH + ".) ")) "(See ./configure or " + _PYTHON_BIN_PATH + ".) "),
)
return result.stdout.splitlines()[0] return result.stdout.splitlines()[0]
def _get_numpy_include(repository_ctx, python_bin): def _get_numpy_include(repository_ctx, python_bin):
"""Gets the numpy include path.""" """Gets the numpy include path."""
return _execute(repository_ctx, return _execute(
[python_bin, "-c", repository_ctx,
'from __future__ import print_function;' + [
'import numpy;' + python_bin,
' print(numpy.get_include());'], "-c",
"from __future__ import print_function;" +
"import numpy;" +
" print(numpy.get_include());",
],
error_msg = "Problem getting numpy include path.", error_msg = "Problem getting numpy include path.",
error_details="Is numpy installed?").stdout.splitlines()[0] error_details = "Is numpy installed?",
).stdout.splitlines()[0]
def _create_local_python_repository(repository_ctx): def _create_local_python_repository(repository_ctx):
"""Creates the repository containing files set up to build with Python.""" """Creates the repository containing files set up to build with Python."""
@ -269,43 +293,56 @@ def _create_local_python_repository(repository_ctx):
python_lib = _get_python_lib(repository_ctx, python_bin) python_lib = _get_python_lib(repository_ctx, python_bin)
_check_python_lib(repository_ctx, python_lib) _check_python_lib(repository_ctx, python_lib)
python_include = _get_python_include(repository_ctx, python_bin) python_include = _get_python_include(repository_ctx, python_bin)
numpy_include = _get_numpy_include(repository_ctx, python_bin) + '/numpy' numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy"
python_include_rule = _symlink_genrule_for_dir( python_include_rule = _symlink_genrule_for_dir(
repository_ctx, python_include, 'python_include', 'python_include') repository_ctx,
python_include,
"python_include",
"python_include",
)
python_import_lib_genrule = "" python_import_lib_genrule = ""
# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib # To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
# See https://docs.python.org/3/extending/windows.html # See https://docs.python.org/3/extending/windows.html
if _is_windows(repository_ctx): if _is_windows(repository_ctx):
python_include = _norm_path(python_include) python_include = _norm_path(python_include)
python_import_lib_name = _get_python_import_lib_name(repository_ctx, python_bin) python_import_lib_name = _get_python_import_lib_name(repository_ctx, python_bin)
python_import_lib_src = python_include.rsplit('/', 1)[0] + "/libs/" + python_import_lib_name python_import_lib_src = python_include.rsplit("/", 1)[0] + "/libs/" + python_import_lib_name
python_import_lib_genrule = _symlink_genrule_for_dir( python_import_lib_genrule = _symlink_genrule_for_dir(
repository_ctx, None, '', 'python_import_lib', repository_ctx,
[python_import_lib_src], [python_import_lib_name]) None,
"",
"python_import_lib",
[python_import_lib_src],
[python_import_lib_name],
)
numpy_include_rule = _symlink_genrule_for_dir( numpy_include_rule = _symlink_genrule_for_dir(
repository_ctx, numpy_include, 'numpy_include/numpy', 'numpy_include') repository_ctx,
numpy_include,
"numpy_include/numpy",
"numpy_include",
)
_tpl(repository_ctx, "BUILD", { _tpl(repository_ctx, "BUILD", {
"%{PYTHON_INCLUDE_GENRULE}": python_include_rule, "%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
"%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule, "%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule,
"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule, "%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
}) })
def _create_remote_python_repository(repository_ctx, remote_config_repo): def _create_remote_python_repository(repository_ctx, remote_config_repo):
"""Creates pointers to a remotely configured repo set up to build with Python. """Creates pointers to a remotely configured repo set up to build with Python.
""" """
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {}) repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
def _python_autoconf_impl(repository_ctx): def _python_autoconf_impl(repository_ctx):
"""Implementation of the python_autoconf repository rule.""" """Implementation of the python_autoconf repository rule."""
if _TF_PYTHON_CONFIG_REPO in repository_ctx.os.environ: if _TF_PYTHON_CONFIG_REPO in repository_ctx.os.environ:
_create_remote_python_repository(repository_ctx, _create_remote_python_repository(
repository_ctx.os.environ[_TF_PYTHON_CONFIG_REPO]) repository_ctx,
repository_ctx.os.environ[_TF_PYTHON_CONFIG_REPO],
)
else: else:
_create_local_python_repository(repository_ctx) _create_local_python_repository(repository_ctx)
python_configure = repository_rule( python_configure = repository_rule(
implementation = _python_autoconf_impl, implementation = _python_autoconf_impl,
environ = [ environ = [

View File

@ -11,7 +11,7 @@ def if_sycl(if_true, if_false = []):
return select({ return select({
"@local_config_sycl//sycl:using_sycl_ccpp": if_true, "@local_config_sycl//sycl:using_sycl_ccpp": if_true,
"@local_config_sycl//sycl:using_sycl_trisycl": if_true[0:1], "@local_config_sycl//sycl:using_sycl_trisycl": if_true[0:1],
"//conditions:default": if_false "//conditions:default": if_false,
}) })
def if_ccpp(if_true, if_false = []): def if_ccpp(if_true, if_false = []):
@ -24,5 +24,5 @@ def if_ccpp(if_true, if_false = []):
return select({ return select({
"@local_config_sycl//sycl:using_sycl_ccpp": if_true, "@local_config_sycl//sycl:using_sycl_ccpp": if_true,
"@local_config_sycl//sycl:using_sycl_trisycl": if_false, "@local_config_sycl//sycl:using_sycl_trisycl": if_false,
"//conditions:default": if_false "//conditions:default": if_false,
}) })

View File

@ -30,6 +30,7 @@ def auto_configure_fail(msg):
red = "\033[0;31m" red = "\033[0;31m"
no_color = "\033[0m" no_color = "\033[0m"
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg)) fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
# END cc_configure common functions (see TODO above). # END cc_configure common functions (see TODO above).
def find_c(repository_ctx): def find_c(repository_ctx):
@ -79,7 +80,6 @@ def find_python_lib(repository_ctx):
return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip() return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip()
fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure") fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure")
def _check_lib(repository_ctx, toolkit_path, lib): def _check_lib(repository_ctx, toolkit_path, lib):
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't. """Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
@ -120,13 +120,15 @@ def _tpl(repository_ctx, tpl, substitutions={}, out=None):
repository_ctx.template( repository_ctx.template(
out, out,
Label("//third_party/sycl/%s.tpl" % tpl), Label("//third_party/sycl/%s.tpl" % tpl),
substitutions) substitutions,
)
def _file(repository_ctx, label): def _file(repository_ctx, label):
repository_ctx.template( repository_ctx.template(
label.replace(":", "/"), label.replace(":", "/"),
Label("//third_party/sycl/%s" % label), Label("//third_party/sycl/%s" % label),
{}) {},
)
_DUMMY_CROSSTOOL_BZL_FILE = """ _DUMMY_CROSSTOOL_BZL_FILE = """
def error_sycl_disabled(): def error_sycl_disabled():
@ -147,7 +149,6 @@ def error_sycl_disabled():
) )
""" """
_DUMMY_CROSSTOOL_BUILD_FILE = """ _DUMMY_CROSSTOOL_BUILD_FILE = """
load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled") load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled")
@ -169,11 +170,12 @@ def _create_dummy_repository(repository_ctx):
# If sycl_configure is not configured to build with SYCL support, and the user # If sycl_configure is not configured to build with SYCL support, and the user
# attempts to build with --config=sycl, add a dummy build rule to intercept # attempts to build with --config=sycl, add a dummy build rule to intercept
# this and fail with an actionable error message. # this and fail with an actionable error message.
repository_ctx.file("crosstool/error_sycl_disabled.bzl", repository_ctx.file(
_DUMMY_CROSSTOOL_BZL_FILE) "crosstool/error_sycl_disabled.bzl",
_DUMMY_CROSSTOOL_BZL_FILE,
)
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
def _sycl_autoconf_imp(repository_ctx): def _sycl_autoconf_imp(repository_ctx):
"""Implementation of the sycl_autoconf rule.""" """Implementation of the sycl_autoconf rule."""
if not _enable_sycl(repository_ctx): if not _enable_sycl(repository_ctx):
@ -187,22 +189,28 @@ def _sycl_autoconf_imp(repository_ctx):
_file(repository_ctx, "sycl:LICENSE.text") _file(repository_ctx, "sycl:LICENSE.text")
if _enable_compute_cpp(repository_ctx): if _enable_compute_cpp(repository_ctx):
_tpl(repository_ctx, "crosstool:computecpp", _tpl(
repository_ctx,
"crosstool:computecpp",
{ {
"%{host_cxx_compiler}": find_cc(repository_ctx), "%{host_cxx_compiler}": find_cc(repository_ctx),
"%{host_c_compiler}" : find_c(repository_ctx) "%{host_c_compiler}": find_c(repository_ctx),
}) },
)
computecpp_root = find_computecpp_root(repository_ctx); computecpp_root = find_computecpp_root(repository_ctx)
_check_dir(repository_ctx, computecpp_root) _check_dir(repository_ctx, computecpp_root)
_tpl(repository_ctx, "crosstool:CROSSTOOL", _tpl(
repository_ctx,
"crosstool:CROSSTOOL",
{ {
"%{sycl_include_dir}": computecpp_root, "%{sycl_include_dir}": computecpp_root,
"%{sycl_impl}": "computecpp", "%{sycl_impl}": "computecpp",
"%{c++_std}": "-std=c++11", "%{c++_std}": "-std=c++11",
"%{python_lib_path}": find_python_lib(repository_ctx), "%{python_lib_path}": find_python_lib(repository_ctx),
}) },
)
# symlink libraries # symlink libraries
_check_lib(repository_ctx, computecpp_root + "/lib", "libComputeCpp.so") _check_lib(repository_ctx, computecpp_root + "/lib", "libComputeCpp.so")
@ -210,29 +218,32 @@ def _sycl_autoconf_imp(repository_ctx):
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include") _symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin") _symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
else: else:
trisycl_include_dir = find_trisycl_include_dir(repository_ctx)
trisycl_include_dir = find_trisycl_include_dir(repository_ctx);
_check_dir(repository_ctx, trisycl_include_dir) _check_dir(repository_ctx, trisycl_include_dir)
_tpl(repository_ctx, "crosstool:trisycl", _tpl(
repository_ctx,
"crosstool:trisycl",
{ {
"%{host_cxx_compiler}": find_cc(repository_ctx), "%{host_cxx_compiler}": find_cc(repository_ctx),
"%{host_c_compiler}": find_c(repository_ctx), "%{host_c_compiler}": find_c(repository_ctx),
"%{trisycl_include_dir}" : trisycl_include_dir "%{trisycl_include_dir}": trisycl_include_dir,
}) },
)
_tpl(
_tpl(repository_ctx, "crosstool:CROSSTOOL", repository_ctx,
"crosstool:CROSSTOOL",
{ {
"%{sycl_include_dir}": trisycl_include_dir, "%{sycl_include_dir}": trisycl_include_dir,
"%{sycl_impl}": "trisycl", "%{sycl_impl}": "trisycl",
"%{c++_std}": "-std=c++1y", "%{c++_std}": "-std=c++1y",
"%{python_lib_path}": find_python_lib(repository_ctx), "%{python_lib_path}": find_python_lib(repository_ctx),
}) },
)
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include") _symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include")
sycl_configure = repository_rule( sycl_configure = repository_rule(
implementation = _sycl_autoconf_imp, implementation = _sycl_autoconf_imp,
local = True, local = True,

View File

@ -10,13 +10,13 @@
load( load(
"//third_party/gpus:cuda_configure.bzl", "//third_party/gpus:cuda_configure.bzl",
"auto_configure_fail", "auto_configure_fail",
"get_cpu_value",
"find_cuda_define", "find_cuda_define",
"find_lib", "find_lib",
"get_cpu_value",
"lib_name", "lib_name",
"matches_version",
"make_copy_dir_rule", "make_copy_dir_rule",
"make_copy_files_rule", "make_copy_files_rule",
"matches_version",
) )
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
@ -30,7 +30,6 @@ _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH" _DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
def _headers_exist(repository_ctx, path): def _headers_exist(repository_ctx, path):
"""Returns whether all TensorRT header files could be found in 'path'. """Returns whether all TensorRT header files could be found in 'path'.
@ -46,7 +45,6 @@ def _headers_exist(repository_ctx, path):
return False return False
return True return True
def _find_trt_header_dir(repository_ctx, trt_install_path): def _find_trt_header_dir(repository_ctx, trt_install_path):
"""Returns the path to the directory containing headers of TensorRT. """Returns the path to the directory containing headers of TensorRT.
@ -69,8 +67,8 @@ def _find_trt_header_dir(repository_ctx, trt_install_path):
if _headers_exist(repository_ctx, path): if _headers_exist(repository_ctx, path):
return path return path
auto_configure_fail( auto_configure_fail(
"Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path) "Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path,
)
def _trt_lib_version(repository_ctx, trt_install_path): def _trt_lib_version(repository_ctx, trt_install_path):
"""Detects the library (e.g. libnvinfer) version of TensorRT. """Detects the library (e.g. libnvinfer) version of TensorRT.
@ -83,23 +81,36 @@ def _trt_lib_version(repository_ctx, trt_install_path):
A string containing the library version of TensorRT. A string containing the library version of TensorRT.
""" """
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path) trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
major_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h", major_version = find_cuda_define(
_DEFINE_TENSORRT_SONAME_MAJOR) repository_ctx,
minor_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h", trt_header_dir,
_DEFINE_TENSORRT_SONAME_MINOR) "NvInfer.h",
patch_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h", _DEFINE_TENSORRT_SONAME_MAJOR,
_DEFINE_TENSORRT_SONAME_PATCH) )
minor_version = find_cuda_define(
repository_ctx,
trt_header_dir,
"NvInfer.h",
_DEFINE_TENSORRT_SONAME_MINOR,
)
patch_version = find_cuda_define(
repository_ctx,
trt_header_dir,
"NvInfer.h",
_DEFINE_TENSORRT_SONAME_PATCH,
)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version) full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip() environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip()
if not matches_version(environ_version, full_version): if not matches_version(environ_version, full_version):
auto_configure_fail( auto_configure_fail(
("TensorRT library version detected from %s/%s (%s) does not match " + ("TensorRT library version detected from %s/%s (%s) does not match " +
"TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") % "TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") %
(trt_header_dir, "NvInfer.h", full_version, environ_version)) (trt_header_dir, "NvInfer.h", full_version, environ_version),
)
# Only use the major version to match the SONAME of the library. # Only use the major version to match the SONAME of the library.
return major_version return major_version
def _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version): def _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version):
"""Finds the given TensorRT library on the system. """Finds the given TensorRT library on the system.
@ -121,11 +132,12 @@ def _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version)
result[file_name] = path result[file_name] = path
return result return result
def _tpl(repository_ctx, tpl, substitutions): def _tpl(repository_ctx, tpl, substitutions):
repository_ctx.template(tpl, Label("//third_party/tensorrt:%s.tpl" % tpl), repository_ctx.template(
substitutions) tpl,
Label("//third_party/tensorrt:%s.tpl" % tpl),
substitutions,
)
def _create_dummy_repository(repository_ctx): def _create_dummy_repository(repository_ctx):
"""Create a dummy TensorRT repository.""" """Create a dummy TensorRT repository."""
@ -134,7 +146,7 @@ def _create_dummy_repository(repository_ctx):
_tpl(repository_ctx, "BUILD", { _tpl(repository_ctx, "BUILD", {
"%{tensorrt_genrules}": "", "%{tensorrt_genrules}": "",
"%{tensorrt_headers}": "[]", "%{tensorrt_headers}": "[]",
"%{tensorrt_libs}": "[]" "%{tensorrt_libs}": "[]",
}) })
def _tensorrt_configure_impl(repository_ctx): def _tensorrt_configure_impl(repository_ctx):
@ -162,7 +174,8 @@ def _tensorrt_configure_impl(repository_ctx):
trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip() trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip()
if not repository_ctx.path(trt_install_path).exists: if not repository_ctx.path(trt_install_path).exists:
auto_configure_fail( auto_configure_fail(
"Cannot find TensorRT install path %s." % trt_install_path) "Cannot find TensorRT install path %s." % trt_install_path,
)
# Copy the library files. # Copy the library files.
trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path) trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path)
@ -182,10 +195,12 @@ def _tensorrt_configure_impl(repository_ctx):
# Copy the header files header files. # Copy the header files header files.
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path) trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
trt_header_srcs = [ trt_header_srcs = [
"%s/%s" % (trt_header_dir, header) for header in _TF_TENSORRT_HEADERS "%s/%s" % (trt_header_dir, header)
for header in _TF_TENSORRT_HEADERS
] ]
trt_header_outs = [ trt_header_outs = [
"tensorrt/include/" + header for header in _TF_TENSORRT_HEADERS "tensorrt/include/" + header
for header in _TF_TENSORRT_HEADERS
] ]
copy_rules.append( copy_rules.append(
make_copy_files_rule( make_copy_files_rule(
@ -193,7 +208,8 @@ def _tensorrt_configure_impl(repository_ctx):
name = "tensorrt_include", name = "tensorrt_include",
srcs = trt_header_srcs, srcs = trt_header_srcs,
outs = trt_header_outs, outs = trt_header_outs,
)) ),
)
# Set up config file. # Set up config file.
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"}) _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
@ -205,7 +221,6 @@ def _tensorrt_configure_impl(repository_ctx):
"%{tensorrt_libs}": str(trt_lib_outs), "%{tensorrt_libs}": str(trt_lib_outs),
}) })
tensorrt_configure = repository_rule( tensorrt_configure = repository_rule(
implementation = _tensorrt_configure_impl, implementation = _tensorrt_configure_impl,
environ = [ environ = [

View File

@ -5,26 +5,33 @@ def _clang6_configure(ctx):
# method to generate a gigantic CROSSTOOL file that allows # method to generate a gigantic CROSSTOOL file that allows
# Clang to support everything. # Clang to support everything.
ctx.symlink( ctx.symlink(
ctx.os.environ.get('TF_LLVM_PATH', ctx.os.environ.get(
'/usr/lib/llvm-6.0'), "TF_LLVM_PATH",
'clang6/llvm') "/usr/lib/llvm-6.0",
),
"clang6/llvm",
)
ctx.symlink( ctx.symlink(
ctx.os.environ.get('STRIP', '/usr/bin/strip'), ctx.os.environ.get("STRIP", "/usr/bin/strip"),
'clang6/sbin/strip') "clang6/sbin/strip",
)
ctx.symlink( ctx.symlink(
ctx.os.environ.get('OBJDUMP', '/usr/bin/objdump'), ctx.os.environ.get("OBJDUMP", "/usr/bin/objdump"),
'clang6/sbin/objdump') "clang6/sbin/objdump",
ctx.symlink(ctx.attr._build, 'clang6/BUILD') )
ctx.template('clang6/CROSSTOOL', ctx.attr._crosstool, { ctx.symlink(ctx.attr._build, "clang6/BUILD")
'%package(@local_config_clang6//clang6)%': str(ctx.path('clang6')), ctx.template("clang6/CROSSTOOL", ctx.attr._crosstool, {
"%package(@local_config_clang6//clang6)%": str(ctx.path("clang6")),
}) })
clang6_configure = repository_rule( clang6_configure = repository_rule(
implementation = _clang6_configure, implementation = _clang6_configure,
attrs = { attrs = {
'_build': attr.label( "_build": attr.label(
default=str(Label('//third_party/toolchains/clang6:clang.BUILD'))), default = str(Label("//third_party/toolchains/clang6:clang.BUILD")),
'_crosstool': attr.label( ),
default=str(Label('//third_party/toolchains/clang6:CROSSTOOL.tpl'))), "_crosstool": attr.label(
default = str(Label("//third_party/toolchains/clang6:CROSSTOOL.tpl")),
),
}, },
) )

View File

@ -7,8 +7,8 @@ def _tpl(repository_ctx, tpl, substitutions={}, out=None):
repository_ctx.template( repository_ctx.template(
out, out,
Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl), Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl),
substitutions) substitutions,
)
def _arm_compiler_configure_impl(repository_ctx): def _arm_compiler_configure_impl(repository_ctx):
# We need to find a cross-compilation include directory for Python, so look # We need to find a cross-compilation include directory for Python, so look
@ -23,12 +23,12 @@ def _arm_compiler_configure_impl(repository_ctx):
python_include_path = "/usr/include/python2.7" python_include_path = "/usr/include/python2.7"
_tpl(repository_ctx, "CROSSTOOL", { _tpl(repository_ctx, "CROSSTOOL", {
"%{ARM_COMPILER_PATH}%": str(repository_ctx.path( "%{ARM_COMPILER_PATH}%": str(repository_ctx.path(
repository_ctx.attr.remote_config_repo)), repository_ctx.attr.remote_config_repo,
)),
"%{PYTHON_INCLUDE_PATH}%": python_include_path, "%{PYTHON_INCLUDE_PATH}%": python_include_path,
}) })
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD") repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
arm_compiler_configure = repository_rule( arm_compiler_configure = repository_rule(
implementation = _arm_compiler_configure_impl, implementation = _arm_compiler_configure_impl,
attrs = { attrs = {

View File

@ -27,6 +27,7 @@ def _tensorflow_rbe_config(name, compiler, python_version, cuda_version = None,
if cuda_version != None: if cuda_version != None:
base = "@cuda%s-cudnn%s-ubuntu14.04//image" % (cuda_version, cudnn_version) base = "@cuda%s-cudnn%s-ubuntu14.04//image" % (cuda_version, cudnn_version)
# The cuda toolchain currently contains its own C++ toolchain definition, # The cuda toolchain currently contains its own C++ toolchain definition,
# so we do not fetch local_config_cc. # so we do not fetch local_config_cc.
config_repos = [ config_repos = [