Rolling back tensorflow .bzl file changes

END_PUBLIC

BEGIN_PUBLIC
Automated g4 rollback of changelist 203459720

PiperOrigin-RevId: 203501636
This commit is contained in:
Rohan Jain 2018-07-06 11:14:53 -07:00 committed by TensorFlower Gardener
parent 8f5e2a740e
commit ed494f17fc
37 changed files with 4512 additions and 4762 deletions

View File

@ -16,355 +16,339 @@ tf_library(
)
"""
load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
"tf_copts",
)
load("//tensorflow:tensorflow.bzl",
"if_android", "tf_cc_test", "tf_copts")
def tf_library(
name,
graph,
config,
freeze_checkpoint = None,
freeze_saver = None,
cpp_class = None,
gen_test = True,
gen_benchmark = True,
visibility = None,
testonly = None,
tfcompile_flags = None,
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps = True,
enable_xla_hlo_profiling = False,
deps = None,
tags = None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
def tf_library(name, graph, config,
freeze_checkpoint=None, freeze_saver=None,
cpp_class=None, gen_test=True, gen_benchmark=True,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps=True,
enable_xla_hlo_profiling=False, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
foo: A cc_library containing the generated header and computation.
foo_test: A cc_test with simple tests and benchmarks. Only created if
gen_test=True.
foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful
for mobile devices or other platforms that can't compile the
full test libraries. Only created if gen_benchmark=True.
Given an invocation of tf_library(name="foo", ...), generates the following
build targets:
foo: A cc_library containing the generated header and computation.
foo_test: A cc_test with simple tests and benchmarks. Only created if
gen_test=True.
foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful
for mobile devices or other platforms that can't compile the
full test libraries. Only created if gen_benchmark=True.
Args:
name: The name of the build rule.
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
convert variables into constants.
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
binary form, to convert variables into constants.
cpp_class: The name of the generated C++ class, wrapping the generated
function. The syntax of this flag is
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
for referring to a class, where multiple namespaces may precede the class
name, separated by double-colons. The class will be generated in the
given namespace(s), or if no namespaces are given, within the global
namespace.
gen_test: If True, also generate a cc_test rule that builds a simple
test and benchmark.
gen_benchmark: If True, also generate a binary with a simple benchmark.
Unlike the output of gen_test, this benchmark can be run on android.
visibility: Bazel build visibility.
testonly: Bazel testonly attribute.
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
tfcompile_tool: The tfcompile binary. A non-default can be passed to
use a tfcompile built with extra dependencies.
include_standard_runtime_deps: If True, the standard list of kernel/runtime
deps is added to deps. If False, deps must contain the full set of deps
needed by the generated library.
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
and emit metadata that lets us pretty-print the gathered profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
Args:
name: The name of the build rule.
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
is expected to be in the human-readable proto text format, otherwise it is
expected to be in the proto binary format.
config: File containing tensorflow.tf2xla.Config proto. If the file ends
in '.pbtxt' it is expected to be in the human-readable proto text format,
otherwise it is expected to be in the proto binary format.
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
convert variables into constants.
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
binary form, to convert variables into constants.
cpp_class: The name of the generated C++ class, wrapping the generated
function. The syntax of this flag is
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
for referring to a class, where multiple namespaces may precede the class
name, separated by double-colons. The class will be generated in the
given namespace(s), or if no namespaces are given, within the global
namespace.
gen_test: If True, also generate a cc_test rule that builds a simple
test and benchmark.
gen_benchmark: If True, also generate a binary with a simple benchmark.
Unlike the output of gen_test, this benchmark can be run on android.
visibility: Bazel build visibility.
testonly: Bazel testonly attribute.
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
tfcompile_tool: The tfcompile binary. A non-default can be passed to
use a tfcompile built with extra dependencies.
include_standard_runtime_deps: If True, the standard list of kernel/runtime
deps is added to deps. If False, deps must contain the full set of deps
needed by the generated library.
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
and emit metadata that lets us pretty-print the gathered profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
The output header is called <name>.h.
"""
if not cpp_class:
fail("cpp_class must be specified")
The output header is called <name>.h.
"""
if not cpp_class:
fail("cpp_class must be specified")
tfcompile_graph = graph
if freeze_checkpoint or freeze_saver:
if not freeze_checkpoint:
fail("freeze_checkpoint must be specified when freeze_saver is specified")
tfcompile_graph = graph
if freeze_checkpoint or freeze_saver:
if not freeze_checkpoint:
fail("freeze_checkpoint must be specified when freeze_saver is specified")
freeze_name = "freeze_" + name
freeze_file = freeze_name + ".pb"
freeze_name = "freeze_" + name
freeze_file = freeze_name + ".pb"
# First run tfcompile to generate the list of out_nodes.
out_nodes_file = "out_nodes_" + freeze_name
native.genrule(
name = ("gen_" + out_nodes_file),
srcs = [config],
outs = [out_nodes_file],
cmd = ("$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools = [tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local = 1,
tags = tags,
)
# Now run freeze_graph to convert variables into constants.
freeze_args = (" --input_graph=$(location " + graph + ")" +
" --checkpoint_version=1" +
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
" --output_graph=$(location " + freeze_file + ")" +
" --output_node_names=$$(<$(location " + out_nodes_file +
"))")
freeze_saver_srcs = []
if freeze_saver:
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
freeze_saver_srcs += [freeze_saver]
native.genrule(
name = freeze_name,
srcs = [
graph,
freeze_checkpoint,
out_nodes_file,
] + freeze_saver_srcs,
outs = [freeze_file],
cmd = ("$(location //tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools = ["//tensorflow/python/tools:freeze_graph"],
tags = tags,
)
tfcompile_graph = freeze_file
# Rule that runs tfcompile to produce the header and object file.
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
profiling_flag = ""
# First run tfcompile to generate the list of out_nodes.
out_nodes_file = "out_nodes_" + freeze_name
native.genrule(
name = ("gen_" + name),
srcs = [
tfcompile_graph,
config,
],
outs = [
name=("gen_" + out_nodes_file),
srcs=[config],
outs=[out_nodes_file],
cmd=("$(location " + tfcompile_tool + ")" +
" --config=$(location " + config + ")" +
" --dump_fetch_nodes > $@"),
tools=[tfcompile_tool],
# Run tfcompile on the build host, rather than forge, since it's
# typically way faster on the local machine.
local=1,
tags=tags,
)
# Now run freeze_graph to convert variables into constants.
freeze_args = (" --input_graph=$(location " + graph + ")" +
" --checkpoint_version=1" +
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
" --output_graph=$(location " + freeze_file + ")" +
" --output_node_names=$$(<$(location " + out_nodes_file +
"))")
freeze_saver_srcs = []
if freeze_saver:
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
freeze_saver_srcs += [freeze_saver]
native.genrule(
name=freeze_name,
srcs=[
graph,
freeze_checkpoint,
out_nodes_file,
] + freeze_saver_srcs,
outs=[freeze_file],
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools=["//tensorflow/python/tools:freeze_graph"],
tags=tags,
)
tfcompile_graph = freeze_file
# Rule that runs tfcompile to produce the header and object file.
header_file = name + ".h"
metadata_object_file = name + "_tfcompile_metadata.o"
function_object_file = name + "_tfcompile_function.o"
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
if type(tfcompile_flags) == type(""):
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
if enable_xla_hlo_profiling:
profiling_flag = "--xla_hlo_profile"
else:
profiling_flag = ""
native.genrule(
name=("gen_" + name),
srcs=[
tfcompile_graph,
config,
],
outs=[
header_file,
metadata_object_file,
function_object_file,
],
cmd=("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
# Run tfcompile on the build host since it's typically faster on the local
# machine.
#
# Note that setting the local=1 attribute on a *test target* causes the
# test infrastructure to skip that test. However this is a genrule, not a
# test target, and runs with --genrule_strategy=forced_forge, meaning the
# local=1 attribute is ignored, and the genrule is still run.
#
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
local=1,
tags=tags,
)
# Rule that runs tfcompile to produce the SessionModule proto, useful for
# debugging. TODO(b/64813587): Once the SessionModule proto is
# deterministic, move this into the main rule above.
session_module_pb = name + "_session_module.pb"
native.genrule(
name=(name + "_session_module"),
srcs=[
tfcompile_graph,
config,
],
outs=[
session_module_pb,
],
cmd=("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
local=1,
tags=tags,
)
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name=name,
srcs=[function_object_file, metadata_object_file],
hdrs=[header_file],
visibility=visibility,
testonly=testonly,
deps = [
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core:framework_lite",
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"//tensorflow/compiler/xla:xla_data_proto",
] or []) + (enable_xla_hlo_profiling and [
"//tensorflow/compiler/xla/service:hlo_profile_printer_data"
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//third_party/eigen3",
] or []) + (deps or []),
tags=tags,
)
# Variables used for gen_test and gen_benchmark.
no_ns_name = ""
cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
no_ns_name = cpp_class_split[1]
sed_replace = (
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
if gen_test:
test_name = name + "_test"
test_file = test_name + ".cc"
# Rule to rewrite test.cc to produce the test_file.
native.genrule(
name=("gen_" + test_name),
testonly=1,
srcs=[
"//tensorflow/compiler/aot:test.cc",
header_file,
metadata_object_file,
function_object_file,
],
cmd = ("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
" " + flags + " " + profiling_flag),
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
# Run tfcompile on the build host since it's typically faster on the local
# machine.
#
# Note that setting the local=1 attribute on a *test target* causes the
# test infrastructure to skip that test. However this is a genrule, not a
# test target, and runs with --genrule_strategy=forced_forge, meaning the
# local=1 attribute is ignored, and the genrule is still run.
#
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
local = 1,
tags = tags,
outs=[test_file],
cmd=("sed " + sed_replace +
" $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags=tags,
)
# Rule that runs tfcompile to produce the SessionModule proto, useful for
# debugging. TODO(b/64813587): Once the SessionModule proto is
# deterministic, move this into the main rule above.
session_module_pb = name + "_session_module.pb"
native.genrule(
name = (name + "_session_module"),
srcs = [
tfcompile_graph,
config,
],
outs = [
session_module_pb,
],
cmd = ("$(location " + tfcompile_tool + ")" +
" --graph=$(location " + tfcompile_graph + ")" +
" --config=$(location " + config + ")" +
" --entry_point=" + ep +
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
" " + flags),
tools = [tfcompile_tool],
visibility = visibility,
testonly = testonly,
local = 1,
tags = tags,
)
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
native.cc_library(
name = name,
srcs = [function_object_file, metadata_object_file],
hdrs = [header_file],
visibility = visibility,
testonly = testonly,
deps = [
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core:framework_lite",
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"//tensorflow/compiler/xla:xla_data_proto",
] or []) + (enable_xla_hlo_profiling and [
"//tensorflow/compiler/xla/service:hlo_profile_printer_data",
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
# The cc_test rule for the generated code. To ensure that this works
# reliably across build configurations, we must use tf_cc_test instead of
# native.cc_test. This is related to how we build
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
# for more details.
tf_cc_test(
name=test_name,
srcs=[test_file],
deps=[
":" + name,
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/aot:tf_library_test_main",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] or []) + (deps or []),
tags = tags,
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
tags=tags,
)
# Variables used for gen_test and gen_benchmark.
no_ns_name = ""
cpp_class_split = cpp_class.rsplit("::", maxsplit = 2)
if len(cpp_class_split) == 1:
no_ns_name = cpp_class_split[0]
else:
no_ns_name = cpp_class_split[1]
sed_replace = (
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" "
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# Rule to rewrite benchmark.cc to produce the benchmark_file.
native.genrule(
name=("gen_" + benchmark_name),
srcs=[
benchmark_main,
header_file,
],
testonly = testonly,
outs=[benchmark_file],
cmd=("sed " + sed_replace +
" $(location " + benchmark_main + ") " +
"> $(OUTS)"),
tags=tags,
)
if gen_test:
test_name = name + "_test"
test_file = test_name + ".cc"
# Rule to rewrite test.cc to produce the test_file.
native.genrule(
name = ("gen_" + test_name),
testonly = 1,
srcs = [
"//tensorflow/compiler/aot:test.cc",
header_file,
],
outs = [test_file],
cmd = ("sed " + sed_replace +
" $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags = tags,
)
# The cc_test rule for the generated code. To ensure that this works
# reliably across build configurations, we must use tf_cc_test instead of
# native.cc_test. This is related to how we build
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
# for more details.
tf_cc_test(
name = test_name,
srcs = [test_file],
deps = [
":" + name,
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/aot:tf_library_test_main",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
tags = tags,
)
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# Rule to rewrite benchmark.cc to produce the benchmark_file.
native.genrule(
name = ("gen_" + benchmark_name),
srcs = [
benchmark_main,
header_file,
],
testonly = testonly,
outs = [benchmark_file],
cmd = ("sed " + sed_replace +
" $(location " + benchmark_main + ") " +
"> $(OUTS)"),
tags = tags,
)
# The cc_benchmark rule for the generated code. This does not need the
# tf_cc_binary since we (by deliberate design) do not depend on
# //tensorflow/core:lib.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
native.cc_binary(
name = benchmark_name,
srcs = [benchmark_file],
testonly = testonly,
copts = tf_copts(),
linkopts = if_android(["-pie", "-s"]),
deps = [
":" + name,
"//tensorflow/compiler/aot:benchmark",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
"//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags = tags,
)
# The cc_benchmark rule for the generated code. This does not need the
# tf_cc_binary since we (by deliberate design) do not depend on
# //tensorflow/core:lib.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
native.cc_binary(
name=benchmark_name,
srcs=[benchmark_file],
testonly = testonly,
copts = tf_copts(),
linkopts = if_android(["-pie", "-s"]),
deps=[
":" + name,
"//tensorflow/compiler/aot:benchmark",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
"//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags=tags,
)
def target_llvm_triple():
"""Returns the target LLVM triple to be used for compiling the target."""
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"//tensorflow:android_armeabi": "armv5-none-android",
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})
"""Returns the target LLVM triple to be used for compiling the target."""
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"//tensorflow:android_armeabi": "armv5-none-android",
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -4,97 +4,88 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
def all_backends():
b = ["cpu"] + plugins.keys()
if cuda_is_configured():
return b + ["gpu"]
b = ["cpu"] + plugins.keys()
if cuda_is_configured():
return b + ["gpu"]
else:
return b
def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
disabled_backends=None, **kwargs):
"""Generates py_test targets, one per XLA backend.
This rule generates py_test() targets named name_backend, for each backend
in all_backends(). The rule also generates a test suite with named `name` that
tests all backends for the test.
For example, the following rule generates test cases foo_test_cpu,
foo_test_gpu, and a test suite name foo_test that tests both.
tf_xla_py_test(
name="foo_test",
srcs="foo_test.py",
deps=[...],
)
Args:
name: Name of the target.
srcs: Sources for the target.
deps: Dependencies of the target.
tags: Tags to apply to the generated targets.
data: Data dependencies of the target.
main: Same as py_test's main attribute.
disabled_backends: A list of backends that should not be tested. Supported
values include "cpu" and "gpu". If not specified, defaults to None.
**kwargs: keyword arguments passed onto the generated py_test() rules.
"""
if disabled_backends == None:
disabled_backends = []
enabled_backends = [b for b in all_backends() if b not in disabled_backends]
test_names = []
for backend in enabled_backends:
test_name = "{}_{}".format(name, backend)
backend_tags = ["tf_xla_{}".format(backend)]
backend_args = []
backend_deps = []
backend_data = []
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16"
]
backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
backend_args += ["--test_device=" + plugins[backend]["device"],
"--types=" + plugins[backend]["types"]]
backend_tags += plugins[backend]["tags"]
backend_args += plugins[backend]["args"]
backend_deps += plugins[backend]["deps"]
backend_data += plugins[backend]["data"]
else:
return b
fail("Unknown backend {}".format(backend))
def tf_xla_py_test(
name,
srcs = [],
deps = [],
tags = [],
data = [],
main = None,
disabled_backends = None,
**kwargs):
"""Generates py_test targets, one per XLA backend.
This rule generates py_test() targets named name_backend, for each backend
in all_backends(). The rule also generates a test suite with named `name` that
tests all backends for the test.
For example, the following rule generates test cases foo_test_cpu,
foo_test_gpu, and a test suite name foo_test that tests both.
tf_xla_py_test(
name="foo_test",
srcs="foo_test.py",
deps=[...],
native.py_test(
name=test_name,
srcs=srcs,
srcs_version="PY2AND3",
args=backend_args,
main="{}.py".format(name) if main == None else main,
data=data + backend_data,
deps=deps + backend_deps,
tags=tags + backend_tags,
**kwargs
)
test_names.append(test_name)
native.test_suite(name=name, tests=test_names)
Args:
name: Name of the target.
srcs: Sources for the target.
deps: Dependencies of the target.
tags: Tags to apply to the generated targets.
data: Data dependencies of the target.
main: Same as py_test's main attribute.
disabled_backends: A list of backends that should not be tested. Supported
values include "cpu" and "gpu". If not specified, defaults to None.
**kwargs: keyword arguments passed onto the generated py_test() rules.
"""
if disabled_backends == None:
disabled_backends = []
enabled_backends = [b for b in all_backends() if b not in disabled_backends]
test_names = []
for backend in enabled_backends:
test_name = "{}_{}".format(name, backend)
backend_tags = ["tf_xla_{}".format(backend)]
backend_args = []
backend_deps = []
backend_data = []
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
backend_args += [
"--test_device=" + plugins[backend]["device"],
"--types=" + plugins[backend]["types"],
]
backend_tags += plugins[backend]["tags"]
backend_args += plugins[backend]["args"]
backend_deps += plugins[backend]["deps"]
backend_data += plugins[backend]["data"]
else:
fail("Unknown backend {}".format(backend))
native.py_test(
name = test_name,
srcs = srcs,
srcs_version = "PY2AND3",
args = backend_args,
main = "{}.py".format(name) if main == None else main,
data = data + backend_data,
deps = deps + backend_deps,
tags = tags + backend_tags,
**kwargs
)
test_names.append(test_name)
native.test_suite(name = name, tests = test_names)
def generate_backend_suites(backends = []):
"""Generates per-backend test_suites that run all tests for a backend."""
if not backends:
backends = all_backends()
for backend in backends:
native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])
def generate_backend_suites(backends=[]):
"""Generates per-backend test_suites that run all tests for a backend."""
if not backends:
backends = all_backends()
for backend in backends:
native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])

View File

@ -18,12 +18,13 @@
# git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl
plugins = {
#"example": {
# "device":"XLA_MY_DEVICE",
# "types":"DT_FLOAT,DT_HALF,DT_INT32",
# "tags":[],
# "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"],
# "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"],
# "deps":[],
#},
#"example": {
# "device":"XLA_MY_DEVICE",
# "types":"DT_FLOAT,DT_HALF,DT_INT32",
# "tags":[],
# "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"],
# "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"],
# "deps":[],
#},
}

View File

@ -1,11 +1,12 @@
"""build_defs for service/cpu."""
def runtime_copts():
"""Returns copts used for CPU runtime libraries."""
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
"//tensorflow:android_arm": ["-mfpu=neon"],
"//conditions:default": [],
}) + select({
"//tensorflow:android": ["-O2"],
"//conditions:default": [],
}))
"""Returns copts used for CPU runtime libraries."""
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
"//tensorflow:android_arm": ["-mfpu=neon"],
"//conditions:default": []
}) + select({
"//tensorflow:android": ["-O2"],
"//conditions:default": []
}))

View File

@ -7,258 +7,252 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
"""Removes "gpu" from a backend list if CUDA is not enabled.
"""Removes "gpu" from a backend list if CUDA is not enabled.
This allows us to simply hardcode lists including "gpu" here and in the
BUILD file, without causing failures when CUDA isn't enabled.'
This allows us to simply hardcode lists including "gpu" here and in the
BUILD file, without causing failures when CUDA isn't enabled.'
Args:
backends: A list of backends to filter.
Args:
backends: A list of backends to filter.
Returns:
The filtered list of backends.
"""
if cuda_is_configured():
return backends
else:
return [backend for backend in backends if backend != "gpu"]
Returns:
The filtered list of backends.
"""
if cuda_is_configured():
return backends
else:
return [backend for backend in backends if backend != "gpu"]
def xla_test(
name,
srcs,
deps,
xla_test_library_deps = [],
backends = [],
blacklisted_backends = [],
args = [],
tags = [],
copts = [],
data = [],
backend_tags = {},
backend_args = {},
**kwargs):
"""Generates cc_test targets for the given XLA backends.
This rule generates a cc_test target for one or more XLA backends and also a
platform-agnostic cc_library rule. The arguments are identical to cc_test with
two additions: 'backends' and 'backend_args'. 'backends' specifies the
backends to generate tests for ("cpu", "gpu"), and
'backend_args'/'backend_tags' specifies backend-specific args parameters to
use when generating the cc_test.
def xla_test(name,
srcs,
deps,
xla_test_library_deps=[],
backends=[],
blacklisted_backends=[],
args=[],
tags=[],
copts=[],
data=[],
backend_tags={},
backend_args={},
**kwargs):
"""Generates cc_test targets for the given XLA backends.
The name of the cc_tests are the provided name argument with the backend name
appended, and the cc_library target name is the provided name argument with
"_lib" appended. For example, if name parameter is "foo_test", then the cpu
test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
This rule generates a cc_test target for one or more XLA backends and also a
platform-agnostic cc_library rule. The arguments are identical to cc_test with
two additions: 'backends' and 'backend_args'. 'backends' specifies the
backends to generate tests for ("cpu", "gpu"), and
'backend_args'/'backend_tags' specifies backend-specific args parameters to
use when generating the cc_test.
The cc_library target can be used to link with other plugins outside of
xla_test.
The name of the cc_tests are the provided name argument with the backend name
appended, and the cc_library target name is the provided name argument with
"_lib" appended. For example, if name parameter is "foo_test", then the cpu
test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
The build rule also defines a test suite ${name} which includes the tests for
each of the supported backends.
The cc_library target can be used to link with other plugins outside of
xla_test.
Each generated cc_test target has a tag indicating which backend the test is
for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
tags can be used to gather tests for a particular backend into a test_suite.
The build rule also defines a test suite ${name} which includes the tests for
each of the supported backends.
Examples:
Each generated cc_test target has a tag indicating which backend the test is
for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
tags can be used to gather tests for a particular backend into a test_suite.
# Generates the targets: foo_test_cpu and foo_test_gpu.
xla_test(
name = "foo_test",
srcs = ["foo_test.cc"],
backends = ["cpu", "gpu"],
deps = [...],
)
Examples:
# Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
# includes the additional arg "--special_cpu_flag".
xla_test(
name = "bar_test",
srcs = ["bar_test.cc"],
backends = ["cpu", "gpu"],
backend_args = {"cpu": ["--special_cpu_flag"]}
deps = [...],
)
The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
to the value 1 where ${BACKEND} is the uppercase name of the backend.
Args:
name: Name of the target.
srcs: Sources for the target.
deps: Dependencies of the target.
xla_test_library_deps: If set, the generated test targets will depend on the
respective cc_libraries generated by the xla_test_library rule.
backends: A list of backends to generate tests for. Supported values: "cpu",
"gpu". If this list is empty, the test will be generated for all supported
backends.
blacklisted_backends: A list of backends to NOT generate tests for.
args: Test arguments for the target.
tags: Tags for the target.
copts: Additional copts to pass to the build.
data: Additional data to pass to the build.
backend_tags: A dict mapping backend name to list of additional tags to
use for that target.
backend_args: A dict mapping backend name to list of additional args to
use for that target.
**kwargs: Additional keyword arguments to pass to native.cc_test.
"""
test_names = []
if not backends:
backends = all_backends
backends = [
backend
for backend in backends
if backend not in blacklisted_backends
]
native.cc_library(
name = "%s_lib" % name,
srcs = srcs,
copts = copts,
testonly = True,
deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
# Generates the targets: foo_test_cpu and foo_test_gpu.
xla_test(
name = "foo_test",
srcs = ["foo_test.cc"],
backends = ["cpu", "gpu"],
deps = [...],
)
for backend in filter_backends(backends):
test_name = "%s_%s" % (name, backend)
this_backend_tags = ["xla_%s" % backend]
this_backend_copts = []
this_backend_args = backend_args.get(backend, [])
this_backend_data = []
if backend == "cpu":
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
elif backend == "gpu":
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
this_backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
backend_deps = []
backend_deps += plugins[backend]["deps"]
this_backend_copts += plugins[backend]["copts"]
this_backend_tags += plugins[backend]["tags"]
this_backend_args += plugins[backend]["args"]
this_backend_data += plugins[backend]["data"]
else:
fail("Unknown backend %s" % backend)
# Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
# includes the additional arg "--special_cpu_flag".
xla_test(
name = "bar_test",
srcs = ["bar_test.cc"],
backends = ["cpu", "gpu"],
backend_args = {"cpu": ["--special_cpu_flag"]}
deps = [...],
)
if xla_test_library_deps:
for lib_dep in xla_test_library_deps:
backend_deps += ["%s_%s" % (lib_dep, backend)]
The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
to the value 1 where ${BACKEND} is the uppercase name of the backend.
tf_cc_test(
name = test_name,
srcs = srcs,
tags = tags + backend_tags.get(backend, []) + this_backend_tags,
extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
this_backend_copts,
args = args + this_backend_args,
deps = deps + backend_deps,
data = data + this_backend_data,
**kwargs
)
Args:
name: Name of the target.
srcs: Sources for the target.
deps: Dependencies of the target.
xla_test_library_deps: If set, the generated test targets will depend on the
respective cc_libraries generated by the xla_test_library rule.
backends: A list of backends to generate tests for. Supported values: "cpu",
"gpu". If this list is empty, the test will be generated for all supported
backends.
blacklisted_backends: A list of backends to NOT generate tests for.
args: Test arguments for the target.
tags: Tags for the target.
copts: Additional copts to pass to the build.
data: Additional data to pass to the build.
backend_tags: A dict mapping backend name to list of additional tags to
use for that target.
backend_args: A dict mapping backend name to list of additional args to
use for that target.
**kwargs: Additional keyword arguments to pass to native.cc_test.
"""
test_names = []
if not backends:
backends = all_backends
test_names.append(test_name)
backends = [backend for backend in backends
if backend not in blacklisted_backends]
native.test_suite(name = name, tests = test_names)
native.cc_library(
name="%s_lib" % name,
srcs=srcs,
copts=copts,
testonly=True,
deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
)
def xla_test_library(
name,
srcs,
hdrs = [],
deps = [],
backends = []):
"""Generates cc_library targets for the given XLA backends.
for backend in filter_backends(backends):
test_name = "%s_%s" % (name, backend)
this_backend_tags = ["xla_%s" % backend]
this_backend_copts = []
this_backend_args = backend_args.get(backend, [])
this_backend_data = []
if backend == "cpu":
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
elif backend == "gpu":
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
this_backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
backend_deps = []
backend_deps += plugins[backend]["deps"]
this_backend_copts += plugins[backend]["copts"]
this_backend_tags += plugins[backend]["tags"]
this_backend_args += plugins[backend]["args"]
this_backend_data += plugins[backend]["data"]
else:
fail("Unknown backend %s" % backend)
This rule forces the sources to be compiled for each backend so that the
backend specific macros could expand correctly. It's useful when test targets
in different directories referring to the same sources but test with different
arguments.
if xla_test_library_deps:
for lib_dep in xla_test_library_deps:
backend_deps += ["%s_%s" % (lib_dep, backend)]
Examples:
tf_cc_test(
name=test_name,
srcs=srcs,
tags=tags + backend_tags.get(backend, []) + this_backend_tags,
extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
this_backend_copts,
args=args + this_backend_args,
deps=deps + backend_deps,
data=data + this_backend_data,
**kwargs)
# Generates the targets: foo_test_library_cpu and foo_test_gpu.
xla_test_library(
name = "foo_test_library",
srcs = ["foo_test.cc"],
backends = ["cpu", "gpu"],
deps = [...],
)
# Then use the xla_test rule to generate test targets:
xla_test(
name = "foo_test",
srcs = [],
backends = ["cpu", "gpu"],
deps = [...],
xla_test_library_deps = [":foo_test_library"],
)
test_names.append(test_name)
Args:
name: Name of the target.
srcs: Sources for the target.
hdrs: Headers for the target.
deps: Dependencies of the target.
backends: A list of backends to generate libraries for.
Supported values: "cpu", "gpu". If this list is empty, the
library will be generated for all supported backends.
"""
native.test_suite(name=name, tests=test_names)
if not backends:
backends = all_backends
def xla_test_library(name,
srcs,
hdrs=[],
deps=[],
backends=[]):
"""Generates cc_library targets for the given XLA backends.
for backend in filter_backends(backends):
this_backend_copts = []
if backend in ["cpu", "gpu"]:
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
elif backend in plugins:
backend_deps = plugins[backend]["deps"]
this_backend_copts += plugins[backend]["copts"]
else:
fail("Unknown backend %s" % backend)
This rule forces the sources to be compiled for each backend so that the
backend specific macros could expand correctly. It's useful when test targets
in different directories referring to the same sources but test with different
arguments.
native.cc_library(
name = "%s_%s" % (name, backend),
srcs = srcs,
testonly = True,
hdrs = hdrs,
copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
this_backend_copts,
deps = deps + backend_deps,
)
Examples:
def generate_backend_suites(backends = []):
if not backends:
backends = all_backends
for backend in filter_backends(backends):
native.test_suite(
name = "%s_tests" % backend,
tags = ["xla_%s" % backend],
)
# Generates the targets: foo_test_library_cpu and foo_test_gpu.
xla_test_library(
name = "foo_test_library",
srcs = ["foo_test.cc"],
backends = ["cpu", "gpu"],
deps = [...],
)
# Then use the xla_test rule to generate test targets:
xla_test(
name = "foo_test",
srcs = [],
backends = ["cpu", "gpu"],
deps = [...],
xla_test_library_deps = [":foo_test_library"],
)
def generate_backend_test_macros(backends = []):
if not backends:
backends = all_backends
for backend in filter_backends(backends):
manifest = ""
if backend in plugins:
manifest = plugins[backend]["disabled_manifest"]
Args:
name: Name of the target.
srcs: Sources for the target.
hdrs: Headers for the target.
deps: Dependencies of the target.
backends: A list of backends to generate libraries for.
Supported values: "cpu", "gpu". If this list is empty, the
library will be generated for all supported backends.
"""
native.cc_library(
name = "test_macros_%s" % backend,
testonly = True,
srcs = ["test_macros.cc"],
hdrs = ["test_macros.h"],
copts = [
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
],
)
if not backends:
backends = all_backends
for backend in filter_backends(backends):
this_backend_copts = []
if backend in ["cpu", "gpu"]:
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
elif backend in plugins:
backend_deps = plugins[backend]["deps"]
this_backend_copts += plugins[backend]["copts"]
else:
fail("Unknown backend %s" % backend)
native.cc_library(
name = "%s_%s" % (name, backend),
srcs = srcs,
testonly = True,
hdrs = hdrs,
copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
+ this_backend_copts,
deps = deps + backend_deps,
)
def generate_backend_suites(backends=[]):
if not backends:
backends = all_backends
for backend in filter_backends(backends):
native.test_suite(name="%s_tests" % backend,
tags = ["xla_%s" % backend])
def generate_backend_test_macros(backends=[]):
if not backends:
backends = all_backends
for backend in filter_backends(backends):
manifest = ""
if backend in plugins:
manifest = plugins[backend]["disabled_manifest"]
native.cc_library(
name="test_macros_%s" % backend,
testonly = True,
srcs = ["test_macros.cc"],
hdrs = ["test_macros.h"],
copts = [
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
],
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
])

View File

@ -33,3 +33,4 @@
# }
plugins = {}

View File

@ -1,35 +1,30 @@
"""Wrapper around cc_proto_library used inside the XLA codebase."""
load(
"//tensorflow/core:platform/default/build_config.bzl",
"cc_proto_library",
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
load("//tensorflow/core:platform/default/build_config.bzl",
"cc_proto_library")
load("//tensorflow/core:platform/default/build_config_root.bzl",
"if_static")
# xla_proto_library() is a convenience wrapper around cc_proto_library.
def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs):
if kwargs.get("use_grpc_plugin"):
kwargs["use_grpc_namespace"] = True
cc_proto_library(
name = name,
srcs = srcs,
deps = deps,
cc_libs = if_static(
["@protobuf_archive//:protobuf"],
otherwise = ["@protobuf_archive//:protobuf_headers"],
),
protoc = "@protobuf_archive//:protoc",
testonly = testonly,
visibility = visibility,
**kwargs
)
def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs):
if kwargs.get('use_grpc_plugin'):
kwargs['use_grpc_namespace'] = True
cc_proto_library(name=name,
srcs=srcs,
deps=deps,
cc_libs = if_static(
["@protobuf_archive//:protobuf"],
otherwise=["@protobuf_archive//:protobuf_headers"],
),
protoc="@protobuf_archive//:protoc",
testonly=testonly,
visibility=visibility,
**kwargs)
def xla_py_grpc_library(**kwargs):
# Note: we don't currently define any special targets for Python GRPC in OSS.
_ignore = kwargs
pass
# Note: we don't currently define any special targets for Python GRPC in OSS.
_ignore = kwargs
pass
ORC_JIT_MEMORY_MAPPER_TARGETS = []

View File

@ -1,196 +1,193 @@
"""Generate Flatbuffer binary from json."""
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
def tflite_copts():
"""Defines compile time flags."""
copts = [
"-DFARMHASH_NO_CXX_STRING",
] + select({
str(Label("//tensorflow:android_arm64")): [
"-std=c++11",
"-O3",
],
str(Label("//tensorflow:android_arm")): [
"-mfpu=neon",
"-mfloat-abi=softfp",
"-std=c++11",
"-O3",
],
str(Label("//tensorflow:android_x86")): [
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
],
str(Label("//tensorflow:ios_x86_64")): [
"-msse4.1",
],
"//conditions:default": [],
}) + select({
str(Label("//tensorflow:with_default_optimizations")): [],
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
})
"""Defines compile time flags."""
copts = [
"-DFARMHASH_NO_CXX_STRING",
] + select({
str(Label("//tensorflow:android_arm64")): [
"-std=c++11",
"-O3",
],
str(Label("//tensorflow:android_arm")): [
"-mfpu=neon",
"-mfloat-abi=softfp",
"-std=c++11",
"-O3",
],
str(Label("//tensorflow:android_x86")): [
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
],
str(Label("//tensorflow:ios_x86_64")): [
"-msse4.1",
],
"//conditions:default": [],
}) + select({
str(Label("//tensorflow:with_default_optimizations")): [],
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
})
return copts
return copts
LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds"
def tflite_linkopts_unstripped():
"""Defines linker flags to reduce size of TFLite binary.
"""Defines linker flags to reduce size of TFLite binary.
These are useful when trying to investigate the relative size of the
symbols in TFLite.
These are useful when trying to investigate the relative size of the
symbols in TFLite.
Returns:
a select object with proper linkopts
"""
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
"-Wl,--icf=all", # Identical code folding.
],
})
Returns:
a select object with proper linkopts
"""
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
"-Wl,--icf=all", # Identical code folding.
],
})
def tflite_jni_linkopts_unstripped():
"""Defines linker flags to reduce size of TFLite binary with JNI.
"""Defines linker flags to reduce size of TFLite binary with JNI.
These are useful when trying to investigate the relative size of the
symbols in TFLite.
These are useful when trying to investigate the relative size of the
symbols in TFLite.
Returns:
a select object with proper linkopts
"""
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
"-Wl,--icf=all", # Identical code folding.
],
})
Returns:
a select object with proper linkopts
"""
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
"//tensorflow/contrib/lite:mips": [],
"//tensorflow/contrib/lite:mips64": [],
"//conditions:default": [
"-Wl,--icf=all", # Identical code folding.
],
})
def tflite_linkopts():
"""Defines linker flags to reduce size of TFLite binary."""
return tflite_linkopts_unstripped() + select({
"//tensorflow:android": [
"-s", # Omit symbol table.
],
"//conditions:default": [],
})
"""Defines linker flags to reduce size of TFLite binary."""
return tflite_linkopts_unstripped() + select({
"//tensorflow:android": [
"-s", # Omit symbol table.
],
"//conditions:default": [],
})
def tflite_jni_linkopts():
"""Defines linker flags to reduce size of TFLite binary with JNI."""
return tflite_jni_linkopts_unstripped() + select({
"//tensorflow:android": [
"-s", # Omit symbol table.
"-latomic", # Required for some uses of ISO C++11 <atomic> in x86.
],
"//conditions:default": [],
})
"""Defines linker flags to reduce size of TFLite binary with JNI."""
return tflite_jni_linkopts_unstripped() + select({
"//tensorflow:android": [
"-s", # Omit symbol table.
"-latomic", # Required for some uses of ISO C++11 <atomic> in x86.
],
"//conditions:default": [],
})
def tflite_jni_binary(
name,
copts = tflite_copts(),
linkopts = tflite_jni_linkopts(),
linkscript = LINKER_SCRIPT,
linkshared = 1,
linkstatic = 1,
deps = []):
"""Builds a jni binary for TFLite."""
linkopts = linkopts + [
"-Wl,--version-script", # Export only jni functions & classes.
"$(location {})".format(linkscript),
]
native.cc_binary(
name = name,
copts = copts,
linkshared = linkshared,
linkstatic = linkstatic,
deps = deps + [linkscript],
linkopts = linkopts,
)
def tflite_jni_binary(name,
copts=tflite_copts(),
linkopts=tflite_jni_linkopts(),
linkscript=LINKER_SCRIPT,
linkshared=1,
linkstatic=1,
deps=[]):
"""Builds a jni binary for TFLite."""
linkopts = linkopts + [
"-Wl,--version-script", # Export only jni functions & classes.
"$(location {})".format(linkscript),
]
native.cc_binary(
name=name,
copts=copts,
linkshared=linkshared,
linkstatic=linkstatic,
deps= deps + [linkscript],
linkopts=linkopts)
def tf_to_tflite(name, src, options, out):
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
Args:
name: Name of rule.
src: name of the input graphdef file.
options: options passed to TOCO.
out: name of the output flatbuffer file.
"""
Args:
name: Name of rule.
src: name of the input graphdef file.
options: options passed to TOCO.
out: name of the output flatbuffer file.
"""
toco_cmdline = " ".join([
"//tensorflow/contrib/lite/toco:toco",
"--input_format=TENSORFLOW_GRAPHDEF",
"--output_format=TFLITE",
("--input_file=$(location %s)" % src),
("--output_file=$(location %s)" % out),
] + options)
native.genrule(
name = name,
srcs = [src],
outs = [out],
cmd = toco_cmdline,
tools = ["//tensorflow/contrib/lite/toco:toco"],
)
toco_cmdline = " ".join([
"//tensorflow/contrib/lite/toco:toco",
"--input_format=TENSORFLOW_GRAPHDEF",
"--output_format=TFLITE",
("--input_file=$(location %s)" % src),
("--output_file=$(location %s)" % out),
] + options )
native.genrule(
name = name,
srcs=[src],
outs=[out],
cmd = toco_cmdline,
tools= ["//tensorflow/contrib/lite/toco:toco"],
)
def tflite_to_json(name, src, out):
"""Convert a TF Lite flatbuffer to JSON.
"""Convert a TF Lite flatbuffer to JSON.
Args:
name: Name of rule.
src: name of the input flatbuffer file.
out: name of the output JSON file.
"""
Args:
name: Name of rule.
src: name of the input flatbuffer file.
out: name of the output JSON file.
"""
flatc = "@flatbuffers//:flatc"
schema = "//tensorflow/contrib/lite/schema:schema.fbs"
native.genrule(
name = name,
srcs = [schema, src],
outs = [out],
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
"$(location %s) --raw-binary --strict-json -t" +
" -o /tmp $(location %s) -- $${TMP}.bin &&" +
"cp $${TMP}.json $(location %s)") %
(src, flatc, schema, out),
tools = [flatc],
)
flatc = "@flatbuffers//:flatc"
schema = "//tensorflow/contrib/lite/schema:schema.fbs"
native.genrule(
name = name,
srcs = [schema, src],
outs = [out],
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
"$(location %s) --raw-binary --strict-json -t" +
" -o /tmp $(location %s) -- $${TMP}.bin &&" +
"cp $${TMP}.json $(location %s)")
% (src, flatc, schema, out),
tools = [flatc],
)
def json_to_tflite(name, src, out):
"""Convert a JSON file to TF Lite's flatbuffer.
"""Convert a JSON file to TF Lite's flatbuffer.
Args:
name: Name of rule.
src: name of the input JSON file.
out: name of the output flatbuffer file.
"""
Args:
name: Name of rule.
src: name of the input JSON file.
out: name of the output flatbuffer file.
"""
flatc = "@flatbuffers//:flatc"
schema = "//tensorflow/contrib/lite/schema:schema_fbs"
native.genrule(
name = name,
srcs = [schema, src],
outs = [out],
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
"$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
" -o /tmp $(location %s) $${TMP}.json &&" +
"cp $${TMP}.bin $(location %s)") %
(src, flatc, schema, out),
tools = [flatc],
)
flatc = "@flatbuffers//:flatc"
schema = "//tensorflow/contrib/lite/schema:schema_fbs"
native.genrule(
name = name,
srcs = [schema, src],
outs = [out],
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
"$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
" -o /tmp $(location %s) $${TMP}.json &&" +
"cp $${TMP}.bin $(location %s)")
% (src, flatc, schema, out),
tools = [flatc],
)
# This is the master list of generated examples that will be made into tests. A
# function called make_XXX_tests() must also appear in generate_examples.py.
@ -265,58 +262,58 @@ def generated_test_models():
]
def gen_zip_test(name, test_name, **kwargs):
"""Generate a zipped-example test and its dependent zip files.
"""Generate a zipped-example test and its dependent zip files.
Args:
name: Resulting cc_test target name
test_name: Test targets this model. Comes from the list above.
**kwargs: tf_cc_test kwargs.
"""
gen_zipped_test_file(
name = "zip_%s" % test_name,
file = "%s.zip" % test_name,
)
tf_cc_test(name, **kwargs)
Args:
name: Resulting cc_test target name
test_name: Test targets this model. Comes from the list above.
**kwargs: tf_cc_test kwargs.
"""
gen_zipped_test_file(
name = "zip_%s" % test_name,
file = "%s.zip" % test_name,
)
tf_cc_test(name, **kwargs)
def gen_zipped_test_file(name, file):
"""Generate a zip file of tests by using :generate_examples.
"""Generate a zip file of tests by using :generate_examples.
Args:
name: Name of output. We will produce "`file`.files" as a target.
file: The name of one of the generated_examples targets, e.g. "transpose"
"""
toco = "//tensorflow/contrib/lite/toco:toco"
native.genrule(
name = file + ".files",
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco +
" --zip_to_output " + file + " $(@D)"),
outs = [file],
tools = [
":generate_examples",
toco,
],
)
Args:
name: Name of output. We will produce "`file`.files" as a target.
file: The name of one of the generated_examples targets, e.g. "transpose"
"""
toco = "//tensorflow/contrib/lite/toco:toco"
native.genrule(
name = file + ".files",
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
+ " --zip_to_output " + file + " $(@D)"),
outs = [file],
tools = [
":generate_examples",
toco,
],
)
native.filegroup(
name = name,
srcs = [file],
)
native.filegroup(
name = name,
srcs = [file],
)
def gen_selected_ops(name, model):
"""Generate the library that includes only used ops.
"""Generate the library that includes only used ops.
Args:
name: Name of the generated library.
model: TFLite model to interpret.
"""
out = name + "_registration.cc"
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
tflite_path = "//tensorflow/contrib/lite"
native.genrule(
name = name,
srcs = [model],
outs = [out],
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
(tool, model, out, tflite_path[2:]),
tools = [tool],
)
Args:
name: Name of the generated library.
model: TFLite model to interpret.
"""
out = name + "_registration.cc"
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
tflite_path = "//tensorflow/contrib/lite"
native.genrule(
name = name,
srcs = [model],
outs = [out],
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
% (tool, model, out, tflite_path[2:]),
tools = [tool],
)

View File

@ -3,12 +3,12 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
def aar_with_jni(name, android_library):
# Generate dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app_for_so target below)
native.genrule(
name = name + "_binary_manifest_generator",
outs = [name + "_generated_AndroidManifest.xml"],
cmd = """
# Generate dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app_for_so target below)
native.genrule(
name = name + "_binary_manifest_generator",
outs = [name + "_generated_AndroidManifest.xml"],
cmd = """
cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
@ -17,27 +17,27 @@ cat > $(OUTS) <<EOF
</manifest>
EOF
""",
)
)
# Generate dummy apk including .so files and later we extract out
# .so files and throw away the apk.
android_binary(
name = name + "_dummy_app_for_so",
manifest = name + "_generated_AndroidManifest.xml",
custom_package = "dummy.package.for.so",
deps = [android_library],
# In some platforms we don't have an Android SDK/NDK and this target
# can't be built. We need to prevent the build system from trying to
# use the target in that case.
tags = ["manual"],
)
# Generate dummy apk including .so files and later we extract out
# .so files and throw away the apk.
android_binary(
name = name + "_dummy_app_for_so",
manifest = name + "_generated_AndroidManifest.xml",
custom_package = "dummy.package.for.so",
deps = [android_library],
# In some platforms we don't have an Android SDK/NDK and this target
# can't be built. We need to prevent the build system from trying to
# use the target in that case.
tags = ["manual"],
)
native.genrule(
name = name,
srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
outs = [name + ".aar"],
tags = ["manual"],
cmd = """
native.genrule(
name = name,
srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
outs = [name + ".aar"],
tags = ["manual"],
cmd = """
cp $(location {}.aar) $(location :{}.aar)
chmod +w $(location :{}.aar)
origdir=$$PWD
@ -46,4 +46,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*"
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
)
)

