Formatting BUILD and bzl files.

PiperOrigin-RevId: 234543907
This commit is contained in:
A. Unique TensorFlower 2019-02-18 22:19:08 -08:00 committed by TensorFlower Gardener
parent 6f7a0a83c1
commit a0b0a50328
26 changed files with 2088 additions and 2032 deletions

View File

@ -18,13 +18,12 @@
# git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl # git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl
plugins = { plugins = {
#"example": { #"example": {
# "device":"XLA_MY_DEVICE", # "device":"XLA_MY_DEVICE",
# "types":"DT_FLOAT,DT_HALF,DT_INT32", # "types":"DT_FLOAT,DT_HALF,DT_INT32",
# "tags":[], # "tags":[],
# "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"], # "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"],
# "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"], # "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"],
# "deps":[], # "deps":[],
#}, #},
} }

View File

@ -1,12 +1,11 @@
"""build_defs for service/cpu.""" """build_defs for service/cpu."""
def runtime_copts(): def runtime_copts():
"""Returns copts used for CPU runtime libraries.""" """Returns copts used for CPU runtime libraries."""
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({ return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
"//tensorflow:android_arm": ["-mfpu=neon"], "//tensorflow:android_arm": ["-mfpu=neon"],
"//conditions:default": [] "//conditions:default": [],
}) + select({ }) + select({
"//tensorflow:android": ["-O2"], "//tensorflow:android": ["-O2"],
"//conditions:default": [] "//conditions:default": [],
})) }))

View File

@ -33,4 +33,3 @@
# } # }
plugins = {} plugins = {}

View File

@ -5,55 +5,52 @@ CUDNN_VERSION = ""
PLATFORM = "" PLATFORM = ""
def cuda_sdk_version(): def cuda_sdk_version():
return CUDA_VERSION return CUDA_VERSION
def cudnn_sdk_version(): def cudnn_sdk_version():
return CUDNN_VERSION return CUDNN_VERSION
def cuda_library_path(name, version = cuda_sdk_version()): def cuda_library_path(name, version = cuda_sdk_version()):
if PLATFORM == "Darwin": if PLATFORM == "Darwin":
if not version: if not version:
return "lib/lib{}.dylib".format(name) return "lib/lib{}.dylib".format(name)
else:
return "lib/lib{}.{}.dylib".format(name, version)
elif not version:
return "lib64/lib{}.so".format(name)
else: else:
return "lib/lib{}.{}.dylib".format(name, version) return "lib64/lib{}.so.{}".format(name, version)
else:
if not version:
return "lib64/lib{}.so".format(name)
else:
return "lib64/lib{}.so.{}".format(name, version)
def cuda_static_library_path(name): def cuda_static_library_path(name):
if PLATFORM == "Darwin": if PLATFORM == "Darwin":
return "lib/lib{}_static.a".format(name) return "lib/lib{}_static.a".format(name)
else: else:
return "lib64/lib{}_static.a".format(name) return "lib64/lib{}_static.a".format(name)
def cudnn_library_path(version = cudnn_sdk_version()): def cudnn_library_path(version = cudnn_sdk_version()):
if PLATFORM == "Darwin": if PLATFORM == "Darwin":
if not version: if not version:
return "lib/libcudnn.dylib" return "lib/libcudnn.dylib"
else:
return "lib/libcudnn.{}.dylib".format(version)
elif not version:
return "lib64/libcudnn.so"
else: else:
return "lib/libcudnn.{}.dylib".format(version) return "lib64/libcudnn.so.{}".format(version)
else:
if not version:
return "lib64/libcudnn.so"
else:
return "lib64/libcudnn.so.{}".format(version)
def cupti_library_path(version = cuda_sdk_version()): def cupti_library_path(version = cuda_sdk_version()):
if PLATFORM == "Darwin": if PLATFORM == "Darwin":
if not version: if not version:
return "extras/CUPTI/lib/libcupti.dylib" return "extras/CUPTI/lib/libcupti.dylib"
else:
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
elif not version:
return "extras/CUPTI/lib64/libcupti.so"
else: else:
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version) return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
else:
if not version:
return "extras/CUPTI/lib64/libcupti.so"
else:
return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
def readlink_command(): def readlink_command():
if PLATFORM == "Darwin": if PLATFORM == "Darwin":
return "greadlink" return "greadlink"
else: else:
return "readlink" return "readlink"

View File

@ -18,7 +18,7 @@ XLINT_OPTS = [
"-Xlint:-processing", "-Xlint:-processing",
"-Xlint:-serial", "-Xlint:-serial",
"-Xlint:-try", "-Xlint:-try",
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile "-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
] ]
# The bazel errorprone plugin currently only enables default errorChecks # The bazel errorprone plugin currently only enables default errorChecks

View File

@ -17,46 +17,48 @@ load(
# and then archive those source files into # and then archive those source files into
# ops/gen_sources.srcjar # ops/gen_sources.srcjar
# #
def tf_java_op_gen_srcjar(name, def tf_java_op_gen_srcjar(
gen_tool, name,
base_package, gen_tool,
api_def_srcs=[], base_package,
out_dir="ops/", api_def_srcs = [],
out_src_dir="src/main/java/", out_dir = "ops/",
visibility=["//tensorflow/java:__pkg__"]): out_src_dir = "src/main/java/",
visibility = ["//tensorflow/java:__pkg__"]):
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
srcs = api_def_srcs[:]
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files if not api_def_srcs:
srcs = api_def_srcs[:] api_def_args_str = ","
else:
api_def_args = []
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))",
)
api_def_args_str = ",".join(api_def_args)
if not api_def_srcs: gen_cmds += ["$(location " + gen_tool + ")" +
api_def_args_str = "," " --output_dir=$(@D)/" + out_src_dir +
else: " --base_package=" + base_package +
api_def_args = [] " --api_dirs=" + api_def_args_str]
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
gen_cmds += ["$(location " + gen_tool + ")" + # Generate a source archive containing generated code for these ops.
" --output_dir=$(@D)/" + out_src_dir + gen_srcjar = out_dir + name + ".srcjar"
" --base_package=" + base_package + gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
" --api_dirs=" + api_def_args_str]
# Generate a source archive containing generated code for these ops. native.genrule(
gen_srcjar = out_dir + name + ".srcjar" name = name,
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] srcs = srcs,
outs = [gen_srcjar],
native.genrule( tools = [
name=name, "@local_jdk//:jar",
srcs=srcs, "@local_jdk//:jdk",
outs=[gen_srcjar], gen_tool,
tools=[ ] + tf_binary_additional_srcs(),
"@local_jdk//:jar", cmd = " && ".join(gen_cmds),
"@local_jdk//:jdk", )
gen_tool
] + tf_binary_additional_srcs(),
cmd=" && ".join(gen_cmds))

View File

@ -12,22 +12,26 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
# consumers of the tf_gen_op_wrapper_py rule would be simplified if we don't # consumers of the tf_gen_op_wrapper_py rule would be simplified if we don't
# hard code the ops/ directory. # hard code the ops/ directory.
def tf_gen_op_wrapper_private_py(name, out=None, deps=[], def tf_gen_op_wrapper_private_py(
require_shape_functions=True, name,
visibility=[]): out = None,
if not name.endswith("_gen"): deps = [],
fail("name must end in _gen") require_shape_functions = True,
if not visibility: visibility = []):
visibility = ["//visibility:private"] if not name.endswith("_gen"):
bare_op_name = name[:-4] # Strip off the _gen fail("name must end in _gen")
tf_gen_op_wrapper_py(name=bare_op_name, if not visibility:
out=out, visibility = ["//visibility:private"]
visibility=visibility, bare_op_name = name[:-4] # Strip off the _gen
deps=deps, tf_gen_op_wrapper_py(
require_shape_functions=require_shape_functions, name = bare_op_name,
generated_target_name=name, out = out,
api_def_srcs = [ visibility = visibility,
"//tensorflow/core/api_def:base_api_def", deps = deps,
"//tensorflow/core/api_def:python_api_def", require_shape_functions = require_shape_functions,
], generated_target_name = name,
) api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:python_api_def",
],
)

View File

@ -17,5 +17,5 @@ def tf_additional_cudnn_plugin_deps():
# Returns whether any GPU backend is configuered. # Returns whether any GPU backend is configuered.
def if_gpu_is_configured(x): def if_gpu_is_configured(x):
if cuda_is_configured() or rocm_is_configured(): if cuda_is_configured() or rocm_is_configured():
return x return x
return [] return []

View File