View File

@ -1,6 +1,6 @@
"""External versions of build rules that differ outside of Google."""
def tflite_portable_test_suite(**kwargs):
"""This is a no-op outside of Google."""
_ignore = [kwargs]
pass
"""This is a no-op outside of Google."""
_ignore = [kwargs]
pass

View File

@ -9,87 +9,81 @@ load("//tensorflow:tensorflow.bzl", "register_extension_info")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
def _test_name(test, path):
return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
def decode_proto_test_suite(name, examples):
"""Build the decode_proto py_test for each test filename."""
for test_filename in examples:
tf_py_test(
name = _test_name("decode_proto", test_filename),
srcs = ["decode_proto_op_test.py"],
size = "small",
data = [test_filename] + if_static(
[],
otherwise = [":libtestexample.so"],
),
main = "decode_proto_op_test.py",
args = [
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
],
additional_deps = [
":py_test_deps",
"//third_party/py/numpy",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
],
tags = [
"no_pip", # TODO(b/78026780)
"no_windows", # TODO(b/78028010)
],
)
native.test_suite(
name = name,
tests = [
":" + _test_name("decode_proto", test_filename)
for test_filename in examples
"""Build the decode_proto py_test for each test filename."""
for test_filename in examples:
tf_py_test(
name = _test_name("decode_proto", test_filename),
srcs = ["decode_proto_op_test.py"],
size = "small",
data = [test_filename] + if_static(
[],
otherwise = [":libtestexample.so"],
),
main = "decode_proto_op_test.py",
args = [
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
],
additional_deps = [
":py_test_deps",
"//third_party/py/numpy",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
],
tags = [
"no_pip", # TODO(b/78026780)
"no_windows", # TODO(b/78028010)
],
)
native.test_suite(
name = name,
tests = [":" + _test_name("decode_proto", test_filename)
for test_filename in examples],
)
def encode_proto_test_suite(name, examples):
"""Build the encode_proto py_test for each test filename."""
for test_filename in examples:
tf_py_test(
name = _test_name("encode_proto", test_filename),
srcs = ["encode_proto_op_test.py"],
size = "small",
data = [test_filename] + if_static(
[],
otherwise = [":libtestexample.so"],
),
main = "encode_proto_op_test.py",
args = [
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
],
additional_deps = [
":py_test_deps",
"//third_party/py/numpy",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
tags = [
"no_pip", # TODO(b/78026780)
"no_windows", # TODO(b/78028010)
],
)
native.test_suite(
name = name,
tests = [
":" + _test_name("encode_proto", test_filename)
for test_filename in examples
"""Build the encode_proto py_test for each test filename."""
for test_filename in examples:
tf_py_test(
name = _test_name("encode_proto", test_filename),
srcs = ["encode_proto_op_test.py"],
size = "small",
data = [test_filename] + if_static(
[],
otherwise = [":libtestexample.so"],
),
main = "encode_proto_op_test.py",
args = [
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
],
additional_deps = [
":py_test_deps",
"//third_party/py/numpy",
"//tensorflow/contrib/proto:proto",
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
],
tags = [
"no_pip", # TODO(b/78026780)
"no_windows", # TODO(b/78028010)
],
)
native.test_suite(
name = name,
tests = [":" + _test_name("encode_proto", test_filename)
for test_filename in examples],
)
register_extension_info(
extension_name = "decode_proto_test_suite",
label_regex_map = {
"deps": "deps:decode_example_.*",
},
)
})
register_extension_info(
extension_name = "encode_proto_test_suite",
label_regex_map = {
"deps": "deps:encode_example_.*",
},
)
})

View File

@ -1,13 +1,13 @@
"""Fuzzing template for TensorFlow ops."""
def tf_ops_fuzz_target_lib(name):
native.cc_library(
name = name + "_fuzz_lib",
srcs = [name + "_fuzz.cc"],
deps = [
"//tensorflow/core/kernels/fuzzing:fuzz_session",
"//tensorflow/cc:cc_ops",
],
tags = ["no_windows"],
alwayslink = 1,
)
native.cc_library(
name = name + "_fuzz_lib",
srcs = [name + "_fuzz.cc"],
deps = [
"//tensorflow/core/kernels/fuzzing:fuzz_session",
"//tensorflow/cc:cc_ops",
],
tags = ["no_windows"],
alwayslink = 1,
)

File diff suppressed because it is too large Load Diff

View File

@ -3,58 +3,58 @@
# be separate to avoid cyclic references.
def tf_cuda_tests_tags():
return ["requires-gpu"]
return ["requires-gpu"]
def tf_sycl_tests_tags():
return ["requires-gpu"]
return ["requires-gpu"]
def tf_additional_plugin_deps():
return select({
str(Label("//tensorflow:with_xla_support")): [
str(Label("//tensorflow/compiler/jit")),
],
"//conditions:default": [],
})
return select({
str(Label("//tensorflow:with_xla_support")): [
str(Label("//tensorflow/compiler/jit"))
],
"//conditions:default": [],
})
def tf_additional_xla_deps_py():
return []
return []
def tf_additional_grpc_deps_py():
return []
return []
def tf_additional_license_deps():
return select({
str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
"//conditions:default": [],
})
return select({
str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
"//conditions:default": [],
})
def tf_additional_verbs_deps():
return select({
str(Label("//tensorflow:with_verbs_support")): [
str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
],
"//conditions:default": [],
})
return select({
str(Label("//tensorflow:with_verbs_support")): [
str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
],
"//conditions:default": [],
})
def tf_additional_mpi_deps():
return select({
str(Label("//tensorflow:with_mpi_support")): [
str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
],
"//conditions:default": [],
})
return select({
str(Label("//tensorflow:with_mpi_support")): [
str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
],
"//conditions:default": [],
})
def tf_additional_gdr_deps():
return select({
str(Label("//tensorflow:with_gdr_support")): [
str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
],
"//conditions:default": [],
})
return select({
str(Label("//tensorflow:with_gdr_support")): [
str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
],
"//conditions:default": [],
})
def if_static(extra_deps, otherwise = []):
return select({
str(Label("//tensorflow:framework_shared_object")): otherwise,
"//conditions:default": extra_deps,
})
def if_static(extra_deps, otherwise=[]):
return select({
str(Label("//tensorflow:framework_shared_object")): otherwise,
"//conditions:default": extra_deps,
})

View File

@ -5,52 +5,55 @@ CUDNN_VERSION = ""
PLATFORM = ""
def cuda_sdk_version():
return CUDA_VERSION
return CUDA_VERSION
def cudnn_sdk_version():
return CUDNN_VERSION
return CUDNN_VERSION
def cuda_library_path(name, version = cuda_sdk_version()):
if PLATFORM == "Darwin":
if not version:
return "lib/lib{}.dylib".format(name)
else:
return "lib/lib{}.{}.dylib".format(name, version)
elif not version:
return "lib64/lib{}.so".format(name)
if PLATFORM == "Darwin":
if not version:
return "lib/lib{}.dylib".format(name)
else:
return "lib64/lib{}.so.{}".format(name, version)
return "lib/lib{}.{}.dylib".format(name, version)
else:
if not version:
return "lib64/lib{}.so".format(name)
else:
return "lib64/lib{}.so.{}".format(name, version)
def cuda_static_library_path(name):
if PLATFORM == "Darwin":
return "lib/lib{}_static.a".format(name)
else:
return "lib64/lib{}_static.a".format(name)
if PLATFORM == "Darwin":
return "lib/lib{}_static.a".format(name)
else:
return "lib64/lib{}_static.a".format(name)
def cudnn_library_path(version = cudnn_sdk_version()):
if PLATFORM == "Darwin":
if not version:
return "lib/libcudnn.dylib"
else:
return "lib/libcudnn.{}.dylib".format(version)
elif not version:
return "lib64/libcudnn.so"
if PLATFORM == "Darwin":
if not version:
return "lib/libcudnn.dylib"
else:
return "lib64/libcudnn.so.{}".format(version)
return "lib/libcudnn.{}.dylib".format(version)
else:
if not version:
return "lib64/libcudnn.so"
else:
return "lib64/libcudnn.so.{}".format(version)
def cupti_library_path(version = cuda_sdk_version()):
if PLATFORM == "Darwin":
if not version:
return "extras/CUPTI/lib/libcupti.dylib"
else:
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
elif not version:
return "extras/CUPTI/lib64/libcupti.so"
if PLATFORM == "Darwin":
if not version:
return "extras/CUPTI/lib/libcupti.dylib"
else:
return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
else:
if not version:
return "extras/CUPTI/lib64/libcupti.so"
else:
return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
def readlink_command():
if PLATFORM == "Darwin":
return "greadlink"
else:
return "readlink"
if PLATFORM == "Darwin":
return "greadlink"
else:
return "readlink"

View File

@ -18,7 +18,7 @@ XLINT_OPTS = [
"-Xlint:-processing",
"-Xlint:-serial",
"-Xlint:-try",
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
]
# The bazel errorprone plugin currently only enables default errorChecks

View File

@ -17,48 +17,46 @@ load(
# and then archive those source files into
# ops/gen_sources.srcjar
#
def tf_java_op_gen_srcjar(
name,
gen_tool,
base_package,
api_def_srcs = [],
out_dir = "ops/",
out_src_dir = "src/main/java/",
visibility = ["//tensorflow/java:__pkg__"]):
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
srcs = api_def_srcs[:]
def tf_java_op_gen_srcjar(name,
gen_tool,
base_package,
api_def_srcs=[],
out_dir="ops/",
out_src_dir="src/main/java/",
visibility=["//tensorflow/java:__pkg__"]):
if not api_def_srcs:
api_def_args_str = ","
else:
api_def_args = []
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))",
)
api_def_args_str = ",".join(api_def_args)
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
srcs = api_def_srcs[:]
gen_cmds += ["$(location " + gen_tool + ")" +
" --output_dir=$(@D)/" + out_src_dir +
" --base_package=" + base_package +
" --api_dirs=" + api_def_args_str]
if not api_def_srcs:
api_def_args_str = ","
else:
api_def_args = []
for api_def_src in api_def_srcs:
# Add directory of the first ApiDef source to args.
# We are assuming all ApiDefs in a single api_def_src are in the
# same directory.
api_def_args.append(
"$$(dirname $$(echo $(locations " + api_def_src +
") | cut -d\" \" -f1))")
api_def_args_str = ",".join(api_def_args)
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
gen_cmds += ["$(location " + gen_tool + ")" +
" --output_dir=$(@D)/" + out_src_dir +
" --base_package=" + base_package +
" --api_dirs=" + api_def_args_str]
native.genrule(
name = name,
srcs = srcs,
outs = [gen_srcjar],
tools = [
"@local_jdk//:jar",
"@local_jdk//:jdk",
gen_tool,
] + tf_binary_additional_srcs(),
cmd = " && ".join(gen_cmds),
)
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
native.genrule(
name=name,
srcs=srcs,
outs=[gen_srcjar],
tools=[
"@local_jdk//:jar",
"@local_jdk//:jdk",
gen_tool
] + tf_binary_additional_srcs(),
cmd=" && ".join(gen_cmds))

View File

@ -12,26 +12,22 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
# consumers of the tf_gen_op_wrapper_py rule would be simplified if we don't
# hard code the ops/ directory.
def tf_gen_op_wrapper_private_py(
name,
out = None,
deps = [],
require_shape_functions = True,
visibility = []):
if not name.endswith("_gen"):
fail("name must end in _gen")
if not visibility:
visibility = ["//visibility:private"]
bare_op_name = name[:-4] # Strip off the _gen
tf_gen_op_wrapper_py(
name = bare_op_name,
out = out,
visibility = visibility,
deps = deps,
require_shape_functions = require_shape_functions,
generated_target_name = name,
api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:python_api_def",
],
)
def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
require_shape_functions=True,
visibility=[]):
if not name.endswith("_gen"):
fail("name must end in _gen")
if not visibility:
visibility = ["//visibility:private"]
bare_op_name = name[:-4] # Strip off the _gen
tf_gen_op_wrapper_py(name=bare_op_name,
out=out,
visibility=visibility,
deps=deps,
require_shape_functions=require_shape_functions,
generated_target_name=name,
api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
"//tensorflow/core/api_def:python_api_def",
],
)

File diff suppressed because it is too large Load Diff

View File

@ -134,7 +134,7 @@ def gen_api_init_files(
package_dep = "//tensorflow/python:no_contrib"):
root_init_template_flag = ""
if root_init_template:
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
api_gen_binary_target = "create_" + package + "_api"
native.py_binary(
@ -154,9 +154,8 @@ def gen_api_init_files(
outs = output_files,
cmd = (
"$(location :" + api_gen_binary_target + ") " +
root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"
),
root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
srcs = srcs,
tools = [":" + api_gen_binary_target],
tools = [":" + api_gen_binary_target ],
visibility = ["//tensorflow:__pkg__"],
)

View File

@ -24,27 +24,27 @@ load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool")
load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "auto_configure_fail")
def _def_file_filter_configure_impl(repository_ctx):
if repository_ctx.os.name.lower().find("windows") == -1:
repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
repository_ctx.file("def_file_filter.py", "")
return
vc_path = find_vc_path(repository_ctx)
if vc_path == "visual-studio-not-found":
auto_configure_fail("Visual C++ build tools not found on your machine")
undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe")
if undname == None:
auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path)
undname_bin_path = undname.replace("\\", "\\\\")
repository_ctx.template(
"def_file_filter.py",
Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"),
{
"%{undname_bin_path}": undname_bin_path,
},
)
if repository_ctx.os.name.lower().find("windows") == -1:
repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
repository_ctx.file("def_file_filter.py", "")
return
vc_path = find_vc_path(repository_ctx)
if vc_path == "visual-studio-not-found":
auto_configure_fail("Visual C++ build tools not found on your machine")
undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe")
if undname == None:
auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path)
undname_bin_path = undname.replace("\\", "\\\\")
repository_ctx.template(
"def_file_filter.py",
Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"),
{
"%{undname_bin_path}": undname_bin_path,
})
repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
def_file_filter_configure = repository_rule(
implementation = _def_file_filter_configure_impl,
@ -55,6 +55,6 @@ def_file_filter_configure = repository_rule(
"VS100COMNTOOLS",
"VS110COMNTOOLS",
"VS120COMNTOOLS",
"VS140COMNTOOLS",
"VS140COMNTOOLS"
],
)