@ -45,6 +45,7 @@ load(
"//third_party/ngraph:build_defs.bzl", "//third_party/ngraph:build_defs.bzl",
"if_ngraph", "if_ngraph",
) )
def register_extension_info(**kwargs): def register_extension_info(**kwargs):
pass pass
@ -1463,7 +1464,7 @@ def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], **kw
def tf_custom_op_library_additional_deps(): def tf_custom_op_library_additional_deps():
return [ return [
"@protobuf_archive//:protobuf_headers", "@protobuf_archive//:protobuf_headers",
clean_dep("//third_party/eigen3"), clean_dep("//third_party/eigen3"),
clean_dep("//tensorflow/core:framework_headers_lib"), clean_dep("//tensorflow/core:framework_headers_lib"),
] + if_windows(["//tensorflow/python:pywrap_tensorflow_import_lib"]) ] + if_windows(["//tensorflow/python:pywrap_tensorflow_import_lib"])
@ -1473,8 +1474,8 @@ def tf_custom_op_library_additional_deps():
# exporting symbols from _pywrap_tensorflow.dll on Windows. # exporting symbols from _pywrap_tensorflow.dll on Windows.
def tf_custom_op_library_additional_deps_impl(): def tf_custom_op_library_additional_deps_impl():
return [ return [
"@protobuf_archive//:protobuf", "@protobuf_archive//:protobuf",
"@nsync//:nsync_cpp", "@nsync//:nsync_cpp",
# for //third_party/eigen3 # for //third_party/eigen3
clean_dep("//third_party/eigen3"), clean_dep("//third_party/eigen3"),
# for //tensorflow/core:framework_headers_lib # for //tensorflow/core:framework_headers_lib

View File

@ -4,60 +4,66 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test")
# Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test) # Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test)
def tf_cc_logged_benchmark( def tf_cc_logged_benchmark(
name=None, name = None,
target=None, target = None,
benchmarks="..", benchmarks = "..",
tags=[], tags = [],
test_log_output_prefix="", test_log_output_prefix = "",
benchmark_type="cpp_microbenchmark"): benchmark_type = "cpp_microbenchmark"):
if not name: if not name:
fail("Must provide a name") fail("Must provide a name")
if not target: if not target:
fail("Must provide a target") fail("Must provide a target")
if (not ":" in target if (not ":" in target or
or not target.startswith("//") not target.startswith("//") or
or target.endswith(":all") target.endswith(":all") or
or target.endswith(".")): target.endswith(".")):
fail(" ".join(("Target must be a single well-defined test, e.g.,", fail(" ".join((
"//path/to:test. Received: %s" % target))) "Target must be a single well-defined test, e.g.,",
"//path/to:test. Received: %s" % target,
)))
all_tags = ( all_tags = (
depset(tags) + depset( depset(tags) + depset(
["benchmark-test", "local", "manual", "regression-test"])).to_list() ["benchmark-test", "local", "manual", "regression-test"],
)
).to_list()
tf_py_test( tf_py_test(
name = name, name = name,
tags = all_tags, tags = all_tags,
size = "large", size = "large",
srcs = ["//tensorflow/tools/test:run_and_gather_logs"], srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
args = [ args = [
"--name=//%s:%s" % (native.package_name(), name), "--name=//%s:%s" % (native.package_name(), name),
"--test_name=" + target, "--test_name=" + target,
"--test_args=--benchmarks=%s" % benchmarks, "--test_args=--benchmarks=%s" % benchmarks,
"--benchmark_type=%s" % benchmark_type, "--benchmark_type=%s" % benchmark_type,
], ],
data = [ data = [
target, target,
], ],
main = "run_and_gather_logs.py", main = "run_and_gather_logs.py",
additional_deps = [ additional_deps = [
"//tensorflow/tools/test:run_and_gather_logs" "//tensorflow/tools/test:run_and_gather_logs",
]) ],
)
# Create a benchmark test target of a TensorFlow python test (*py_tests) # Create a benchmark test target of a TensorFlow python test (*py_tests)
def tf_py_logged_benchmark( def tf_py_logged_benchmark(
name=None, name = None,
target=None, target = None,
benchmarks="..", benchmarks = "..",
tags=[], tags = [],
test_log_output_prefix=""): test_log_output_prefix = ""):
# For now generating a py benchmark is the same as generating a C++ # For now generating a py benchmark is the same as generating a C++
# benchmark target. In the future this may change, so we have # benchmark target. In the future this may change, so we have
# two macros just in case # two macros just in case
tf_cc_logged_benchmark( tf_cc_logged_benchmark(
name=name, name = name,
target=target, target = target,
benchmarks=benchmarks, benchmarks = benchmarks,
tags=tags, tags = tags,
test_log_output_prefix=test_log_output_prefix, test_log_output_prefix = test_log_output_prefix,
benchmark_type="python_benchmark") benchmark_type = "python_benchmark",
)

View File

@ -7,7 +7,6 @@ load("//third_party:nccl/nccl_configure.bzl", "nccl_configure")
load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/mkl:build_defs.bzl", "mkl_repository")
load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/git:git_configure.bzl", "git_configure")
load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/py:python_configure.bzl", "python_configure")
load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure") load("//third_party/systemlibs:syslibs_configure.bzl", "syslibs_configure")
load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure") load("//third_party/toolchains/remote:configure.bzl", "remote_execution_configure")

View File

@ -1,9 +1,9 @@
"""Set up configurable Android SDK and NDK dependencies.""" """Set up configurable Android SDK and NDK dependencies."""
def android_workspace(): def android_workspace():
# String for replacement in Bazel template. # String for replacement in Bazel template.
# These will either be replaced by android_sdk_repository if various ENV # These will either be replaced by android_sdk_repository if various ENV
# variables are set when `local_config_android` repo_rule is run, or they # variables are set when `local_config_android` repo_rule is run, or they
# will be replaced by noops otherwise. # will be replaced by noops otherwise.
MAYBE_ANDROID_SDK_REPOSITORY MAYBE_ANDROID_SDK_REPOSITORY
MAYBE_ANDROID_NDK_REPOSITORY MAYBE_ANDROID_NDK_REPOSITORY

View File

@ -36,33 +36,39 @@ _ANDROID_NDK_REPO_TEMPLATE = """
""" """
def _android_autoconf_impl(repository_ctx): def _android_autoconf_impl(repository_ctx):
"""Implementation of the android_autoconf repository rule.""" """Implementation of the android_autoconf repository rule."""
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME) sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION) sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
build_tools_version = repository_ctx.os.environ.get( build_tools_version = repository_ctx.os.environ.get(
_ANDROID_BUILD_TOOLS_VERSION) _ANDROID_BUILD_TOOLS_VERSION,
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME) )
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION) ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
sdk_rule = "pass" sdk_rule = "pass"
if all([sdk_home, sdk_api_level, build_tools_version]): if all([sdk_home, sdk_api_level, build_tools_version]):
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % ( sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
sdk_home, sdk_api_level, build_tools_version) sdk_home,
sdk_api_level,
build_tools_version,
)
ndk_rule = "pass" ndk_rule = "pass"
if all([ndk_home, ndk_api_level]): if all([ndk_home, ndk_api_level]):
ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level) ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
repository_ctx.template( repository_ctx.template(
"BUILD", "BUILD",
Label("//third_party/android:android_configure.BUILD.tpl")) Label("//third_party/android:android_configure.BUILD.tpl"),
repository_ctx.template( )
"android.bzl", repository_ctx.template(
Label("//third_party/android:android.bzl.tpl"), "android.bzl",
substitutions={ Label("//third_party/android:android.bzl.tpl"),
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule, substitutions = {
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule, "MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
}) "MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
},
)
android_configure = repository_rule( android_configure = repository_rule(
implementation = _android_autoconf_impl, implementation = _android_autoconf_impl,

View File

@ -21,11 +21,11 @@
# substitutions: A dictionary mapping strings to their substitutions # substitutions: A dictionary mapping strings to their substitutions
def template_rule_impl(ctx): def template_rule_impl(ctx):
ctx.template_action( ctx.template_action(
template = ctx.file.src, template = ctx.file.src,
output = ctx.outputs.out, output = ctx.outputs.out,
substitutions = ctx.attr.substitutions, substitutions = ctx.attr.substitutions,
) )
template_rule = rule( template_rule = rule(
attrs = { attrs = {

View File

@ -8,49 +8,57 @@
_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
def _fail(msg): def _fail(msg):
"""Output failure message when auto configuration fails.""" """Output failure message when auto configuration fails."""
red = "\033[0;31m" red = "\033[0;31m"
no_color = "\033[0m" no_color = "\033[0m"
fail("%sGit Configuration Error:%s %s\n" % (red, no_color, msg)) fail("%sGit Configuration Error:%s %s\n" % (red, no_color, msg))
def _get_python_bin(repository_ctx): def _get_python_bin(repository_ctx):
"""Gets the python bin path.""" """Gets the python bin path."""
python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
if python_bin != None: if python_bin != None:
return python_bin return python_bin
python_bin_path = repository_ctx.which("python") python_bin_path = repository_ctx.which("python")
if python_bin_path != None: if python_bin_path != None:
return str(python_bin_path) return str(python_bin_path)
_fail("Cannot find python in PATH, please make sure " + _fail("Cannot find python in PATH, please make sure " +
"python is installed and add its directory in PATH, or --define " + "python is installed and add its directory in PATH, or --define " +
"%s='/something/else'.\nPATH=%s" % ( "%s='/something/else'.\nPATH=%s" % (
_PYTHON_BIN_PATH, repository_ctx.os.environ.get("PATH", ""))) _PYTHON_BIN_PATH,
repository_ctx.os.environ.get("PATH", ""),
))
def _git_conf_impl(repository_ctx): def _git_conf_impl(repository_ctx):
repository_ctx.template( repository_ctx.template(
"BUILD", "BUILD",
Label("//third_party/git:BUILD.tpl")) Label("//third_party/git:BUILD.tpl"),
)
tensorflow_root_path = str(repository_ctx.path( tensorflow_root_path = str(repository_ctx.path(
Label("@org_tensorflow//:BUILD")))[:-len("BUILD")] Label("@org_tensorflow//:BUILD"),
python_script_path = repository_ctx.path( ))[:-len("BUILD")]
Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py")) python_script_path = repository_ctx.path(
generated_files_path = repository_ctx.path("gen") Label("@org_tensorflow//tensorflow/tools/git:gen_git_source.py"),
)
generated_files_path = repository_ctx.path("gen")
r = repository_ctx.execute( r = repository_ctx.execute(
["test", "-f", "%s/.git/logs/HEAD" % tensorflow_root_path]) ["test", "-f", "%s/.git/logs/HEAD" % tensorflow_root_path],
if r.return_code == 0: )
unused_var = repository_ctx.path(Label("//:.git/HEAD")) # pylint: disable=unused-variable if r.return_code == 0:
unused_var = repository_ctx.path(Label("//:.git/HEAD")) # pylint: disable=unused-variable
result = repository_ctx.execute([ result = repository_ctx.execute([
_get_python_bin(repository_ctx), _get_python_bin(repository_ctx),
python_script_path, "--configure", tensorflow_root_path, python_script_path,
"--gen_root_path", generated_files_path], quiet=False) "--configure",
tensorflow_root_path,
if not result.return_code == 0: "--gen_root_path",
_fail(result.stderr) generated_files_path,
], quiet = False)
if not result.return_code == 0:
_fail(result.stderr)
git_configure = repository_rule( git_configure = repository_rule(
implementation = _git_conf_impl, implementation = _git_conf_impl,

File diff suppressed because it is too large Load Diff

View File

@ -242,11 +242,16 @@ def _hipcc_env(repository_ctx):
A string containing environment variables for hipcc. A string containing environment variables for hipcc.
""" """
hipcc_env = "" hipcc_env = ""
for name in ["HIP_CLANG_PATH", "DEVICE_LIB_PATH", "HIP_VDI_HOME",\ for name in [
"HIPCC_VERBOSE", "HIPCC_COMPILE_FLAGS_APPEND"]: "HIP_CLANG_PATH",
"DEVICE_LIB_PATH",
"HIP_VDI_HOME",
"HIPCC_VERBOSE",
"HIPCC_COMPILE_FLAGS_APPEND",
]:
if name in repository_ctx.os.environ: if name in repository_ctx.os.environ:
hipcc_env = hipcc_env + " " + name + "=\"" + \ hipcc_env = hipcc_env + " " + name + "=\"" + \
repository_ctx.os.environ[name].strip() + "\";" repository_ctx.os.environ[name].strip() + "\";"
return hipcc_env.strip() return hipcc_env.strip()
def _crosstool_verbose(repository_ctx): def _crosstool_verbose(repository_ctx):
@ -636,7 +641,6 @@ def _create_local_rocm_repository(repository_ctx):
outs = rocm_lib_outs, outs = rocm_lib_outs,
)) ))
# Set up BUILD file for rocm/ # Set up BUILD file for rocm/
_tpl( _tpl(
repository_ctx, repository_ctx,

View File

@ -1,26 +1,25 @@
filegroup( filegroup(
name = "LICENSE", name = "LICENSE",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
cc_library( cc_library(
name = "nccl", name = "nccl",
srcs = ["libnccl.so.%{version}"], srcs = ["libnccl.so.%{version}"],
hdrs = ["nccl.h"], hdrs = ["nccl.h"],
include_prefix = "third_party/nccl", include_prefix = "third_party/nccl",
deps = [ visibility = ["//visibility:public"],
"@local_config_cuda//cuda:cuda_headers", deps = [
], "@local_config_cuda//cuda:cuda_headers",
visibility = ["//visibility:public"], ],
) )
genrule( genrule(
name = "nccl-files", name = "nccl-files",
outs = [ outs = [
"libnccl.so.%{version}", "libnccl.so.%{version}",
"nccl.h", "nccl.h",
], ],
cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" && cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" &&
cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """, cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
) )