View File

@ -4,66 +4,60 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test")
# Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test)
def tf_cc_logged_benchmark(
name = None,
target = None,
benchmarks = "..",
tags = [],
test_log_output_prefix = "",
benchmark_type = "cpp_microbenchmark"):
if not name:
fail("Must provide a name")
if not target:
fail("Must provide a target")
if (not ":" in target or
not target.startswith("//") or
target.endswith(":all") or
target.endswith(".")):
fail(" ".join((
"Target must be a single well-defined test, e.g.,",
"//path/to:test. Received: %s" % target,
)))
name=None,
target=None,
benchmarks="..",
tags=[],
test_log_output_prefix="",
benchmark_type="cpp_microbenchmark"):
if not name:
fail("Must provide a name")
if not target:
fail("Must provide a target")
if (not ":" in target
or not target.startswith("//")
or target.endswith(":all")
or target.endswith(".")):
fail(" ".join(("Target must be a single well-defined test, e.g.,",
"//path/to:test. Received: %s" % target)))
all_tags = (
depset(tags) + depset(
["benchmark-test", "local", "manual", "regression-test"],
)
).to_list()
all_tags = (
depset(tags) + depset(
["benchmark-test", "local", "manual", "regression-test"])).to_list()
tf_py_test(
name = name,
tags = all_tags,
size = "large",
srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
args = [
"--name=//%s:%s" % (native.package_name(), name),
"--test_name=" + target,
"--test_args=--benchmarks=%s" % benchmarks,
"--benchmark_type=%s" % benchmark_type,
],
data = [
target,
],
main = "run_and_gather_logs.py",
additional_deps = [
"//tensorflow/tools/test:run_and_gather_logs",
],
)
tf_py_test(
name = name,
tags = all_tags,
size = "large",
srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
args = [
"--name=//%s:%s" % (native.package_name(), name),
"--test_name=" + target,
"--test_args=--benchmarks=%s" % benchmarks,
"--benchmark_type=%s" % benchmark_type,
],
data = [
target,
],
main = "run_and_gather_logs.py",
additional_deps = [
"//tensorflow/tools/test:run_and_gather_logs"
])
# Create a benchmark test target of a TensorFlow python test (*py_tests)
def tf_py_logged_benchmark(
name = None,
target = None,
benchmarks = "..",
tags = [],
test_log_output_prefix = ""):
# For now generating a py benchmark is the same as generating a C++
# benchmark target. In the future this may change, so we have
# two macros just in case
tf_cc_logged_benchmark(
name = name,
target = target,
benchmarks = benchmarks,
tags = tags,
test_log_output_prefix = test_log_output_prefix,
benchmark_type = "python_benchmark",
)
name=None,
target=None,
benchmarks="..",
tags=[],
test_log_output_prefix=""):
# For now generating a py benchmark is the same as generating a C++
# benchmark target. In the future this may change, so we have
# two macros just in case
tf_cc_logged_benchmark(
name=name,
target=target,
benchmarks=benchmarks,
tags=tags,
test_log_output_prefix=test_log_output_prefix,
benchmark_type="python_benchmark")

View File

@ -1,50 +1,48 @@
""" Helpers to check minimum version of bazel."""
def _extract_version_number(bazel_version):
"""Extracts the semantic version number from a version string
"""Extracts the semantic version number from a version string
Args:
bazel_version: the version string that begins with the semantic version
e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash.
Args:
bazel_version: the version string that begins with the semantic version
e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash.
Returns:
The semantic version string, like "1.2.3".
"""
for i in range(len(bazel_version)):
c = bazel_version[i]
if not (c.isdigit() or c == "."):
return bazel_version[:i]
return bazel_version
Returns:
The semantic version string, like "1.2.3".
"""
for i in range(len(bazel_version)):
c = bazel_version[i]
if not (c.isdigit() or c == "."):
return bazel_version[:i]
return bazel_version
# Parse the bazel version string from `native.bazel_version`.
# e.g.
# "0.10.0rc1 abc123d" => (0, 10, 0)
# "0.3.0" => (0, 3, 0)
def _parse_bazel_version(bazel_version):
"""Parses a version string into a 3-tuple of ints
"""Parses a version string into a 3-tuple of ints
int tuples can be compared directly using binary operators (<, >).
int tuples can be compared directly using binary operators (<, >).
Args:
bazel_version: the Bazel version string
Args:
bazel_version: the Bazel version string
Returns:
An int 3-tuple of a (major, minor, patch) version.
"""
Returns:
An int 3-tuple of a (major, minor, patch) version.
"""
version = _extract_version_number(bazel_version)
return tuple([int(n) for n in version.split(".")])
version = _extract_version_number(bazel_version)
return tuple([int(n) for n in version.split(".")])
def check_bazel_version_at_least(minimum_bazel_version):
if "bazel_version" not in dir(native):
fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version)
elif not native.bazel_version:
print("\nCurrent Bazel is not a release version, cannot check for compatibility.")
print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version)
return
if "bazel_version" not in dir(native):
fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version)
elif not native.bazel_version:
print("\nCurrent Bazel is not a release version, cannot check for compatibility.")
print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version)
return
if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version):
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
native.bazel_version,
minimum_bazel_version,
))
if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version):
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
native.bazel_version, minimum_bazel_version))

File diff suppressed because it is too large Load Diff

View File