View File

@ -11,300 +11,337 @@ _PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
_PYTHON_LIB_PATH = "PYTHON_LIB_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
_TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO" _TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO"
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
def _tpl(repository_ctx, tpl, substitutions={}, out=None): if not out:
if not out: out = tpl
out = tpl repository_ctx.template(
repository_ctx.template( out,
out, Label("//third_party/py:%s.tpl" % tpl),
Label("//third_party/py:%s.tpl" % tpl), substitutions,
substitutions) )
def _fail(msg): def _fail(msg):
"""Output failure message when auto configuration fails.""" """Output failure message when auto configuration fails."""
red = "\033[0;31m" red = "\033[0;31m"
no_color = "\033[0m" no_color = "\033[0m"
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
def _is_windows(repository_ctx): def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows.""" """Returns true if the host operating system is windows."""
os_name = repository_ctx.os.name.lower() os_name = repository_ctx.os.name.lower()
if os_name.find("windows") != -1: if os_name.find("windows") != -1:
return True return True
return False return False
def _execute(
repository_ctx,
cmdline,
error_msg = None,
error_details = None,
empty_stdout_fine = False):
"""Executes an arbitrary shell command.
def _execute(repository_ctx, cmdline, error_msg=None, error_details=None, Args:
empty_stdout_fine=False): repository_ctx: the repository_ctx object
"""Executes an arbitrary shell command. cmdline: list of strings, the command to execute
error_msg: string, a summary of the error if the command fails
Args: error_details: string, details about the error or steps to fix it
repository_ctx: the repository_ctx object empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
cmdline: list of strings, the command to execute it's an error
error_msg: string, a summary of the error if the command fails Return:
error_details: string, details about the error or steps to fix it the result of repository_ctx.execute(cmdline)
empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise """
it's an error result = repository_ctx.execute(cmdline)
Return: if result.stderr or not (empty_stdout_fine or result.stdout):
the result of repository_ctx.execute(cmdline) _fail("\n".join([
""" error_msg.strip() if error_msg else "Repository command failed",
result = repository_ctx.execute(cmdline) result.stderr.strip(),
if result.stderr or not (empty_stdout_fine or result.stdout): error_details if error_details else "",
_fail("\n".join([ ]))
error_msg.strip() if error_msg else "Repository command failed", return result
result.stderr.strip(),
error_details if error_details else ""]))
return result
def _read_dir(repository_ctx, src_dir): def _read_dir(repository_ctx, src_dir):
"""Returns a string with all files in a directory. """Returns a string with all files in a directory.
Finds all files inside a directory, traversing subfolders and following Finds all files inside a directory, traversing subfolders and following
symlinks. The returned string contains the full path of all files symlinks. The returned string contains the full path of all files
separated by line breaks. separated by line breaks.
""" """
if _is_windows(repository_ctx): if _is_windows(repository_ctx):
src_dir = src_dir.replace("/", "\\") src_dir = src_dir.replace("/", "\\")
find_result = _execute( find_result = _execute(
repository_ctx, ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], repository_ctx,
empty_stdout_fine=True) ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
# src_files will be used in genrule.outs where the paths must empty_stdout_fine = True,
# use forward slashes. )
result = find_result.stdout.replace("\\", "/")
else:
find_result = _execute(
repository_ctx, ["find", src_dir, "-follow", "-type", "f"],
empty_stdout_fine=True)
result = find_result.stdout
return result
# src_files will be used in genrule.outs where the paths must
# use forward slashes.
result = find_result.stdout.replace("\\", "/")
else:
find_result = _execute(
repository_ctx,
["find", src_dir, "-follow", "-type", "f"],
empty_stdout_fine = True,
)
result = find_result.stdout
return result
def _genrule(src_dir, genrule_name, command, outs): def _genrule(src_dir, genrule_name, command, outs):
"""Returns a string with a genrule. """Returns a string with a genrule.
Genrule executes the given command and produces the given outputs.
"""
return (
'genrule(\n' +
' name = "' +
genrule_name + '",\n' +
' outs = [\n' +
outs +
'\n ],\n' +
' cmd = """\n' +
command +
'\n """,\n' +
')\n'
)
Genrule executes the given command and produces the given outputs.
"""
return (
"genrule(\n" +
' name = "' +
genrule_name + '",\n' +
" outs = [\n" +
outs +
"\n ],\n" +
' cmd = """\n' +
command +
'\n """,\n' +
")\n"
)
def _norm_path(path): def _norm_path(path):
"""Returns a path with '/' and remove the trailing slash.""" """Returns a path with '/' and remove the trailing slash."""
path = path.replace("\\", "/") path = path.replace("\\", "/")
if path[-1] == "/": if path[-1] == "/":
path = path[:-1] path = path[:-1]
return path return path
def _symlink_genrule_for_dir(
repository_ctx,
src_dir,
dest_dir,
genrule_name,
src_files = [],
dest_files = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name, If src_dir is passed, files will be read from the given directory; otherwise
src_files = [], dest_files = []): we assume files are in src_files and dest_files
"""Returns a genrule to symlink(or copy if on Windows) a set of files. """
if src_dir != None:
src_dir = _norm_path(src_dir)
dest_dir = _norm_path(dest_dir)
files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
If src_dir is passed, files will be read from the given directory; otherwise # Create a list with the src_dir stripped to use for outputs.
we assume files are in src_files and dest_files dest_files = files.replace(src_dir, "").splitlines()
""" src_files = files.splitlines()
if src_dir != None: command = []
src_dir = _norm_path(src_dir) outs = []
dest_dir = _norm_path(dest_dir) for i in range(len(dest_files)):
files = '\n'.join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) if dest_files[i] != "":
# Create a list with the src_dir stripped to use for outputs. # If we have only one file to link we do not want to use the dest_dir, as
dest_files = files.replace(src_dir, '').splitlines() # $(@D) will include the full path to the file.
src_files = files.splitlines() dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
command = []
outs = []
for i in range(len(dest_files)):
if dest_files[i] != "":
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i]
# Copy the headers to create a sandboxable setup.
cmd = 'cp -f'
command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(src_dir, genrule_name, " && ".join(command),
"\n".join(outs))
return genrule
# Copy the headers to create a sandboxable setup.
cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(
src_dir,
genrule_name,
" && ".join(command),
"\n".join(outs),
)
return genrule
def _get_python_bin(repository_ctx): def _get_python_bin(repository_ctx):
"""Gets the python bin path.""" """Gets the python bin path."""
python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH) python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
if python_bin != None: if python_bin != None:
return python_bin return python_bin
python_bin_path = repository_ctx.which("python") python_bin_path = repository_ctx.which("python")
if python_bin_path != None: if python_bin_path != None:
return str(python_bin_path) return str(python_bin_path)
_fail("Cannot find python in PATH, please make sure " + _fail("Cannot find python in PATH, please make sure " +
"python is installed and add its directory in PATH, or --define " + "python is installed and add its directory in PATH, or --define " +
"%s='/something/else'.\nPATH=%s" % ( "%s='/something/else'.\nPATH=%s" % (
_PYTHON_BIN_PATH, repository_ctx.os.environ.get("PATH", ""))) _PYTHON_BIN_PATH,
repository_ctx.os.environ.get("PATH", ""),
))
def _get_bash_bin(repository_ctx): def _get_bash_bin(repository_ctx):
"""Gets the bash bin path.""" """Gets the bash bin path."""
bash_bin = repository_ctx.os.environ.get(_BAZEL_SH) bash_bin = repository_ctx.os.environ.get(_BAZEL_SH)
if bash_bin != None: if bash_bin != None:
return bash_bin return bash_bin
else:
bash_bin_path = repository_ctx.which("bash")
if bash_bin_path != None:
return str(bash_bin_path)
else: else:
_fail("Cannot find bash in PATH, please make sure " + bash_bin_path = repository_ctx.which("bash")
"bash is installed and add its directory in PATH, or --define " + if bash_bin_path != None:
"%s='/path/to/bash'.\nPATH=%s" % ( return str(bash_bin_path)
_BAZEL_SH, repository_ctx.os.environ.get("PATH", ""))) else:
_fail("Cannot find bash in PATH, please make sure " +
"bash is installed and add its directory in PATH, or --define " +
"%s='/path/to/bash'.\nPATH=%s" % (
_BAZEL_SH,
repository_ctx.os.environ.get("PATH", ""),
))
def _get_python_lib(repository_ctx, python_bin): def _get_python_lib(repository_ctx, python_bin):
"""Gets the python lib path.""" """Gets the python lib path."""
python_lib = repository_ctx.os.environ.get(_PYTHON_LIB_PATH) python_lib = repository_ctx.os.environ.get(_PYTHON_LIB_PATH)
if python_lib != None: if python_lib != None:
return python_lib return python_lib
print_lib = ("<<END\n" + print_lib = ("<<END\n" +
"from __future__ import print_function\n" + "from __future__ import print_function\n" +
"import site\n" + "import site\n" +
"import os\n" + "import os\n" +
"\n" + "\n" +
"try:\n" + "try:\n" +
" input = raw_input\n" + " input = raw_input\n" +
"except NameError:\n" + "except NameError:\n" +
" pass\n" + " pass\n" +
"\n" + "\n" +
"python_paths = []\n" + "python_paths = []\n" +
"if os.getenv('PYTHONPATH') is not None:\n" + "if os.getenv('PYTHONPATH') is not None:\n" +
" python_paths = os.getenv('PYTHONPATH').split(':')\n" + " python_paths = os.getenv('PYTHONPATH').split(':')\n" +
"try:\n" + "try:\n" +
" library_paths = site.getsitepackages()\n" + " library_paths = site.getsitepackages()\n" +
"except AttributeError:\n" + "except AttributeError:\n" +
" from distutils.sysconfig import get_python_lib\n" + " from distutils.sysconfig import get_python_lib\n" +
" library_paths = [get_python_lib()]\n" + " library_paths = [get_python_lib()]\n" +
"all_paths = set(python_paths + library_paths)\n" + "all_paths = set(python_paths + library_paths)\n" +
"paths = []\n" + "paths = []\n" +
"for path in all_paths:\n" + "for path in all_paths:\n" +
" if os.path.isdir(path):\n" + " if os.path.isdir(path):\n" +
" paths.append(path)\n" + " paths.append(path)\n" +
"if len(paths) >=1:\n" + "if len(paths) >=1:\n" +
" print(paths[0])\n" + " print(paths[0])\n" +
"END") "END")
cmd = '%s - %s' % (python_bin, print_lib) cmd = "%s - %s" % (python_bin, print_lib)
result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd])
return result.stdout.strip('\n') return result.stdout.strip("\n")
def _check_python_lib(repository_ctx, python_lib): def _check_python_lib(repository_ctx, python_lib):
"""Checks the python lib path.""" """Checks the python lib path."""
cmd = 'test -d "%s" -a -x "%s"' % (python_lib, python_lib) cmd = 'test -d "%s" -a -x "%s"' % (python_lib, python_lib)
result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd])
if result.return_code == 1: if result.return_code == 1:
_fail("Invalid python library path: %s" % python_lib) _fail("Invalid python library path: %s" % python_lib)
def _check_python_bin(repository_ctx, python_bin): def _check_python_bin(repository_ctx, python_bin):
"""Checks the python bin path.""" """Checks the python bin path."""
cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin) cmd = '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin)
result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd]) result = repository_ctx.execute([_get_bash_bin(repository_ctx), "-c", cmd])
if result.return_code == 1: if result.return_code == 1:
_fail("--define %s='%s' is not executable. Is it the python binary?" % ( _fail("--define %s='%s' is not executable. Is it the python binary?" % (
_PYTHON_BIN_PATH, python_bin)) _PYTHON_BIN_PATH,
python_bin,
))
def _get_python_include(repository_ctx, python_bin): def _get_python_include(repository_ctx, python_bin):
"""Gets the python include path.""" """Gets the python include path."""
result = _execute( result = _execute(
repository_ctx, repository_ctx,
[python_bin, "-c", [
'from __future__ import print_function;' + python_bin,
'from distutils import sysconfig;' + "-c",
'print(sysconfig.get_python_inc())'], "from __future__ import print_function;" +
error_msg="Problem getting python include path.", "from distutils import sysconfig;" +
error_details=("Is the Python binary path set up right? " + "print(sysconfig.get_python_inc())",
"(See ./configure or " + _PYTHON_BIN_PATH + ".) " + ],
"Is distutils installed?")) error_msg = "Problem getting python include path.",
return result.stdout.splitlines()[0] error_details = ("Is the Python binary path set up right? " +
"(See ./configure or " + _PYTHON_BIN_PATH + ".) " +
"Is distutils installed?"),
)
return result.stdout.splitlines()[0]
def _get_python_import_lib_name(repository_ctx, python_bin): def _get_python_import_lib_name(repository_ctx, python_bin):
"""Get Python import library name (pythonXY.lib) on Windows.""" """Get Python import library name (pythonXY.lib) on Windows."""
result = _execute( result = _execute(
repository_ctx, repository_ctx,
[python_bin, "-c", [
'import sys;' + python_bin,
'print("python" + str(sys.version_info[0]) + ' + "-c",
' str(sys.version_info[1]) + ".lib")'], "import sys;" +
error_msg="Problem getting python import library.", 'print("python" + str(sys.version_info[0]) + ' +
error_details=("Is the Python binary path set up right? " + ' str(sys.version_info[1]) + ".lib")',
"(See ./configure or " + _PYTHON_BIN_PATH + ".) ")) ],
return result.stdout.splitlines()[0] error_msg = "Problem getting python import library.",
error_details = ("Is the Python binary path set up right? " +
"(See ./configure or " + _PYTHON_BIN_PATH + ".) "),
)
return result.stdout.splitlines()[0]
def _get_numpy_include(repository_ctx, python_bin): def _get_numpy_include(repository_ctx, python_bin):
"""Gets the numpy include path.""" """Gets the numpy include path."""
return _execute(repository_ctx, return _execute(
[python_bin, "-c", repository_ctx,
'from __future__ import print_function;' + [
'import numpy;' + python_bin,
' print(numpy.get_include());'], "-c",
error_msg="Problem getting numpy include path.", "from __future__ import print_function;" +
error_details="Is numpy installed?").stdout.splitlines()[0] "import numpy;" +
" print(numpy.get_include());",
],
error_msg = "Problem getting numpy include path.",
error_details = "Is numpy installed?",
).stdout.splitlines()[0]
def _create_local_python_repository(repository_ctx): def _create_local_python_repository(repository_ctx):
"""Creates the repository containing files set up to build with Python.""" """Creates the repository containing files set up to build with Python."""
python_bin = _get_python_bin(repository_ctx) python_bin = _get_python_bin(repository_ctx)
_check_python_bin(repository_ctx, python_bin) _check_python_bin(repository_ctx, python_bin)
python_lib = _get_python_lib(repository_ctx, python_bin) python_lib = _get_python_lib(repository_ctx, python_bin)
_check_python_lib(repository_ctx, python_lib) _check_python_lib(repository_ctx, python_lib)
python_include = _get_python_include(repository_ctx, python_bin) python_include = _get_python_include(repository_ctx, python_bin)
numpy_include = _get_numpy_include(repository_ctx, python_bin) + '/numpy' numpy_include = _get_numpy_include(repository_ctx, python_bin) + "/numpy"
python_include_rule = _symlink_genrule_for_dir( python_include_rule = _symlink_genrule_for_dir(
repository_ctx, python_include, 'python_include', 'python_include') repository_ctx,
python_import_lib_genrule = "" python_include,
# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib "python_include",
# See https://docs.python.org/3/extending/windows.html "python_include",
if _is_windows(repository_ctx): )
python_include = _norm_path(python_include) python_import_lib_genrule = ""
python_import_lib_name = _get_python_import_lib_name(repository_ctx, python_bin)
python_import_lib_src = python_include.rsplit('/', 1)[0] + "/libs/" + python_import_lib_name
python_import_lib_genrule = _symlink_genrule_for_dir(
repository_ctx, None, '', 'python_import_lib',
[python_import_lib_src], [python_import_lib_name])
numpy_include_rule = _symlink_genrule_for_dir(
repository_ctx, numpy_include, 'numpy_include/numpy', 'numpy_include')
_tpl(repository_ctx, "BUILD", {
"%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
"%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule,
"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
})
# To build Python C/C++ extension on Windows, we need to link to python import library pythonXY.lib
# See https://docs.python.org/3/extending/windows.html
if _is_windows(repository_ctx):
python_include = _norm_path(python_include)
python_import_lib_name = _get_python_import_lib_name(repository_ctx, python_bin)
python_import_lib_src = python_include.rsplit("/", 1)[0] + "/libs/" + python_import_lib_name
python_import_lib_genrule = _symlink_genrule_for_dir(
repository_ctx,
None,
"",
"python_import_lib",
[python_import_lib_src],
[python_import_lib_name],
)
numpy_include_rule = _symlink_genrule_for_dir(
repository_ctx,
numpy_include,
"numpy_include/numpy",
"numpy_include",
)
_tpl(repository_ctx, "BUILD", {
"%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
"%{PYTHON_IMPORT_LIB_GENRULE}": python_import_lib_genrule,
"%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
})
def _create_remote_python_repository(repository_ctx, remote_config_repo): def _create_remote_python_repository(repository_ctx, remote_config_repo):
"""Creates pointers to a remotely configured repo set up to build with Python. """Creates pointers to a remotely configured repo set up to build with Python.
""" """
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {}) repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
def _python_autoconf_impl(repository_ctx): def _python_autoconf_impl(repository_ctx):
"""Implementation of the python_autoconf repository rule.""" """Implementation of the python_autoconf repository rule."""
if _TF_PYTHON_CONFIG_REPO in repository_ctx.os.environ: if _TF_PYTHON_CONFIG_REPO in repository_ctx.os.environ:
_create_remote_python_repository(repository_ctx, _create_remote_python_repository(
repository_ctx.os.environ[_TF_PYTHON_CONFIG_REPO]) repository_ctx,
else: repository_ctx.os.environ[_TF_PYTHON_CONFIG_REPO],
_create_local_python_repository(repository_ctx) )
else:
_create_local_python_repository(repository_ctx)
python_configure = repository_rule( python_configure = repository_rule(
implementation = _python_autoconf_impl, implementation = _python_autoconf_impl,

View File

@ -11,7 +11,7 @@ def if_sycl(if_true, if_false = []):
return select({ return select({
"@local_config_sycl//sycl:using_sycl_ccpp": if_true, "@local_config_sycl//sycl:using_sycl_ccpp": if_true,
"@local_config_sycl//sycl:using_sycl_trisycl": if_true[0:1], "@local_config_sycl//sycl:using_sycl_trisycl": if_true[0:1],
"//conditions:default": if_false "//conditions:default": if_false,
}) })
def if_ccpp(if_true, if_false = []): def if_ccpp(if_true, if_false = []):
@ -24,5 +24,5 @@ def if_ccpp(if_true, if_false = []):
return select({ return select({
"@local_config_sycl//sycl:using_sycl_ccpp": if_true, "@local_config_sycl//sycl:using_sycl_ccpp": if_true,
"@local_config_sycl//sycl:using_sycl_trisycl": if_false, "@local_config_sycl//sycl:using_sycl_trisycl": if_false,
"//conditions:default": if_false "//conditions:default": if_false,
}) })

View File

@ -11,122 +11,124 @@
""" """
_HOST_CXX_COMPILER = "HOST_CXX_COMPILER" _HOST_CXX_COMPILER = "HOST_CXX_COMPILER"
_HOST_C_COMPILER= "HOST_C_COMPILER" _HOST_C_COMPILER = "HOST_C_COMPILER"
_COMPUTECPP_TOOLKIT_PATH = "COMPUTECPP_TOOLKIT_PATH" _COMPUTECPP_TOOLKIT_PATH = "COMPUTECPP_TOOLKIT_PATH"
_TRISYCL_INCLUDE_DIR = "TRISYCL_INCLUDE_DIR" _TRISYCL_INCLUDE_DIR = "TRISYCL_INCLUDE_DIR"
_PYTHON_LIB_PATH = "PYTHON_LIB_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
def _enable_sycl(repository_ctx): def _enable_sycl(repository_ctx):
if "TF_NEED_OPENCL_SYCL" in repository_ctx.os.environ: if "TF_NEED_OPENCL_SYCL" in repository_ctx.os.environ:
enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL_SYCL"].strip() enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL_SYCL"].strip()
return enable_sycl == "1" return enable_sycl == "1"
return False return False
def _enable_compute_cpp(repository_ctx): def _enable_compute_cpp(repository_ctx):
return _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ return _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ
def auto_configure_fail(msg): def auto_configure_fail(msg):
"""Output failure message when auto configuration fails.""" """Output failure message when auto configuration fails."""
red = "\033[0;31m" red = "\033[0;31m"
no_color = "\033[0m" no_color = "\033[0m"
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg)) fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
# END cc_configure common functions (see TODO above). # END cc_configure common functions (see TODO above).
def find_c(repository_ctx): def find_c(repository_ctx):
"""Find host C compiler.""" """Find host C compiler."""
c_name = "gcc" c_name = "gcc"
if _HOST_C_COMPILER in repository_ctx.os.environ: if _HOST_C_COMPILER in repository_ctx.os.environ:
c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip() c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip()
if c_name.startswith("/"): if c_name.startswith("/"):
return c_name return c_name
c = repository_ctx.which(c_name) c = repository_ctx.which(c_name)
if c == None: if c == None:
fail("Cannot find C compiler, please correct your path.") fail("Cannot find C compiler, please correct your path.")
return c return c
def find_cc(repository_ctx): def find_cc(repository_ctx):
"""Find host C++ compiler.""" """Find host C++ compiler."""
cc_name = "g++" cc_name = "g++"
if _HOST_CXX_COMPILER in repository_ctx.os.environ: if _HOST_CXX_COMPILER in repository_ctx.os.environ:
cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip() cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip()
if cc_name.startswith("/"): if cc_name.startswith("/"):
return cc_name return cc_name
cc = repository_ctx.which(cc_name) cc = repository_ctx.which(cc_name)
if cc == None: if cc == None:
fail("Cannot find C++ compiler, please correct your path.") fail("Cannot find C++ compiler, please correct your path.")
return cc return cc
def find_computecpp_root(repository_ctx): def find_computecpp_root(repository_ctx):
"""Find ComputeCpp compiler.""" """Find ComputeCpp compiler."""
sycl_name = "" sycl_name = ""
if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ: if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ:
sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip() sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip()
if sycl_name.startswith("/"): if sycl_name.startswith("/"):
return sycl_name return sycl_name
fail("Cannot find SYCL compiler, please correct your path") fail("Cannot find SYCL compiler, please correct your path")
def find_trisycl_include_dir(repository_ctx): def find_trisycl_include_dir(repository_ctx):
"""Find triSYCL include directory. """ """Find triSYCL include directory. """
if _TRISYCL_INCLUDE_DIR in repository_ctx.os.environ: if _TRISYCL_INCLUDE_DIR in repository_ctx.os.environ:
sycl_name = repository_ctx.os.environ[_TRISYCL_INCLUDE_DIR].strip() sycl_name = repository_ctx.os.environ[_TRISYCL_INCLUDE_DIR].strip()
if sycl_name.startswith("/"): if sycl_name.startswith("/"):
return sycl_name return sycl_name
fail( "Cannot find triSYCL include directory, please correct your path") fail("Cannot find triSYCL include directory, please correct your path")
def find_python_lib(repository_ctx): def find_python_lib(repository_ctx):
"""Returns python path.""" """Returns python path."""
if _PYTHON_LIB_PATH in repository_ctx.os.environ: if _PYTHON_LIB_PATH in repository_ctx.os.environ:
return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip() return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip()
fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure") fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure")
def _check_lib(repository_ctx, toolkit_path, lib): def _check_lib(repository_ctx, toolkit_path, lib):
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't. """Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
toolkit_path: The toolkit directory containing the libraries. toolkit_path: The toolkit directory containing the libraries.
ib: The library to look for under toolkit_path. ib: The library to look for under toolkit_path.
""" """
lib_path = toolkit_path + "/" + lib lib_path = toolkit_path + "/" + lib
if not repository_ctx.path(lib_path).exists: if not repository_ctx.path(lib_path).exists:
auto_configure_fail("Cannot find %s" % lib_path) auto_configure_fail("Cannot find %s" % lib_path)
def _check_dir(repository_ctx, directory): def _check_dir(repository_ctx, directory):
"""Checks whether the directory exists and fail if it does not. """Checks whether the directory exists and fail if it does not.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
directory: The directory to check the existence of. directory: The directory to check the existence of.
""" """
if not repository_ctx.path(directory).exists: if not repository_ctx.path(directory).exists:
auto_configure_fail("Cannot find dir: %s" % directory) auto_configure_fail("Cannot find dir: %s" % directory)
def _symlink_dir(repository_ctx, src_dir, dest_dir): def _symlink_dir(repository_ctx, src_dir, dest_dir):
"""Symlinks all the files in a directory. """Symlinks all the files in a directory.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
src_dir: The source directory. src_dir: The source directory.
dest_dir: The destination directory to create the symlinks in. dest_dir: The destination directory to create the symlinks in.
""" """
files = repository_ctx.path(src_dir).readdir() files = repository_ctx.path(src_dir).readdir()
for src_file in files: for src_file in files:
repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename) repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
def _tpl(repository_ctx, tpl, substitutions={}, out=None): def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out: if not out:
out = tpl.replace(":", "/") out = tpl.replace(":", "/")
repository_ctx.template( repository_ctx.template(
out, out,
Label("//third_party/sycl/%s.tpl" % tpl), Label("//third_party/sycl/%s.tpl" % tpl),
substitutions) substitutions,
)
def _file(repository_ctx, label): def _file(repository_ctx, label):
repository_ctx.template( repository_ctx.template(
label.replace(":", "/"), label.replace(":", "/"),
Label("//third_party/sycl/%s" % label), Label("//third_party/sycl/%s" % label),
{}) {},
)
_DUMMY_CROSSTOOL_BZL_FILE = """ _DUMMY_CROSSTOOL_BZL_FILE = """
def error_sycl_disabled(): def error_sycl_disabled():
@ -147,7 +149,6 @@ def error_sycl_disabled():
) )
""" """
_DUMMY_CROSSTOOL_BUILD_FILE = """ _DUMMY_CROSSTOOL_BUILD_FILE = """
load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled") load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled")
@ -155,87 +156,97 @@ error_sycl_disabled()
""" """
def _create_dummy_repository(repository_ctx): def _create_dummy_repository(repository_ctx):
# Set up BUILD file for sycl/. # Set up BUILD file for sycl/.
_tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
_file(repository_ctx, "sycl:LICENSE.text")
_tpl(repository_ctx, "sycl:platform.bzl")
# Create dummy files for the SYCL toolkit since they are still required by
# tensorflow/sycl/platform/default/build_config:sycl.
repository_ctx.file("sycl/include/sycl.hpp", "")
repository_ctx.file("sycl/lib/libComputeCpp.so", "")
# If sycl_configure is not configured to build with SYCL support, and the user
# attempts to build with --config=sycl, add a dummy build rule to intercept
# this and fail with an actionable error message.
repository_ctx.file("crosstool/error_sycl_disabled.bzl",
_DUMMY_CROSSTOOL_BZL_FILE)
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
def _sycl_autoconf_imp(repository_ctx):
"""Implementation of the sycl_autoconf rule."""
if not _enable_sycl(repository_ctx):
_create_dummy_repository(repository_ctx)
else:
# copy template files
_tpl(repository_ctx, "sycl:build_defs.bzl") _tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD") _tpl(repository_ctx, "sycl:BUILD")
_tpl(repository_ctx, "sycl:platform.bzl")
_tpl(repository_ctx, "crosstool:BUILD")
_file(repository_ctx, "sycl:LICENSE.text") _file(repository_ctx, "sycl:LICENSE.text")
_tpl(repository_ctx, "sycl:platform.bzl")
if _enable_compute_cpp(repository_ctx): # Create dummy files for the SYCL toolkit since they are still required by
_tpl(repository_ctx, "crosstool:computecpp", # tensorflow/sycl/platform/default/build_config:sycl.
{ repository_ctx.file("sycl/include/sycl.hpp", "")
"%{host_cxx_compiler}" : find_cc(repository_ctx), repository_ctx.file("sycl/lib/libComputeCpp.so", "")
"%{host_c_compiler}" : find_c(repository_ctx)
})
computecpp_root = find_computecpp_root(repository_ctx); # If sycl_configure is not configured to build with SYCL support, and the user
_check_dir(repository_ctx, computecpp_root) # attempts to build with --config=sycl, add a dummy build rule to intercept
# this and fail with an actionable error message.
repository_ctx.file(
"crosstool/error_sycl_disabled.bzl",
_DUMMY_CROSSTOOL_BZL_FILE,
)
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
_tpl(repository_ctx, "crosstool:CROSSTOOL", def _sycl_autoconf_imp(repository_ctx):
{ """Implementation of the sycl_autoconf rule."""
"%{sycl_include_dir}" : computecpp_root, if not _enable_sycl(repository_ctx):
"%{sycl_impl}" : "computecpp", _create_dummy_repository(repository_ctx)
"%{c++_std}" : "-std=c++11",
"%{python_lib_path}" : find_python_lib(repository_ctx),
})
# symlink libraries
_check_lib(repository_ctx, computecpp_root+"/lib", "libComputeCpp.so" )
_symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
else: else:
# copy template files
_tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
_tpl(repository_ctx, "sycl:platform.bzl")
_tpl(repository_ctx, "crosstool:BUILD")
_file(repository_ctx, "sycl:LICENSE.text")
trisycl_include_dir = find_trisycl_include_dir(repository_ctx); if _enable_compute_cpp(repository_ctx):
_check_dir(repository_ctx, trisycl_include_dir) _tpl(
repository_ctx,
"crosstool:computecpp",
{
"%{host_cxx_compiler}": find_cc(repository_ctx),
"%{host_c_compiler}": find_c(repository_ctx),
},
)
_tpl(repository_ctx, "crosstool:trisycl", computecpp_root = find_computecpp_root(repository_ctx)
{ _check_dir(repository_ctx, computecpp_root)
"%{host_cxx_compiler}" : find_cc(repository_ctx),
"%{host_c_compiler}" : find_c(repository_ctx),
"%{trisycl_include_dir}" : trisycl_include_dir
})
_tpl(
repository_ctx,
"crosstool:CROSSTOOL",
{
"%{sycl_include_dir}": computecpp_root,
"%{sycl_impl}": "computecpp",
"%{c++_std}": "-std=c++11",
"%{python_lib_path}": find_python_lib(repository_ctx),
},
)
_tpl(repository_ctx, "crosstool:CROSSTOOL", # symlink libraries
{ _check_lib(repository_ctx, computecpp_root + "/lib", "libComputeCpp.so")
"%{sycl_include_dir}" : trisycl_include_dir, _symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
"%{sycl_impl}" : "trisycl", _symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
"%{c++_std}" : "-std=c++1y", _symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
"%{python_lib_path}" : find_python_lib(repository_ctx), else:
}) trisycl_include_dir = find_trisycl_include_dir(repository_ctx)
_check_dir(repository_ctx, trisycl_include_dir)
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include") _tpl(
repository_ctx,
"crosstool:trisycl",
{
"%{host_cxx_compiler}": find_cc(repository_ctx),
"%{host_c_compiler}": find_c(repository_ctx),
"%{trisycl_include_dir}": trisycl_include_dir,
},
)
_tpl(
repository_ctx,
"crosstool:CROSSTOOL",
{
"%{sycl_include_dir}": trisycl_include_dir,
"%{sycl_impl}": "trisycl",
"%{c++_std}": "-std=c++1y",
"%{python_lib_path}": find_python_lib(repository_ctx),
},
)
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include")
sycl_configure = repository_rule( sycl_configure = repository_rule(
implementation = _sycl_autoconf_imp, implementation = _sycl_autoconf_imp,
local = True, local = True,
) )
"""Detects and configures the SYCL toolchain. """Detects and configures the SYCL toolchain.