@ -36,39 +36,33 @@ _ANDROID_NDK_REPO_TEMPLATE = """
"""
def _android_autoconf_impl(repository_ctx):
"""Implementation of the android_autoconf repository rule."""
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
build_tools_version = repository_ctx.os.environ.get(
_ANDROID_BUILD_TOOLS_VERSION,
)
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
"""Implementation of the android_autoconf repository rule."""
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
build_tools_version = repository_ctx.os.environ.get(
_ANDROID_BUILD_TOOLS_VERSION)
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
sdk_rule = "pass"
if all([sdk_home, sdk_api_level, build_tools_version]):
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
sdk_home,
sdk_api_level,
build_tools_version,
)
sdk_rule = "pass"
if all([sdk_home, sdk_api_level, build_tools_version]):
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
sdk_home, sdk_api_level, build_tools_version)
ndk_rule = "pass"
if all([ndk_home, ndk_api_level]):
ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
ndk_rule = "pass"
if all([ndk_home, ndk_api_level]):
ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
repository_ctx.template(
"BUILD",
Label("//third_party/android:android_configure.BUILD.tpl"),
)
repository_ctx.template(
"android.bzl",
Label("//third_party/android:android.bzl.tpl"),
substitutions = {
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
},
)
repository_ctx.template(
"BUILD",
Label("//third_party/android:android_configure.BUILD.tpl"))
repository_ctx.template(
"android.bzl",
Label("//third_party/android:android.bzl.tpl"),
substitutions={
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
})
android_configure = repository_rule(
implementation = _android_autoconf_impl,

View File

@ -7,16 +7,16 @@ _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
_TF_NEED_CUDA = "TF_NEED_CUDA"
def _cc_clang_autoconf(repo_ctx):
if repo_ctx.os.environ.get(_TF_DOWNLOAD_CLANG) != "1":
return
if repo_ctx.os.environ.get(_TF_NEED_CUDA) == "1":
# Clang is handled separately for CUDA configs.
# See cuda_configure.bzl for more details.
return
if repo_ctx.os.environ.get(_TF_DOWNLOAD_CLANG) != "1":
return
if repo_ctx.os.environ.get(_TF_NEED_CUDA) == "1":
# Clang is handled separately for CUDA configs.
# See cuda_configure.bzl for more details.
return
download_clang(repo_ctx, out_folder = "extra_tools")
overriden_tools = {"gcc": "extra_tools/bin/clang"}
cc_autoconf_impl(repo_ctx, overriden_tools)
download_clang(repo_ctx, out_folder='extra_tools')
overriden_tools = {'gcc': 'extra_tools/bin/clang'}
cc_autoconf_impl(repo_ctx, overriden_tools)
cc_download_clang_toolchain = repository_rule(
environ = [

View File

@ -1,60 +1,54 @@
""" Helpers to download a recent clang release."""
def _get_platform_folder(os_name):
os_name = os_name.lower()
if os_name.startswith("windows"):
return "Win"
if os_name.startswith("mac os"):
return "Mac"
if not os_name.startswith("linux"):
fail("Unknown platform")
return "Linux_x64"
os_name = os_name.lower()
if os_name.startswith('windows'):
return 'Win'
if os_name.startswith('mac os'):
return 'Mac'
if not os_name.startswith('linux'):
fail('Unknown platform')
return 'Linux_x64'
def _download_chromium_clang(
repo_ctx,
platform_folder,
package_version,
sha256,
out_folder):
cds_url = "https://commondatastorage.googleapis.com/chromium-browser-clang"
cds_file = "clang-%s.tgz" % package_version
cds_full_url = "{0}/{1}/{2}".format(cds_url, platform_folder, cds_file)
repo_ctx.download_and_extract(cds_full_url, output = out_folder, sha256 = sha256)
def _download_chromium_clang(repo_ctx, platform_folder, package_version, sha256,
out_folder):
cds_url = 'https://commondatastorage.googleapis.com/chromium-browser-clang'
cds_file = 'clang-%s.tgz' % package_version
cds_full_url = '{0}/{1}/{2}'.format(cds_url, platform_folder, cds_file)
repo_ctx.download_and_extract(cds_full_url, output=out_folder, sha256=sha256)
def download_clang(repo_ctx, out_folder):
""" Download a fresh clang release and put it into out_folder.
""" Download a fresh clang release and put it into out_folder.
Clang itself will be located in 'out_folder/bin/clang'.
We currently download one of the latest releases of clang by the
Chromium project (see
https://chromium.googlesource.com/chromium/src/+/master/docs/clang.md).
Clang itself will be located in 'out_folder/bin/clang'.
We currently download one of the latest releases of clang by the
Chromium project (see
https://chromium.googlesource.com/chromium/src/+/master/docs/clang.md).
Args:
repo_ctx: An instance of repository_context object.
out_folder: A folder to extract the compiler into.
"""
# TODO(ibiryukov): we currently download and extract some extra tools in the
# clang release (e.g., sanitizers). We should probably remove the ones
# we don't need and document the ones we want provide in addition to clang.
Args:
repo_ctx: An instance of repository_context object.
out_folder: A folder to extract the compiler into.
"""
# TODO(ibiryukov): we currently download and extract some extra tools in the
# clang release (e.g., sanitizers). We should probably remove the ones
# we don't need and document the ones we want provide in addition to clang.
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
CLANG_REVISION = "335091"
CLANG_SUB_REVISION = 1
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
CLANG_REVISION = '335091'
CLANG_SUB_REVISION = 1
package_version = "%s-%s" % (CLANG_REVISION, CLANG_SUB_REVISION)
package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION)
checksums = {
"Linux_x64": "17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f",
"Mac": "9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468",
"Win": "e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e",
}
checksums = {
'Linux_x64':
'17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f',
'Mac':
'9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468',
'Win':
'e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e',
}
platform_folder = _get_platform_folder(repo_ctx.os.name)
_download_chromium_clang(
repo_ctx,
platform_folder,
package_version,
checksums[platform_folder],
out_folder,
)
platform_folder = _get_platform_folder(repo_ctx.os.name)
_download_chromium_clang(repo_ctx, platform_folder, package_version,
checksums[platform_folder], out_folder)

View File

@ -21,11 +21,11 @@
# substitutions: A dictionary mapping strings to their substitutions
def template_rule_impl(ctx):
ctx.template_action(
template = ctx.file.src,
output = ctx.outputs.out,
substitutions = ctx.attr.substitutions,
)
ctx.template_action(
template = ctx.file.src,
output = ctx.outputs.out,
substitutions = ctx.attr.substitutions,
)
template_rule = rule(
attrs = {

View File

@ -8,49 +8,66 @@ DEFAULT_FLATC_ARGS = [
"--gen-object-api",
]
def flatbuffer_library_public(
name,
srcs,
outs,
language_flag,
out_prefix = "",
includes = [],
include_paths = [],
flatc_args = DEFAULT_FLATC_ARGS,
reflection_name = "",
reflection_visiblity = None,
output_to_bindir = False):
"""Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
def flatbuffer_library_public(name,
srcs,
outs,
language_flag,
out_prefix="",
includes=[],
include_paths=[],
flatc_args=DEFAULT_FLATC_ARGS,
reflection_name="",
reflection_visiblity=None,
output_to_bindir=False):
'''Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
outs: Output files from flatc.
language_flag: Target language flag. One of [-c, -j, -js].
out_prefix: Prepend this path to the front of all generated files except on
single source targets. Usually is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
include_paths: Optional, list of paths the includes files can be found in.
flatc_args: Optional, list of additional arguments to pass to flatc.
reflection_name: Optional, if set this will generate the flatbuffer
reflection binaries for the schemas.
reflection_visiblity: The visibility of the generated reflection Fileset.
output_to_bindir: Passed to genrule for output to bin directory.
Outs:
filegroup(name): all generated source files.
Fileset([reflection_name]): (Optional) all generated reflection binaries.
"""
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
# '$(@D)' when given a single source target will give the appropriate
# directory. Appending 'out_prefix' is only necessary when given a build
# target with multiple sources.
output_directory = (
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)")
)
genrule_cmd = " ".join([
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
outs: Output files from flatc.
language_flag: Target language flag. One of [-c, -j, -js].
out_prefix: Prepend this path to the front of all generated files except on
single source targets. Usually is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
include_paths: Optional, list of paths the includes files can be found in.
flatc_args: Optional, list of additional arguments to pass to flatc.
reflection_name: Optional, if set this will generate the flatbuffer
reflection binaries for the schemas.
reflection_visiblity: The visibility of the generated reflection Fileset.
output_to_bindir: Passed to genrule for output to bin directory.
Outs:
filegroup(name): all generated source files.
Fileset([reflection_name]): (Optional) all generated reflection binaries.
'''
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
# '$(@D)' when given a single source target will give the appropriate
# directory. Appending 'out_prefix' is only necessary when given a build
# target with multiple sources.
output_directory = (
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)"))
genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
output_directory,
"$$f;",
"done",
])
native.genrule(
name=name,
srcs=srcs,
outs=outs,
output_to_bindir=output_to_bindir,
tools=includes + [flatc_path,],
cmd=genrule_cmd,
message="Generating flatbuffer files for %s:" % (name),)
if reflection_name:
reflection_genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
"-b --schema",
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
@ -58,156 +75,122 @@ def flatbuffer_library_public(
"$$f;",
"done",
])
native.genrule(
name = name,
srcs = srcs,
outs = outs,
output_to_bindir = output_to_bindir,
tools = includes + [flatc_path],
cmd = genrule_cmd,
message = "Generating flatbuffer files for %s:" % (name),
)
if reflection_name:
reflection_genrule_cmd = " ".join([
"for f in $(SRCS); do",
"$(location %s)" % (flatc_path),
"-b --schema",
" ".join(flatc_args),
" ".join(include_paths_cmd),
language_flag,
output_directory,
"$$f;",
"done",
])
reflection_outs = [
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1])
for s in srcs
]
native.genrule(
name = "%s_srcs" % reflection_name,
srcs = srcs,
outs = reflection_outs,
output_to_bindir = output_to_bindir,
tools = includes + [flatc_path],
cmd = reflection_genrule_cmd,
message = "Generating flatbuffer reflection binary for %s:" % (name),
)
native.Fileset(
name = reflection_name,
out = "%s_out" % reflection_name,
entries = [
native.FilesetEntry(files = reflection_outs),
],
visibility = reflection_visiblity,
)
def flatbuffer_cc_library(
name,
srcs,
srcs_filegroup_name = "",
out_prefix = "",
includes = [],
include_paths = [],
flatc_args = DEFAULT_FLATC_ARGS,
visibility = None,
srcs_filegroup_visibility = None,
gen_reflections = False):
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
filegroup into the `includes` parameter of any other
flatbuffer_cc_library that depends on this one's schemas.
out_prefix: Prepend this path to the front of all generated files. Usually
is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
** SEE REMARKS BELOW **
include_paths: Optional, list of paths the includes files can be found in.
flatc_args: Optional list of additional arguments to pass to flatc
(e.g. --gen-mutable).
visibility: The visibility of the generated cc_library. By default, use the
default visibility of the project.
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
By default, use the value of the visibility parameter above.
gen_reflections: Optional, if true this will generate the flatbuffer
reflection binaries for the schemas.
Outs:
filegroup([name]_srcs): all generated .h files.
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
Other flatbuffer_cc_library's can pass this in for their `includes`
parameter, if they depend on the schemas in this library.
Fileset([name]_reflection): (Optional) all generated reflection binaries.
cc_library([name]): library with sources and flatbuffers deps.
Remarks:
** Because the genrule used to call flatc does not have any trivial way of
computing the output list of files transitively generated by includes and
--gen-includes (the default) being defined for flatc, the --gen-includes
flag will not work as expected. The way around this is to add a dependency
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
For example you might define:
flatbuffer_cc_library(
name = "my_fbs",
srcs = [ "schemas/foo.fbs" ],
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
)
In which foo.fbs includes a few files from the Fileset defined at
//third_party/bazz:bazz_fbs_includes. When compiling the library that
includes foo_generated.h, and therefore has my_fbs as a dependency, it
will fail to find any of the bazz *_generated.h files unless you also
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
cc_library(
name = "my_lib",
deps = [
":my_fbs",
"//third_party/bazz:bazz_fbs"
],
)
Happy dependent Flatbuffering!
'''
output_headers = [
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1])
for s in srcs
reflection_outs = [
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1]) for s in srcs
]
reflection_name = "%s_reflection" % name if gen_reflections else ""
flatbuffer_library_public(
name = "%s_srcs" % (name),
srcs = srcs,
outs = output_headers,
language_flag = "-c",
out_prefix = out_prefix,
includes = includes,
include_paths = include_paths,
flatc_args = flatc_args,
reflection_name = reflection_name,
reflection_visiblity = visibility,
)
native.cc_library(
name = name,
hdrs = output_headers,
srcs = output_headers,
features = [
"-parse_headers",
native.genrule(
name= "%s_srcs" % reflection_name,
srcs=srcs,
outs=reflection_outs,
output_to_bindir=output_to_bindir,
tools=includes + [flatc_path,],
cmd=reflection_genrule_cmd,
message="Generating flatbuffer reflection binary for %s:" % (name),)
native.Fileset(
name=reflection_name,
out="%s_out" % reflection_name,
entries=[
native.FilesetEntry(files=reflection_outs),
],
deps = [
"@flatbuffers//:runtime_cc",
],
includes = ["."],
linkstatic = 1,
visibility = visibility,
visibility=reflection_visiblity
)
# A filegroup for the `srcs`. That is, all the schema files for this
# Flatbuffer set.
native.filegroup(
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
srcs = srcs,
visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
)
def flatbuffer_cc_library(name, srcs, srcs_filegroup_name="",
out_prefix="", includes=[], include_paths=[],
flatc_args=DEFAULT_FLATC_ARGS,
visibility=None, srcs_filegroup_visibility=None,
gen_reflections=False):
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
Args:
name: Rule name.
srcs: Source .fbs files. Sent in order to the compiler.
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
filegroup into the `includes` parameter of any other
flatbuffer_cc_library that depends on this one's schemas.
out_prefix: Prepend this path to the front of all generated files. Usually
is a directory name.
includes: Optional, list of filegroups of schemas that the srcs depend on.
** SEE REMARKS BELOW **
include_paths: Optional, list of paths the includes files can be found in.
flatc_args: Optional list of additional arguments to pass to flatc
(e.g. --gen-mutable).
visibility: The visibility of the generated cc_library. By default, use the
default visibility of the project.
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
By default, use the value of the visibility parameter above.
gen_reflections: Optional, if true this will generate the flatbuffer
reflection binaries for the schemas.
Outs:
filegroup([name]_srcs): all generated .h files.
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
Other flatbuffer_cc_library's can pass this in for their `includes`
parameter, if they depend on the schemas in this library.
Fileset([name]_reflection): (Optional) all generated reflection binaries.
cc_library([name]): library with sources and flatbuffers deps.
Remarks:
** Because the genrule used to call flatc does not have any trivial way of
computing the output list of files transitively generated by includes and
--gen-includes (the default) being defined for flatc, the --gen-includes
flag will not work as expected. The way around this is to add a dependency
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
For example you might define:
flatbuffer_cc_library(
name = "my_fbs",
srcs = [ "schemas/foo.fbs" ],
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
)
In which foo.fbs includes a few files from the Fileset defined at
//third_party/bazz:bazz_fbs_includes. When compiling the library that
includes foo_generated.h, and therefore has my_fbs as a dependency, it
will fail to find any of the bazz *_generated.h files unless you also
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
cc_library(
name = "my_lib",
deps = [
":my_fbs",
"//third_party/bazz:bazz_fbs"
],
)
Happy dependent Flatbuffering!
'''
output_headers = [
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1]) for s in srcs
]
reflection_name = "%s_reflection" % name if gen_reflections else ""
flatbuffer_library_public(name="%s_srcs" % (name),
srcs=srcs,
outs=output_headers,
language_flag="-c",
out_prefix=out_prefix,
includes=includes,
include_paths=include_paths,
flatc_args=flatc_args,
reflection_name=reflection_name,
reflection_visiblity=visibility,)
native.cc_library(name=name,
hdrs=output_headers,
srcs=output_headers,
features=[
"-parse_headers",
],
deps=[
"@flatbuffers//:runtime_cc",
],
includes=["."],
linkstatic=1,
visibility=visibility)
# A filegroup for the `srcs`. That is, all the schema files for this
# Flatbuffer set.
native.filegroup(
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
srcs = srcs,
visibility=srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility)

View File

@ -8,114 +8,102 @@ correctly understood by the build system.
"""
def gentbl(name, tblgen, td_file, td_srcs, tbl_outs, library = True, **kwargs):
"""gentbl() generates tabular code from a table definition file.
"""gentbl() generates tabular code from a table definition file.
Args:
name: The name of the build rule for use in dependencies.
tblgen: The binary used to produce the output.
td_file: The primary table definitions file.
td_srcs: A list of table definition files included transitively.
tbl_outs: A list of tuples (opts, out), where each opts is a string of
options passed to tblgen, and the out is the corresponding output file
produced.
library: Whether to bundle the generated files into a library.
**kwargs: Keyword arguments to pass to subsidiary cc_library() rule.
"""
if td_file not in td_srcs:
td_srcs += [td_file]
includes = []
for (opts, out) in tbl_outs:
outdir = out[:out.rindex("/")]
if outdir not in includes:
includes.append(outdir)
rule_suffix = "_".join(opts.replace("-", "_").replace("=", "_").split(" "))
native.genrule(
name = "%s_%s_genrule" % (name, rule_suffix),
srcs = td_srcs,
outs = [out],
tools = [tblgen],
message = "Generating code from table: %s" % td_file,
cmd = (("$(location %s) " + "-I external/llvm/include " +
"-I external/llvm/tools/clang/include " +
"-I $$(dirname $(location %s)) " + "%s $(location %s) -o $@") % (
tblgen,
td_file,
opts,
td_file,
)),
)
# For now, all generated files can be assumed to comprise public interfaces.
# If this is not true, you should specify library = False
# and list the generated '.inc' files in "srcs".
if library:
native.cc_library(
name = name,
textual_hdrs = [f for (_, f) in tbl_outs],
includes = includes,
**kwargs
)
Args:
name: The name of the build rule for use in dependencies.
tblgen: The binary used to produce the output.
td_file: The primary table definitions file.
td_srcs: A list of table definition files included transitively.
tbl_outs: A list of tuples (opts, out), where each opts is a string of
options passed to tblgen, and the out is the corresponding output file
produced.
library: Whether to bundle the generated files into a library.
**kwargs: Keyword arguments to pass to subsidiary cc_library() rule.
"""
if td_file not in td_srcs:
td_srcs += [td_file]
includes = []
for (opts, out) in tbl_outs:
outdir = out[:out.rindex("/")]
if outdir not in includes:
includes.append(outdir)
rule_suffix = "_".join(opts.replace("-", "_").replace("=", "_").split(" "))
native.genrule(
name="%s_%s_genrule" % (name, rule_suffix),
srcs=td_srcs,
outs=[out],
tools=[tblgen],
message="Generating code from table: %s" % td_file,
cmd=(("$(location %s) " + "-I external/llvm/include " +
"-I external/llvm/tools/clang/include " +
"-I $$(dirname $(location %s)) " + "%s $(location %s) -o $@") % (
tblgen, td_file, opts, td_file)))
# For now, all generated files can be assumed to comprise public interfaces.
# If this is not true, you should specify library = False
# and list the generated '.inc' files in "srcs".
if library:
native.cc_library(name=name, textual_hdrs=[f for (_, f) in tbl_outs],
includes=includes, **kwargs)
def llvm_target_cmake_vars(native_arch, target_triple):
return {
"LLVM_HOST_TRIPLE": target_triple,
"LLVM_DEFAULT_TARGET_TRIPLE": target_triple,
"LLVM_NATIVE_ARCH": native_arch,
}
return {
"LLVM_HOST_TRIPLE": target_triple,
"LLVM_DEFAULT_TARGET_TRIPLE": target_triple,
"LLVM_NATIVE_ARCH": native_arch,
}
def _quote(s):
"""Quotes the given string for use in a shell command.
"""Quotes the given string for use in a shell command.
This function double-quotes the given string (in case it contains spaces or
other special characters) and escapes any special characters (dollar signs,
double-quotes, and backslashes) that may be present.
This function double-quotes the given string (in case it contains spaces or
other special characters) and escapes any special characters (dollar signs,
double-quotes, and backslashes) that may be present.
Args:
s: The string to quote.
Returns:
An escaped and quoted version of the string that can be passed to a shell
command.
"""
return ('"' +
s.replace("\\", "\\\\").replace("$", "\\$").replace('"', '\\"') +
'"')
Args:
s: The string to quote.
Returns:
An escaped and quoted version of the string that can be passed to a shell
command.
"""
return ('"' +
s.replace("\\", "\\\\").replace("$", "\\$").replace('"', '\\"') +
'"')
def cmake_var_string(cmake_vars):
"""Converts a dictionary to an input suitable for expand_cmake_vars.
"""Converts a dictionary to an input suitable for expand_cmake_vars.
Ideally we would jist stringify in the expand_cmake_vars() rule, but select()
interacts badly with genrules.
Ideally we would jist stringify in the expand_cmake_vars() rule, but select()
interacts badly with genrules.
TODO(phawkins): replace the genrule() with native rule and delete this rule.
TODO(phawkins): replace the genrule() with native rule and delete this rule.
Args:
cmake_vars: a dictionary with string keys and values that are convertable to
strings.
"""
return " ".join([
_quote("{}={}".format(k, str(v)))
for (k, v) in cmake_vars.items()
])
Args:
cmake_vars: a dictionary with string keys and values that are convertable to
strings.
"""
return " ".join([_quote("{}={}".format(k, str(v)))
for (k, v) in cmake_vars.items()])
def expand_cmake_vars(name, src, dst, cmake_vars):
"""Expands #cmakedefine, #cmakedefine01, and CMake variables in a text file.
"""Expands #cmakedefine, #cmakedefine01, and CMake variables in a text file.
Args:
name: the name of the rule
src: the input of the rule
dst: the output of the rule
cmake_vars: a string containing the CMake variables, as generated by
cmake_var_string.
"""
expand_cmake_vars_tool = Label("@org_tensorflow//third_party/llvm:expand_cmake_vars")
native.genrule(
name = name,
srcs = [src],
tools = [expand_cmake_vars_tool],
outs = [dst],
cmd = ("$(location {}) ".format(expand_cmake_vars_tool) + cmake_vars +
"< $< > $@"),
)
Args:
name: the name of the rule
src: the input of the rule
dst: the output of the rule
cmake_vars: a string containing the CMake variables, as generated by
cmake_var_string.
"""
expand_cmake_vars_tool = Label("@org_tensorflow//third_party/llvm:expand_cmake_vars")
native.genrule(
name = name,
srcs = [src],
tools = [expand_cmake_vars_tool],
outs = [dst],
cmd = ("$(location {}) ".format(expand_cmake_vars_tool) + cmake_vars +
"< $< > $@")
)
# TODO(phawkins): the set of CMake variables was hardcoded for expediency.
# However, we should really detect many of these via configure-time tests.
@ -225,18 +213,17 @@ darwin_cmake_vars = {
llvm_all_cmake_vars = select({
"@org_tensorflow//tensorflow:darwin": cmake_var_string(
cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") +
darwin_cmake_vars,
),
darwin_cmake_vars),
"@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string(
cmake_vars +
llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") +
linux_cmake_vars,
),
"//conditions:default": cmake_var_string(
cmake_vars +
llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +
linux_cmake_vars,
),
cmake_vars +
llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +
linux_cmake_vars),
})
LLVM_LINKOPTS = ["-ldl", "-lm", "-lpthread"]

View File

@ -8,8 +8,10 @@ mkl_repository depends on the following environment variables:
* `TF_MKL_ROOT`: The root folder where a copy of libmkl is located.
"""
_TF_MKL_ROOT = "TF_MKL_ROOT"
def if_mkl(if_true, if_false = []):
"""Shorthand for select()'ing on whether we're building with MKL.
@ -19,7 +21,7 @@ def if_mkl(if_true, if_false = []):
"""
return select({
str(Label("//third_party/mkl:using_mkl")): if_true,
"//conditions:default": if_false,
"//conditions:default": if_false
})
def if_mkl_lnx_x64(if_true, if_false = []):
@ -31,34 +33,37 @@ def if_mkl_lnx_x64(if_true, if_false = []):
"""
return select({
str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
"//conditions:default": if_false,
"//conditions:default": if_false
})
def _enable_local_mkl(repository_ctx):
return _TF_MKL_ROOT in repository_ctx.os.environ
return _TF_MKL_ROOT in repository_ctx.os.environ
def _mkl_autoconf_impl(repository_ctx):
"""Implementation of the local_mkl_autoconf repository rule."""
"""Implementation of the local_mkl_autoconf repository rule."""
if _enable_local_mkl(repository_ctx):
# Symlink lib and include local folders.
mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT]
mkl_lib_path = "%s/lib" % mkl_root
repository_ctx.symlink(mkl_lib_path, "lib")
mkl_include_path = "%s/include" % mkl_root
repository_ctx.symlink(mkl_include_path, "include")
mkl_license_path = "%s/license.txt" % mkl_root
repository_ctx.symlink(mkl_license_path, "license.txt")
else:
# setup remote mkl repository.
repository_ctx.download_and_extract(
repository_ctx.attr.urls,
sha256 = repository_ctx.attr.sha256,
stripPrefix = repository_ctx.attr.strip_prefix,
)
if _enable_local_mkl(repository_ctx):
# Symlink lib and include local folders.
mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT]
mkl_lib_path = "%s/lib" % mkl_root
repository_ctx.symlink(mkl_lib_path, "lib")
mkl_include_path = "%s/include" % mkl_root
repository_ctx.symlink(mkl_include_path, "include")
mkl_license_path = "%s/license.txt" % mkl_root
repository_ctx.symlink(mkl_license_path, "license.txt")
else:
# setup remote mkl repository.
repository_ctx.download_and_extract(
repository_ctx.attr.urls,
sha256=repository_ctx.attr.sha256,
stripPrefix=repository_ctx.attr.strip_prefix,
)
# Also setup BUILD file.
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
# Also setup BUILD file.
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
mkl_repository = repository_rule(
implementation = _mkl_autoconf_impl,

View File

@ -2,16 +2,16 @@
#based on the configuration options return one or the other
def mpi_hdr():
MPI_LIB_IS_OPENMPI = True
hdrs = []
MPI_LIB_IS_OPENMPI=True
hdrs = []
if MPI_LIB_IS_OPENMPI:
hdrs = ["mpi.h", "mpi_portable_platform.h"] #When using OpenMPI
hdrs = ["mpi.h", "mpi_portable_platform.h"] #When using OpenMPI
else:
hdrs = ["mpi.h", "mpio.h", "mpicxx.h"] #When using MVAPICH
hdrs = ["mpi.h", "mpio.h", "mpicxx.h"] #When using MVAPICH
return hdrs
def if_mpi(if_true, if_false = []):
return select({
"//tensorflow:with_mpi_support": if_true,
"//conditions:default": if_false,
"//conditions:default": if_false
})

126
third_party/repo.bzl vendored
View File

@ -19,98 +19,90 @@ _SINGLE_URL_WHITELIST = depset([
])
def _is_windows(ctx):
return ctx.os.name.lower().find("windows") != -1
return ctx.os.name.lower().find("windows") != -1
def _wrap_bash_cmd(ctx, cmd):
if _is_windows(ctx):
bazel_sh = _get_env_var(ctx, "BAZEL_SH")
if not bazel_sh:
fail("BAZEL_SH environment variable is not set")
cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
return cmd
if _is_windows(ctx):
bazel_sh = _get_env_var(ctx, "BAZEL_SH")
if not bazel_sh:
fail("BAZEL_SH environment variable is not set")
cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
return cmd
def _get_env_var(ctx, name):
if name in ctx.os.environ:
return ctx.os.environ[name]
else:
return None
if name in ctx.os.environ:
return ctx.os.environ[name]
else:
return None
# Executes specified command with arguments and calls 'fail' if it exited with
# non-zero code
def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
result = repo_ctx.execute(cmd_and_args, timeout = 10)
if result.return_code != 0:
fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" +
"Stderr: {3}").format(
" ".join(cmd_and_args),
result.return_code,
result.stdout,
result.stderr,
))
result = repo_ctx.execute(cmd_and_args, timeout=10)
if result.return_code != 0:
fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n"
+ "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
result.stdout, result.stderr))
def _repos_are_siblings():
return Label("@foo//bar").workspace_root.startswith("../")
return Label("@foo//bar").workspace_root.startswith("../")
# Apply a patch_file to the repository root directory
# Runs 'patch -p1'
def _apply_patch(ctx, patch_file):
# Don't check patch on Windows, because patch is only available under bash.
if not _is_windows(ctx) and not ctx.which("patch"):
fail("patch command is not found, please install it")
cmd = _wrap_bash_cmd(
ctx,
["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)],
)
_execute_and_check_ret_code(ctx, cmd)
# Don't check patch on Windows, because patch is only available under bash.
if not _is_windows(ctx) and not ctx.which("patch"):
fail("patch command is not found, please install it")
cmd = _wrap_bash_cmd(
ctx, ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)])
_execute_and_check_ret_code(ctx, cmd)
def _apply_delete(ctx, paths):
for path in paths:
if path.startswith("/"):
fail("refusing to rm -rf path starting with '/': " + path)
if ".." in path:
fail("refusing to rm -rf path containing '..': " + path)
cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
_execute_and_check_ret_code(ctx, cmd)
for path in paths:
if path.startswith("/"):
fail("refusing to rm -rf path starting with '/': " + path)
if ".." in path:
fail("refusing to rm -rf path containing '..': " + path)
cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
_execute_and_check_ret_code(ctx, cmd)
def _tf_http_archive(ctx):
if ("mirror.bazel.build" not in ctx.attr.urls[0] and
(len(ctx.attr.urls) < 2 and
ctx.attr.name not in _SINGLE_URL_WHITELIST)):
fail("tf_http_archive(urls) must have redundant URLs. The " +
"mirror.bazel.build URL must be present and it must come first. " +
"Even if you don't have permission to mirror the file, please " +
"put the correctly formatted mirror URL there anyway, because " +
"someone will come along shortly thereafter and mirror the file.")
ctx.download_and_extract(
ctx.attr.urls,
"",
ctx.attr.sha256,
ctx.attr.type,
ctx.attr.strip_prefix,
)
if ctx.attr.delete:
_apply_delete(ctx, ctx.attr.delete)
if ctx.attr.patch_file != None:
_apply_patch(ctx, ctx.attr.patch_file)
if ctx.attr.build_file != None:
# Use BUILD.bazel to avoid conflict with third party projects with
# BUILD or build (directory) underneath.
ctx.template("BUILD.bazel", ctx.attr.build_file, {
"%prefix%": ".." if _repos_are_siblings() else "external",
}, False)
if ("mirror.bazel.build" not in ctx.attr.urls[0] and
(len(ctx.attr.urls) < 2 and
ctx.attr.name not in _SINGLE_URL_WHITELIST)):
fail("tf_http_archive(urls) must have redundant URLs. The " +
"mirror.bazel.build URL must be present and it must come first. " +
"Even if you don't have permission to mirror the file, please " +
"put the correctly formatted mirror URL there anyway, because " +
"someone will come along shortly thereafter and mirror the file.")
ctx.download_and_extract(
ctx.attr.urls,
"",
ctx.attr.sha256,
ctx.attr.type,
ctx.attr.strip_prefix)
if ctx.attr.delete:
_apply_delete(ctx, ctx.attr.delete)
if ctx.attr.patch_file != None:
_apply_patch(ctx, ctx.attr.patch_file)
if ctx.attr.build_file != None:
# Use BUILD.bazel to avoid conflict with third party projects with
# BUILD or build (directory) underneath.
ctx.template("BUILD.bazel", ctx.attr.build_file, {
"%prefix%": ".." if _repos_are_siblings() else "external",
}, False)
tf_http_archive = repository_rule(
implementation = _tf_http_archive,
attrs = {
"sha256": attr.string(mandatory = True),
"urls": attr.string_list(mandatory = True, allow_empty = False),
implementation=_tf_http_archive,
attrs={
"sha256": attr.string(mandatory=True),
"urls": attr.string_list(mandatory=True, allow_empty=False),
"strip_prefix": attr.string(),
"type": attr.string(),
"delete": attr.string_list(),
"patch_file": attr.label(),
"build_file": attr.label(),
},
)
})
"""Downloads and creates Bazel repos for dependencies.
This is a swappable replacement for both http_archive() and

View File

@ -11,124 +11,122 @@
"""
_HOST_CXX_COMPILER = "HOST_CXX_COMPILER"
_HOST_C_COMPILER = "HOST_C_COMPILER"
_HOST_C_COMPILER= "HOST_C_COMPILER"
_COMPUTECPP_TOOLKIT_PATH = "COMPUTECPP_TOOLKIT_PATH"
_TRISYCL_INCLUDE_DIR = "TRISYCL_INCLUDE_DIR"
_PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
def _enable_sycl(repository_ctx):
if "TF_NEED_OPENCL_SYCL" in repository_ctx.os.environ:
enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL_SYCL"].strip()
return enable_sycl == "1"
return False
if "TF_NEED_OPENCL_SYCL" in repository_ctx.os.environ:
enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL_SYCL"].strip()
return enable_sycl == "1"
return False
def _enable_compute_cpp(repository_ctx):
return _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ
return _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ
def auto_configure_fail(msg):
"""Output failure message when auto configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
"""Output failure message when auto configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
# END cc_configure common functions (see TODO above).
def find_c(repository_ctx):
"""Find host C compiler."""
c_name = "gcc"
if _HOST_C_COMPILER in repository_ctx.os.environ:
c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip()
if c_name.startswith("/"):
return c_name
c = repository_ctx.which(c_name)
if c == None:
fail("Cannot find C compiler, please correct your path.")
return c
"""Find host C compiler."""
c_name = "gcc"
if _HOST_C_COMPILER in repository_ctx.os.environ:
c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip()
if c_name.startswith("/"):
return c_name
c = repository_ctx.which(c_name)
if c == None:
fail("Cannot find C compiler, please correct your path.")
return c
def find_cc(repository_ctx):
"""Find host C++ compiler."""
cc_name = "g++"
if _HOST_CXX_COMPILER in repository_ctx.os.environ:
cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip()
if cc_name.startswith("/"):
return cc_name
cc = repository_ctx.which(cc_name)
if cc == None:
fail("Cannot find C++ compiler, please correct your path.")
return cc
"""Find host C++ compiler."""
cc_name = "g++"
if _HOST_CXX_COMPILER in repository_ctx.os.environ:
cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip()
if cc_name.startswith("/"):
return cc_name
cc = repository_ctx.which(cc_name)
if cc == None:
fail("Cannot find C++ compiler, please correct your path.")
return cc
def find_computecpp_root(repository_ctx):
"""Find ComputeCpp compiler."""
sycl_name = ""
if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ:
sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip()
if sycl_name.startswith("/"):
return sycl_name
fail("Cannot find SYCL compiler, please correct your path")
"""Find ComputeCpp compiler."""
sycl_name = ""
if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ:
sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip()
if sycl_name.startswith("/"):
return sycl_name
fail("Cannot find SYCL compiler, please correct your path")
def find_trisycl_include_dir(repository_ctx):
"""Find triSYCL include directory. """
if _TRISYCL_INCLUDE_DIR in repository_ctx.os.environ:
sycl_name = repository_ctx.os.environ[_TRISYCL_INCLUDE_DIR].strip()
if sycl_name.startswith("/"):
return sycl_name
fail("Cannot find triSYCL include directory, please correct your path")
"""Find triSYCL include directory. """
if _TRISYCL_INCLUDE_DIR in repository_ctx.os.environ:
sycl_name = repository_ctx.os.environ[_TRISYCL_INCLUDE_DIR].strip()
if sycl_name.startswith("/"):
return sycl_name
fail( "Cannot find triSYCL include directory, please correct your path")
def find_python_lib(repository_ctx):
"""Returns python path."""
if _PYTHON_LIB_PATH in repository_ctx.os.environ:
return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip()
fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure")
"""Returns python path."""
if _PYTHON_LIB_PATH in repository_ctx.os.environ:
return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip()
fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure")
def _check_lib(repository_ctx, toolkit_path, lib):
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
Args:
repository_ctx: The repository context.
toolkit_path: The toolkit directory containing the libraries.
ib: The library to look for under toolkit_path.
"""
lib_path = toolkit_path + "/" + lib
if not repository_ctx.path(lib_path).exists:
auto_configure_fail("Cannot find %s" % lib_path)
Args:
repository_ctx: The repository context.
toolkit_path: The toolkit directory containing the libraries.
ib: The library to look for under toolkit_path.
"""
lib_path = toolkit_path + "/" + lib
if not repository_ctx.path(lib_path).exists:
auto_configure_fail("Cannot find %s" % lib_path)
def _check_dir(repository_ctx, directory):
"""Checks whether the directory exists and fail if it does not.
"""Checks whether the directory exists and fail if it does not.
Args:
repository_ctx: The repository context.
directory: The directory to check the existence of.
"""
if not repository_ctx.path(directory).exists:
auto_configure_fail("Cannot find dir: %s" % directory)
Args:
repository_ctx: The repository context.
directory: The directory to check the existence of.
"""
if not repository_ctx.path(directory).exists:
auto_configure_fail("Cannot find dir: %s" % directory)
def _symlink_dir(repository_ctx, src_dir, dest_dir):
"""Symlinks all the files in a directory.
"""Symlinks all the files in a directory.
Args:
repository_ctx: The repository context.
src_dir: The source directory.
dest_dir: The destination directory to create the symlinks in.
"""
files = repository_ctx.path(src_dir).readdir()
for src_file in files:
repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
Args:
repository_ctx: The repository context.
src_dir: The source directory.
dest_dir: The destination directory to create the symlinks in.
"""
files = repository_ctx.path(src_dir).readdir()
for src_file in files:
repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
out = tpl.replace(":", "/")
repository_ctx.template(
out,
Label("//third_party/sycl/%s.tpl" % tpl),
substitutions,
)
def _tpl(repository_ctx, tpl, substitutions={}, out=None):
if not out:
out = tpl.replace(":", "/")
repository_ctx.template(
out,
Label("//third_party/sycl/%s.tpl" % tpl),
substitutions)
def _file(repository_ctx, label):
repository_ctx.template(
label.replace(":", "/"),
Label("//third_party/sycl/%s" % label),
{},
)
repository_ctx.template(
label.replace(":", "/"),
Label("//third_party/sycl/%s" % label),
{})
_DUMMY_CROSSTOOL_BZL_FILE = """
def error_sycl_disabled():
@ -149,6 +147,7 @@ def error_sycl_disabled():
)
"""
_DUMMY_CROSSTOOL_BUILD_FILE = """
load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled")
@ -156,97 +155,87 @@ error_sycl_disabled()
"""
def _create_dummy_repository(repository_ctx):
# Set up BUILD file for sycl/.
_tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
_file(repository_ctx, "sycl:LICENSE.text")
_tpl(repository_ctx, "sycl:platform.bzl")
# Set up BUILD file for sycl/.
_tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
_file(repository_ctx, "sycl:LICENSE.text")
_tpl(repository_ctx, "sycl:platform.bzl")
# Create dummy files for the SYCL toolkit since they are still required by
# tensorflow/sycl/platform/default/build_config:sycl.
repository_ctx.file("sycl/include/sycl.hpp", "")
repository_ctx.file("sycl/lib/libComputeCpp.so", "")
# Create dummy files for the SYCL toolkit since they are still required by
# tensorflow/sycl/platform/default/build_config:sycl.
repository_ctx.file("sycl/include/sycl.hpp", "")
repository_ctx.file("sycl/lib/libComputeCpp.so", "")
# If sycl_configure is not configured to build with SYCL support, and the user
# attempts to build with --config=sycl, add a dummy build rule to intercept
# this and fail with an actionable error message.
repository_ctx.file("crosstool/error_sycl_disabled.bzl",
_DUMMY_CROSSTOOL_BZL_FILE)
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
# If sycl_configure is not configured to build with SYCL support, and the user
# attempts to build with --config=sycl, add a dummy build rule to intercept
# this and fail with an actionable error message.
repository_ctx.file(
"crosstool/error_sycl_disabled.bzl",
_DUMMY_CROSSTOOL_BZL_FILE,
)
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
def _sycl_autoconf_imp(repository_ctx):
"""Implementation of the sycl_autoconf rule."""
if not _enable_sycl(repository_ctx):
_create_dummy_repository(repository_ctx)
"""Implementation of the sycl_autoconf rule."""
if not _enable_sycl(repository_ctx):
_create_dummy_repository(repository_ctx)
else:
# copy template files
_tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
_tpl(repository_ctx, "sycl:platform.bzl")
_tpl(repository_ctx, "crosstool:BUILD")
_file(repository_ctx, "sycl:LICENSE.text")
if _enable_compute_cpp(repository_ctx):
_tpl(repository_ctx, "crosstool:computecpp",
{
"%{host_cxx_compiler}" : find_cc(repository_ctx),
"%{host_c_compiler}" : find_c(repository_ctx)
})
computecpp_root = find_computecpp_root(repository_ctx);
_check_dir(repository_ctx, computecpp_root)
_tpl(repository_ctx, "crosstool:CROSSTOOL",
{
"%{sycl_include_dir}" : computecpp_root,
"%{sycl_impl}" : "computecpp",
"%{c++_std}" : "-std=c++11",
"%{python_lib_path}" : find_python_lib(repository_ctx),
})
# symlink libraries
_check_lib(repository_ctx, computecpp_root+"/lib", "libComputeCpp.so" )
_symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
else:
# copy template files
_tpl(repository_ctx, "sycl:build_defs.bzl")
_tpl(repository_ctx, "sycl:BUILD")
_tpl(repository_ctx, "sycl:platform.bzl")
_tpl(repository_ctx, "crosstool:BUILD")
_file(repository_ctx, "sycl:LICENSE.text")
if _enable_compute_cpp(repository_ctx):
_tpl(
repository_ctx,
"crosstool:computecpp",
{
"%{host_cxx_compiler}": find_cc(repository_ctx),
"%{host_c_compiler}": find_c(repository_ctx),
},
)
trisycl_include_dir = find_trisycl_include_dir(repository_ctx);
_check_dir(repository_ctx, trisycl_include_dir)
computecpp_root = find_computecpp_root(repository_ctx)
_check_dir(repository_ctx, computecpp_root)
_tpl(repository_ctx, "crosstool:trisycl",
{
"%{host_cxx_compiler}" : find_cc(repository_ctx),
"%{host_c_compiler}" : find_c(repository_ctx),
"%{trisycl_include_dir}" : trisycl_include_dir
})
_tpl(
repository_ctx,
"crosstool:CROSSTOOL",
{
"%{sycl_include_dir}": computecpp_root,
"%{sycl_impl}": "computecpp",
"%{c++_std}": "-std=c++11",
"%{python_lib_path}": find_python_lib(repository_ctx),
},
)
# symlink libraries
_check_lib(repository_ctx, computecpp_root + "/lib", "libComputeCpp.so")
_symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
else:
trisycl_include_dir = find_trisycl_include_dir(repository_ctx)
_check_dir(repository_ctx, trisycl_include_dir)
_tpl(repository_ctx, "crosstool:CROSSTOOL",
{
"%{sycl_include_dir}" : trisycl_include_dir,
"%{sycl_impl}" : "trisycl",
"%{c++_std}" : "-std=c++1y",
"%{python_lib_path}" : find_python_lib(repository_ctx),
})
_tpl(
repository_ctx,
"crosstool:trisycl",
{
"%{host_cxx_compiler}": find_cc(repository_ctx),
"%{host_c_compiler}": find_c(repository_ctx),
"%{trisycl_include_dir}": trisycl_include_dir,
},
)
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include")
_tpl(
repository_ctx,
"crosstool:CROSSTOOL",
{
"%{sycl_include_dir}": trisycl_include_dir,
"%{sycl_impl}": "trisycl",
"%{c++_std}": "-std=c++1y",
"%{python_lib_path}": find_python_lib(repository_ctx),
},
)
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include")
sycl_configure = repository_rule(
implementation = _sycl_autoconf_imp,
local = True,
implementation = _sycl_autoconf_imp,
local = True,
)
"""Detects and configures the SYCL toolchain.

View File

@ -1,37 +1,30 @@
"""Repository rule for Debian 8 Jessie Clang-6.0 portable Linux builds."""
def _clang6_configure(ctx):
# TODO(jart): It'd probably be better to use Bazel's struct.to_proto()
# method to generate a gigantic CROSSTOOL file that allows
# Clang to support everything.
ctx.symlink(
ctx.os.environ.get(
"TF_LLVM_PATH",
"/usr/lib/llvm-6.0",
),
"clang6/llvm",
)
ctx.symlink(
ctx.os.environ.get("STRIP", "/usr/bin/strip"),
"clang6/sbin/strip",
)
ctx.symlink(
ctx.os.environ.get("OBJDUMP", "/usr/bin/objdump"),
"clang6/sbin/objdump",
)
ctx.symlink(ctx.attr._build, "clang6/BUILD")
ctx.template("clang6/CROSSTOOL", ctx.attr._crosstool, {
"%package(@local_config_clang6//clang6)%": str(ctx.path("clang6")),
})
# TODO(jart): It'd probably be better to use Bazel's struct.to_proto()
# method to generate a gigantic CROSSTOOL file that allows
# Clang to support everything.
ctx.symlink(
ctx.os.environ.get('TF_LLVM_PATH',
'/usr/lib/llvm-6.0'),
'clang6/llvm')
ctx.symlink(
ctx.os.environ.get('STRIP', '/usr/bin/strip'),
'clang6/sbin/strip')
ctx.symlink(
ctx.os.environ.get('OBJDUMP', '/usr/bin/objdump'),
'clang6/sbin/objdump')
ctx.symlink(ctx.attr._build, 'clang6/BUILD')
ctx.template('clang6/CROSSTOOL', ctx.attr._crosstool, {
'%package(@local_config_clang6//clang6)%': str(ctx.path('clang6')),
})
clang6_configure = repository_rule(
implementation = _clang6_configure,
attrs = {
"_build": attr.label(
default = str(Label("//third_party/toolchains/clang6:clang.BUILD")),
),
"_crosstool": attr.label(
default = str(Label("//third_party/toolchains/clang6:CROSSTOOL.tpl")),
),
'_build': attr.label(
default=str(Label('//third_party/toolchains/clang6:clang.BUILD'))),
'_crosstool': attr.label(
default=str(Label('//third_party/toolchains/clang6:CROSSTOOL.tpl'))),
},
)

View File

@ -1,38 +1,38 @@
# -*- Python -*-
"""Repository rule for arm compiler autoconfiguration."""
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
out = tpl
repository_ctx.template(
out,
Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl),
substitutions,
)
def _tpl(repository_ctx, tpl, substitutions={}, out=None):
if not out:
out = tpl
repository_ctx.template(
out,
Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl),
substitutions)
def _arm_compiler_configure_impl(repository_ctx):
# We need to find a cross-compilation include directory for Python, so look
# for an environment variable. Be warned, this crosstool template is only
# regenerated on the first run of Bazel, so if you change the variable after
# it may not be reflected in later builds. Doing a shutdown and clean of Bazel
# doesn't fix this, you'll need to delete the generated file at something like:
# external/local_config_arm_compiler/CROSSTOOL in your Bazel install.
if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ:
python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"]
else:
python_include_path = "/usr/include/python2.7"
_tpl(repository_ctx, "CROSSTOOL", {
"%{ARM_COMPILER_PATH}%": str(repository_ctx.path(
repository_ctx.attr.remote_config_repo,
)),
"%{PYTHON_INCLUDE_PATH}%": python_include_path,
})
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
# We need to find a cross-compilation include directory for Python, so look
# for an environment variable. Be warned, this crosstool template is only
# regenerated on the first run of Bazel, so if you change the variable after
# it may not be reflected in later builds. Doing a shutdown and clean of Bazel
# doesn't fix this, you'll need to delete the generated file at something like:
# external/local_config_arm_compiler/CROSSTOOL in your Bazel install.
if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ:
python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"]
else:
python_include_path = "/usr/include/python2.7"
_tpl(repository_ctx, "CROSSTOOL", {
"%{ARM_COMPILER_PATH}%": str(repository_ctx.path(
repository_ctx.attr.remote_config_repo)),
"%{PYTHON_INCLUDE_PATH}%": python_include_path,
})
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
arm_compiler_configure = repository_rule(
implementation = _arm_compiler_configure_impl,
attrs = {
"remote_config_repo": attr.string(mandatory = False, default = ""),
"remote_config_repo": attr.string(mandatory = False, default =""),
"build_file": attr.label(),
},
)

View File

@ -12,13 +12,15 @@ def if_cuda(if_true, if_false = []):
return select({
"@local_config_cuda//cuda:using_nvcc": if_true,
"@local_config_cuda//cuda:using_clang": if_true,
"//conditions:default": if_false,
"//conditions:default": if_false
})
def cuda_default_copts():
"""Default options for all CUDA compilations."""
return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + ["--cuda-gpu-arch=sm_30"])
def cuda_is_configured():
"""Returns true if CUDA was enabled during the configure process."""
return True
@ -30,5 +32,6 @@ def if_cuda_is_configured(x):
--config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
"""
if cuda_is_configured():
return x
return x
return []