View File

@ -10,13 +10,13 @@
load( load(
"//third_party/gpus:cuda_configure.bzl", "//third_party/gpus:cuda_configure.bzl",
"auto_configure_fail", "auto_configure_fail",
"get_cpu_value",
"find_cuda_define", "find_cuda_define",
"find_lib", "find_lib",
"get_cpu_value",
"lib_name", "lib_name",
"matches_version",
"make_copy_dir_rule", "make_copy_dir_rule",
"make_copy_files_rule", "make_copy_files_rule",
"matches_version",
) )
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH" _TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
@ -30,185 +30,200 @@ _DEFINE_TENSORRT_SONAME_MAJOR = "#define NV_TENSORRT_SONAME_MAJOR"
_DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR" _DEFINE_TENSORRT_SONAME_MINOR = "#define NV_TENSORRT_SONAME_MINOR"
_DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH" _DEFINE_TENSORRT_SONAME_PATCH = "#define NV_TENSORRT_SONAME_PATCH"
def _headers_exist(repository_ctx, path): def _headers_exist(repository_ctx, path):
"""Returns whether all TensorRT header files could be found in 'path'. """Returns whether all TensorRT header files could be found in 'path'.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
path: The TensorRT include path to check. path: The TensorRT include path to check.
Returns:
True if all TensorRT header files can be found in the path.
"""
for h in _TF_TENSORRT_HEADERS:
if not repository_ctx.path("%s/%s" % (path, h)).exists:
return False
return True
Returns:
True if all TensorRT header files can be found in the path.
"""
for h in _TF_TENSORRT_HEADERS:
if not repository_ctx.path("%s/%s" % (path, h)).exists:
return False
return True
def _find_trt_header_dir(repository_ctx, trt_install_path): def _find_trt_header_dir(repository_ctx, trt_install_path):
"""Returns the path to the directory containing headers of TensorRT. """Returns the path to the directory containing headers of TensorRT.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
trt_install_path: The TensorRT library install directory. trt_install_path: The TensorRT library install directory.
Returns: Returns:
The path of the directory containing the TensorRT header. The path of the directory containing the TensorRT header.
""" """
if trt_install_path == "/usr/lib/x86_64-linux-gnu": if trt_install_path == "/usr/lib/x86_64-linux-gnu":
path = "/usr/include/x86_64-linux-gnu" path = "/usr/include/x86_64-linux-gnu"
if _headers_exist(repository_ctx, path):
return path
if trt_install_path == "/usr/lib/aarch64-linux-gnu":
path = "/usr/include/aarch64-linux-gnu"
if _headers_exist(repository_ctx, path):
return path
path = str(repository_ctx.path("%s/../include" % trt_install_path).realpath)
if _headers_exist(repository_ctx, path): if _headers_exist(repository_ctx, path):
return path return path
if trt_install_path == "/usr/lib/aarch64-linux-gnu": auto_configure_fail(
path = "/usr/include/aarch64-linux-gnu" "Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path,
if _headers_exist(repository_ctx, path): )
return path
path = str(repository_ctx.path("%s/../include" % trt_install_path).realpath)
if _headers_exist(repository_ctx, path):
return path
auto_configure_fail(
"Cannot find NvInfer.h with TensorRT install path %s" % trt_install_path)
def _trt_lib_version(repository_ctx, trt_install_path): def _trt_lib_version(repository_ctx, trt_install_path):
"""Detects the library (e.g. libnvinfer) version of TensorRT. """Detects the library (e.g. libnvinfer) version of TensorRT.
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
trt_install_path: The TensorRT library install directory. trt_install_path: The TensorRT library install directory.
Returns: Returns:
A string containing the library version of TensorRT. A string containing the library version of TensorRT.
""" """
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path) trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
major_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h", major_version = find_cuda_define(
_DEFINE_TENSORRT_SONAME_MAJOR) repository_ctx,
minor_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h", trt_header_dir,
_DEFINE_TENSORRT_SONAME_MINOR) "NvInfer.h",
patch_version = find_cuda_define(repository_ctx, trt_header_dir, "NvInfer.h", _DEFINE_TENSORRT_SONAME_MAJOR,
_DEFINE_TENSORRT_SONAME_PATCH) )
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version) minor_version = find_cuda_define(
environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip() repository_ctx,
if not matches_version(environ_version, full_version): trt_header_dir,
auto_configure_fail( "NvInfer.h",
("TensorRT library version detected from %s/%s (%s) does not match " + _DEFINE_TENSORRT_SONAME_MINOR,
"TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") % )
(trt_header_dir, "NvInfer.h", full_version, environ_version)) patch_version = find_cuda_define(
# Only use the major version to match the SONAME of the library. repository_ctx,
return major_version trt_header_dir,
"NvInfer.h",
_DEFINE_TENSORRT_SONAME_PATCH,
)
full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
environ_version = repository_ctx.os.environ[_TF_TENSORRT_VERSION].strip()
if not matches_version(environ_version, full_version):
auto_configure_fail(
("TensorRT library version detected from %s/%s (%s) does not match " +
"TF_TENSORRT_VERSION (%s). To fix this rerun configure again.") %
(trt_header_dir, "NvInfer.h", full_version, environ_version),
)
# Only use the major version to match the SONAME of the library.
return major_version
def _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version): def _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version):
"""Finds the given TensorRT library on the system. """Finds the given TensorRT library on the system.
Adapted from code contributed by Sami Kama (https://github.com/samikama). Adapted from code contributed by Sami Kama (https://github.com/samikama).
Args: Args:
repository_ctx: The repository context. repository_ctx: The repository context.
trt_install_path: The TensorRT library installation directory. trt_install_path: The TensorRT library installation directory.
trt_lib_version: The version of TensorRT library files as returned trt_lib_version: The version of TensorRT library files as returned
by _trt_lib_version. by _trt_lib_version.
Returns:
The path to the library.
"""
result = {}
for lib in _TF_TENSORRT_LIBS:
file_name = lib_name("nvinfer", cpu_value, trt_lib_version)
path = find_lib(repository_ctx, ["%s/%s" % (trt_install_path, file_name)])
result[file_name] = path
return result
Returns:
The path to the library.
"""
result = {}
for lib in _TF_TENSORRT_LIBS:
file_name = lib_name("nvinfer", cpu_value, trt_lib_version)
path = find_lib(repository_ctx, ["%s/%s" % (trt_install_path, file_name)])
result[file_name] = path
return result
def _tpl(repository_ctx, tpl, substitutions): def _tpl(repository_ctx, tpl, substitutions):
repository_ctx.template(tpl, Label("//third_party/tensorrt:%s.tpl" % tpl), repository_ctx.template(
substitutions) tpl,
Label("//third_party/tensorrt:%s.tpl" % tpl),
substitutions,
)
def _create_dummy_repository(repository_ctx): def _create_dummy_repository(repository_ctx):
"""Create a dummy TensorRT repository.""" """Create a dummy TensorRT repository."""
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"}) _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_false"})
_tpl(repository_ctx, "BUILD", { _tpl(repository_ctx, "BUILD", {
"%{tensorrt_genrules}": "", "%{tensorrt_genrules}": "",
"%{tensorrt_headers}": "[]", "%{tensorrt_headers}": "[]",
"%{tensorrt_libs}": "[]" "%{tensorrt_libs}": "[]",
}) })
def _tensorrt_configure_impl(repository_ctx): def _tensorrt_configure_impl(repository_ctx):
"""Implementation of the tensorrt_configure repository rule.""" """Implementation of the tensorrt_configure repository rule."""
if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ: if _TF_TENSORRT_CONFIG_REPO in repository_ctx.os.environ:
# Forward to the pre-configured remote repository. # Forward to the pre-configured remote repository.
remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO] remote_config_repo = repository_ctx.os.environ[_TF_TENSORRT_CONFIG_REPO]
repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {}) repository_ctx.template("BUILD", Label(remote_config_repo + ":BUILD"), {})
repository_ctx.template( repository_ctx.template(
"build_defs.bzl", "build_defs.bzl",
Label(remote_config_repo + ":build_defs.bzl"), Label(remote_config_repo + ":build_defs.bzl"),
{}, {},
)
return
if _TENSORRT_INSTALL_PATH not in repository_ctx.os.environ:
_create_dummy_repository(repository_ctx)
return
cpu_value = get_cpu_value(repository_ctx)
if (cpu_value != "Linux"):
auto_configure_fail("TensorRT is supported only on Linux.")
if _TF_TENSORRT_VERSION not in repository_ctx.os.environ:
auto_configure_fail("TensorRT library (libnvinfer) version is not set.")
trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip()
if not repository_ctx.path(trt_install_path).exists:
auto_configure_fail(
"Cannot find TensorRT install path %s." % trt_install_path,
)
# Copy the library files.
trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path)
trt_libs = _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version)
trt_lib_srcs = []
trt_lib_outs = []
for path in trt_libs.values():
trt_lib_srcs.append(str(path))
trt_lib_outs.append("tensorrt/lib/" + path.basename)
copy_rules = [make_copy_files_rule(
repository_ctx,
name = "tensorrt_lib",
srcs = trt_lib_srcs,
outs = trt_lib_outs,
)]
# Copy the header files header files.
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
trt_header_srcs = [
"%s/%s" % (trt_header_dir, header)
for header in _TF_TENSORRT_HEADERS
]
trt_header_outs = [
"tensorrt/include/" + header
for header in _TF_TENSORRT_HEADERS
]
copy_rules.append(
make_copy_files_rule(
repository_ctx,
name = "tensorrt_include",
srcs = trt_header_srcs,
outs = trt_header_outs,
),
) )
return
if _TENSORRT_INSTALL_PATH not in repository_ctx.os.environ: # Set up config file.
_create_dummy_repository(repository_ctx) _tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
return
cpu_value = get_cpu_value(repository_ctx)
if (cpu_value != "Linux"):
auto_configure_fail("TensorRT is supported only on Linux.")
if _TF_TENSORRT_VERSION not in repository_ctx.os.environ:
auto_configure_fail("TensorRT library (libnvinfer) version is not set.")
trt_install_path = repository_ctx.os.environ[_TENSORRT_INSTALL_PATH].strip()
if not repository_ctx.path(trt_install_path).exists:
auto_configure_fail(
"Cannot find TensorRT install path %s." % trt_install_path)
# Copy the library files.
trt_lib_version = _trt_lib_version(repository_ctx, trt_install_path)
trt_libs = _find_trt_libs(repository_ctx, cpu_value, trt_install_path, trt_lib_version)
trt_lib_srcs = []
trt_lib_outs = []
for path in trt_libs.values():
trt_lib_srcs.append(str(path))
trt_lib_outs.append("tensorrt/lib/" + path.basename)
copy_rules = [make_copy_files_rule(
repository_ctx,
name = "tensorrt_lib",
srcs = trt_lib_srcs,
outs = trt_lib_outs,
)]
# Copy the header files header files.
trt_header_dir = _find_trt_header_dir(repository_ctx, trt_install_path)
trt_header_srcs = [
"%s/%s" % (trt_header_dir, header) for header in _TF_TENSORRT_HEADERS
]
trt_header_outs = [
"tensorrt/include/" + header for header in _TF_TENSORRT_HEADERS
]
copy_rules.append(
make_copy_files_rule(
repository_ctx,
name = "tensorrt_include",
srcs = trt_header_srcs,
outs = trt_header_outs,
))
# Set up config file.
_tpl(repository_ctx, "build_defs.bzl", {"%{if_tensorrt}": "if_true"})
# Set up BUILD file.
_tpl(repository_ctx, "BUILD", {
"%{copy_rules}": "\n".join(copy_rules),
"%{tensorrt_headers}": '":tensorrt_include"',
"%{tensorrt_libs}": str(trt_lib_outs),
})
# Set up BUILD file.
_tpl(repository_ctx, "BUILD", {
"%{copy_rules}": "\n".join(copy_rules),
"%{tensorrt_headers}": '":tensorrt_include"',
"%{tensorrt_libs}": str(trt_lib_outs),
})
tensorrt_configure = repository_rule( tensorrt_configure = repository_rule(
implementation=_tensorrt_configure_impl, implementation = _tensorrt_configure_impl,
environ=[ environ = [
_TENSORRT_INSTALL_PATH, _TENSORRT_INSTALL_PATH,
_TF_TENSORRT_VERSION, _TF_TENSORRT_VERSION,
], ],

View File

@ -1,30 +1,37 @@
"""Repository rule for Debian 8 Jessie Clang-6.0 portable Linux builds.""" """Repository rule for Debian 8 Jessie Clang-6.0 portable Linux builds."""
def _clang6_configure(ctx): def _clang6_configure(ctx):
# TODO(jart): It'd probably be better to use Bazel's struct.to_proto() # TODO(jart): It'd probably be better to use Bazel's struct.to_proto()
# method to generate a gigantic CROSSTOOL file that allows # method to generate a gigantic CROSSTOOL file that allows
# Clang to support everything. # Clang to support everything.
ctx.symlink( ctx.symlink(
ctx.os.environ.get('TF_LLVM_PATH', ctx.os.environ.get(
'/usr/lib/llvm-6.0'), "TF_LLVM_PATH",
'clang6/llvm') "/usr/lib/llvm-6.0",
ctx.symlink( ),
ctx.os.environ.get('STRIP', '/usr/bin/strip'), "clang6/llvm",
'clang6/sbin/strip') )
ctx.symlink( ctx.symlink(
ctx.os.environ.get('OBJDUMP', '/usr/bin/objdump'), ctx.os.environ.get("STRIP", "/usr/bin/strip"),
'clang6/sbin/objdump') "clang6/sbin/strip",
ctx.symlink(ctx.attr._build, 'clang6/BUILD') )
ctx.template('clang6/CROSSTOOL', ctx.attr._crosstool, { ctx.symlink(
'%package(@local_config_clang6//clang6)%': str(ctx.path('clang6')), ctx.os.environ.get("OBJDUMP", "/usr/bin/objdump"),
}) "clang6/sbin/objdump",
)
ctx.symlink(ctx.attr._build, "clang6/BUILD")
ctx.template("clang6/CROSSTOOL", ctx.attr._crosstool, {
"%package(@local_config_clang6//clang6)%": str(ctx.path("clang6")),
})
clang6_configure = repository_rule( clang6_configure = repository_rule(
implementation = _clang6_configure, implementation = _clang6_configure,
attrs = { attrs = {
'_build': attr.label( "_build": attr.label(
default=str(Label('//third_party/toolchains/clang6:clang.BUILD'))), default = str(Label("//third_party/toolchains/clang6:clang.BUILD")),
'_crosstool': attr.label( ),
default=str(Label('//third_party/toolchains/clang6:CROSSTOOL.tpl'))), "_crosstool": attr.label(
default = str(Label("//third_party/toolchains/clang6:CROSSTOOL.tpl")),
),
}, },
) )

View File

@ -1,38 +1,38 @@
# -*- Python -*- # -*- Python -*-
"""Repository rule for arm compiler autoconfiguration.""" """Repository rule for arm compiler autoconfiguration."""
def _tpl(repository_ctx, tpl, substitutions={}, out=None): def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out: if not out:
out = tpl out = tpl
repository_ctx.template( repository_ctx.template(
out, out,
Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl), Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl),
substitutions) substitutions,
)
def _arm_compiler_configure_impl(repository_ctx): def _arm_compiler_configure_impl(repository_ctx):
# We need to find a cross-compilation include directory for Python, so look # We need to find a cross-compilation include directory for Python, so look
# for an environment variable. Be warned, this crosstool template is only # for an environment variable. Be warned, this crosstool template is only
# regenerated on the first run of Bazel, so if you change the variable after # regenerated on the first run of Bazel, so if you change the variable after
# it may not be reflected in later builds. Doing a shutdown and clean of Bazel # it may not be reflected in later builds. Doing a shutdown and clean of Bazel
# doesn't fix this, you'll need to delete the generated file at something like: # doesn't fix this, you'll need to delete the generated file at something like:
# external/local_config_arm_compiler/CROSSTOOL in your Bazel install. # external/local_config_arm_compiler/CROSSTOOL in your Bazel install.
if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ: if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ:
python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"] python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"]
else: else:
python_include_path = "/usr/include/python2.7" python_include_path = "/usr/include/python2.7"
_tpl(repository_ctx, "CROSSTOOL", { _tpl(repository_ctx, "CROSSTOOL", {
"%{ARM_COMPILER_PATH}%": str(repository_ctx.path( "%{ARM_COMPILER_PATH}%": str(repository_ctx.path(
repository_ctx.attr.remote_config_repo)), repository_ctx.attr.remote_config_repo,
"%{PYTHON_INCLUDE_PATH}%": python_include_path, )),
}) "%{PYTHON_INCLUDE_PATH}%": python_include_path,
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD") })
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
arm_compiler_configure = repository_rule( arm_compiler_configure = repository_rule(
implementation = _arm_compiler_configure_impl, implementation = _arm_compiler_configure_impl,
attrs = { attrs = {
"remote_config_repo": attr.string(mandatory = False, default =""), "remote_config_repo": attr.string(mandatory = False, default = ""),
"build_file": attr.label(), "build_file": attr.label(),
}, },
) )

View File

@ -2,11 +2,11 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def bazel_toolchains_archive(): def bazel_toolchains_archive():
http_archive( http_archive(
name = "bazel_toolchains", name = "bazel_toolchains",
sha256 = "109a99384f9d08f9e75136d218ebaebc68cc810c56897aea2224c57932052d30", sha256 = "109a99384f9d08f9e75136d218ebaebc68cc810c56897aea2224c57932052d30",
strip_prefix = "bazel-toolchains-94d31935a2c94fe7e7c7379a0f3393e181928ff7", strip_prefix = "bazel-toolchains-94d31935a2c94fe7e7c7379a0f3393e181928ff7",
urls = [ urls = [
"https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz", "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz",
"https://github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz", "https://github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz",
], ],
) )

View File

@ -27,6 +27,7 @@ def _tensorflow_rbe_config(name, compiler, python_version, cuda_version = None,
if cuda_version != None: if cuda_version != None:
base = "@cuda%s-cudnn%s-ubuntu14.04//image" % (cuda_version, cudnn_version) base = "@cuda%s-cudnn%s-ubuntu14.04//image" % (cuda_version, cudnn_version)
# The cuda toolchain currently contains its own C++ toolchain definition, # The cuda toolchain currently contains its own C++ toolchain definition,
# so we do not fetch local_config_cc. # so we do not fetch local_config_cc.
config_repos = [ config_repos = [
@ -42,7 +43,7 @@ def _tensorflow_rbe_config(name, compiler, python_version, cuda_version = None,
"TF_CUDNN_VERSION": cudnn_version, "TF_CUDNN_VERSION": cudnn_version,
"TF_CUDA_VERSION": cuda_version, "TF_CUDA_VERSION": cuda_version,
"CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "CUDNN_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
"TF_NEED_TENSORRT" : "1", "TF_NEED_TENSORRT": "1",
"TF_TENSORRT_VERSION": tensorrt_version, "TF_TENSORRT_VERSION": tensorrt_version,
"TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu",
"GCC_HOST_COMPILER_PATH": compiler if compiler != "clang" else "", "GCC_HOST_COMPILER_PATH": compiler if compiler != "clang" else "",