Rolling back tensorflow .bzl file changes
END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 203459720 PiperOrigin-RevId: 203501636
This commit is contained in:
parent
8f5e2a740e
commit
ed494f17fc
@ -16,355 +16,339 @@ tf_library(
|
||||
)
|
||||
"""
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_android",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl",
|
||||
"if_android", "tf_cc_test", "tf_copts")
|
||||
|
||||
def tf_library(
|
||||
name,
|
||||
graph,
|
||||
config,
|
||||
freeze_checkpoint = None,
|
||||
freeze_saver = None,
|
||||
cpp_class = None,
|
||||
gen_test = True,
|
||||
gen_benchmark = True,
|
||||
visibility = None,
|
||||
testonly = None,
|
||||
tfcompile_flags = None,
|
||||
tfcompile_tool = "//tensorflow/compiler/aot:tfcompile",
|
||||
include_standard_runtime_deps = True,
|
||||
enable_xla_hlo_profiling = False,
|
||||
deps = None,
|
||||
tags = None):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
def tf_library(name, graph, config,
|
||||
freeze_checkpoint=None, freeze_saver=None,
|
||||
cpp_class=None, gen_test=True, gen_benchmark=True,
|
||||
visibility=None, testonly=None,
|
||||
tfcompile_flags=None,
|
||||
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
|
||||
include_standard_runtime_deps=True,
|
||||
enable_xla_hlo_profiling=False, deps=None, tags=None):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
foo: A cc_library containing the generated header and computation.
|
||||
foo_test: A cc_test with simple tests and benchmarks. Only created if
|
||||
gen_test=True.
|
||||
foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful
|
||||
for mobile devices or other platforms that can't compile the
|
||||
full test libraries. Only created if gen_benchmark=True.
|
||||
Given an invocation of tf_library(name="foo", ...), generates the following
|
||||
build targets:
|
||||
foo: A cc_library containing the generated header and computation.
|
||||
foo_test: A cc_test with simple tests and benchmarks. Only created if
|
||||
gen_test=True.
|
||||
foo_benchmark: A cc_binary that runs a minimal-dependency benchmark, useful
|
||||
for mobile devices or other platforms that can't compile the
|
||||
full test libraries. Only created if gen_benchmark=True.
|
||||
|
||||
Args:
|
||||
name: The name of the build rule.
|
||||
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
|
||||
is expected to be in the human-readable proto text format, otherwise it is
|
||||
expected to be in the proto binary format.
|
||||
config: File containing tensorflow.tf2xla.Config proto. If the file ends
|
||||
in '.pbtxt' it is expected to be in the human-readable proto text format,
|
||||
otherwise it is expected to be in the proto binary format.
|
||||
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
|
||||
convert variables into constants.
|
||||
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
|
||||
binary form, to convert variables into constants.
|
||||
cpp_class: The name of the generated C++ class, wrapping the generated
|
||||
function. The syntax of this flag is
|
||||
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
|
||||
for referring to a class, where multiple namespaces may precede the class
|
||||
name, separated by double-colons. The class will be generated in the
|
||||
given namespace(s), or if no namespaces are given, within the global
|
||||
namespace.
|
||||
gen_test: If True, also generate a cc_test rule that builds a simple
|
||||
test and benchmark.
|
||||
gen_benchmark: If True, also generate a binary with a simple benchmark.
|
||||
Unlike the output of gen_test, this benchmark can be run on android.
|
||||
visibility: Bazel build visibility.
|
||||
testonly: Bazel testonly attribute.
|
||||
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
|
||||
tfcompile_tool: The tfcompile binary. A non-default can be passed to
|
||||
use a tfcompile built with extra dependencies.
|
||||
include_standard_runtime_deps: If True, the standard list of kernel/runtime
|
||||
deps is added to deps. If False, deps must contain the full set of deps
|
||||
needed by the generated library.
|
||||
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
|
||||
and emit metadata that lets us pretty-print the gathered profile counters.
|
||||
deps: a list of deps to include on the build rules for the generated
|
||||
library, added to the standard deps if standard_runtime_deps is True.
|
||||
tags: tags to apply to subsidiary build rules.
|
||||
Args:
|
||||
name: The name of the build rule.
|
||||
graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
|
||||
is expected to be in the human-readable proto text format, otherwise it is
|
||||
expected to be in the proto binary format.
|
||||
config: File containing tensorflow.tf2xla.Config proto. If the file ends
|
||||
in '.pbtxt' it is expected to be in the human-readable proto text format,
|
||||
otherwise it is expected to be in the proto binary format.
|
||||
freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
|
||||
convert variables into constants.
|
||||
freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
|
||||
binary form, to convert variables into constants.
|
||||
cpp_class: The name of the generated C++ class, wrapping the generated
|
||||
function. The syntax of this flag is
|
||||
[[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
|
||||
for referring to a class, where multiple namespaces may precede the class
|
||||
name, separated by double-colons. The class will be generated in the
|
||||
given namespace(s), or if no namespaces are given, within the global
|
||||
namespace.
|
||||
gen_test: If True, also generate a cc_test rule that builds a simple
|
||||
test and benchmark.
|
||||
gen_benchmark: If True, also generate a binary with a simple benchmark.
|
||||
Unlike the output of gen_test, this benchmark can be run on android.
|
||||
visibility: Bazel build visibility.
|
||||
testonly: Bazel testonly attribute.
|
||||
tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
|
||||
tfcompile_tool: The tfcompile binary. A non-default can be passed to
|
||||
use a tfcompile built with extra dependencies.
|
||||
include_standard_runtime_deps: If True, the standard list of kernel/runtime
|
||||
deps is added to deps. If False, deps must contain the full set of deps
|
||||
needed by the generated library.
|
||||
enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
|
||||
and emit metadata that lets us pretty-print the gathered profile counters.
|
||||
deps: a list of deps to include on the build rules for the generated
|
||||
library, added to the standard deps if standard_runtime_deps is True.
|
||||
tags: tags to apply to subsidiary build rules.
|
||||
|
||||
The output header is called <name>.h.
|
||||
"""
|
||||
if not cpp_class:
|
||||
fail("cpp_class must be specified")
|
||||
The output header is called <name>.h.
|
||||
"""
|
||||
if not cpp_class:
|
||||
fail("cpp_class must be specified")
|
||||
|
||||
tfcompile_graph = graph
|
||||
if freeze_checkpoint or freeze_saver:
|
||||
if not freeze_checkpoint:
|
||||
fail("freeze_checkpoint must be specified when freeze_saver is specified")
|
||||
tfcompile_graph = graph
|
||||
if freeze_checkpoint or freeze_saver:
|
||||
if not freeze_checkpoint:
|
||||
fail("freeze_checkpoint must be specified when freeze_saver is specified")
|
||||
|
||||
freeze_name = "freeze_" + name
|
||||
freeze_file = freeze_name + ".pb"
|
||||
freeze_name = "freeze_" + name
|
||||
freeze_file = freeze_name + ".pb"
|
||||
|
||||
# First run tfcompile to generate the list of out_nodes.
|
||||
out_nodes_file = "out_nodes_" + freeze_name
|
||||
native.genrule(
|
||||
name = ("gen_" + out_nodes_file),
|
||||
srcs = [config],
|
||||
outs = [out_nodes_file],
|
||||
cmd = ("$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
tools = [tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local = 1,
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
# Now run freeze_graph to convert variables into constants.
|
||||
freeze_args = (" --input_graph=$(location " + graph + ")" +
|
||||
" --checkpoint_version=1" +
|
||||
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
|
||||
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
|
||||
" --output_graph=$(location " + freeze_file + ")" +
|
||||
" --output_node_names=$$(<$(location " + out_nodes_file +
|
||||
"))")
|
||||
freeze_saver_srcs = []
|
||||
if freeze_saver:
|
||||
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
|
||||
freeze_saver_srcs += [freeze_saver]
|
||||
native.genrule(
|
||||
name = freeze_name,
|
||||
srcs = [
|
||||
graph,
|
||||
freeze_checkpoint,
|
||||
out_nodes_file,
|
||||
] + freeze_saver_srcs,
|
||||
outs = [freeze_file],
|
||||
cmd = ("$(location //tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args),
|
||||
tools = ["//tensorflow/python/tools:freeze_graph"],
|
||||
tags = tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
|
||||
# Rule that runs tfcompile to produce the header and object file.
|
||||
header_file = name + ".h"
|
||||
metadata_object_file = name + "_tfcompile_metadata.o"
|
||||
function_object_file = name + "_tfcompile_function.o"
|
||||
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
|
||||
if type(tfcompile_flags) == type(""):
|
||||
flags = tfcompile_flags
|
||||
else:
|
||||
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
else:
|
||||
profiling_flag = ""
|
||||
# First run tfcompile to generate the list of out_nodes.
|
||||
out_nodes_file = "out_nodes_" + freeze_name
|
||||
native.genrule(
|
||||
name = ("gen_" + name),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
outs = [
|
||||
name=("gen_" + out_nodes_file),
|
||||
srcs=[config],
|
||||
outs=[out_nodes_file],
|
||||
cmd=("$(location " + tfcompile_tool + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --dump_fetch_nodes > $@"),
|
||||
tools=[tfcompile_tool],
|
||||
# Run tfcompile on the build host, rather than forge, since it's
|
||||
# typically way faster on the local machine.
|
||||
local=1,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Now run freeze_graph to convert variables into constants.
|
||||
freeze_args = (" --input_graph=$(location " + graph + ")" +
|
||||
" --checkpoint_version=1" +
|
||||
" --input_binary=" + str(not graph.endswith(".pbtxt")) +
|
||||
" --input_checkpoint=$(location " + freeze_checkpoint + ")" +
|
||||
" --output_graph=$(location " + freeze_file + ")" +
|
||||
" --output_node_names=$$(<$(location " + out_nodes_file +
|
||||
"))")
|
||||
freeze_saver_srcs = []
|
||||
if freeze_saver:
|
||||
freeze_args += " --input_saver=$(location " + freeze_saver + ")"
|
||||
freeze_saver_srcs += [freeze_saver]
|
||||
native.genrule(
|
||||
name=freeze_name,
|
||||
srcs=[
|
||||
graph,
|
||||
freeze_checkpoint,
|
||||
out_nodes_file,
|
||||
] + freeze_saver_srcs,
|
||||
outs=[freeze_file],
|
||||
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args),
|
||||
tools=["//tensorflow/python/tools:freeze_graph"],
|
||||
tags=tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
|
||||
# Rule that runs tfcompile to produce the header and object file.
|
||||
header_file = name + ".h"
|
||||
metadata_object_file = name + "_tfcompile_metadata.o"
|
||||
function_object_file = name + "_tfcompile_function.o"
|
||||
ep = ("__" + native.package_name() + "__" + name).replace("/", "_")
|
||||
if type(tfcompile_flags) == type(""):
|
||||
flags = tfcompile_flags
|
||||
else:
|
||||
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
|
||||
if enable_xla_hlo_profiling:
|
||||
profiling_flag = "--xla_hlo_profile"
|
||||
else:
|
||||
profiling_flag = ""
|
||||
native.genrule(
|
||||
name=("gen_" + name),
|
||||
srcs=[
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
outs=[
|
||||
header_file,
|
||||
metadata_object_file,
|
||||
function_object_file,
|
||||
],
|
||||
cmd=("$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
" --target_triple=" + target_llvm_triple() +
|
||||
" --out_header=$(@D)/" + header_file +
|
||||
" --out_metadata_object=$(@D)/" + metadata_object_file +
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag),
|
||||
tools=[tfcompile_tool],
|
||||
visibility=visibility,
|
||||
testonly=testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the local
|
||||
# machine.
|
||||
#
|
||||
# Note that setting the local=1 attribute on a *test target* causes the
|
||||
# test infrastructure to skip that test. However this is a genrule, not a
|
||||
# test target, and runs with --genrule_strategy=forced_forge, meaning the
|
||||
# local=1 attribute is ignored, and the genrule is still run.
|
||||
#
|
||||
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
|
||||
local=1,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Rule that runs tfcompile to produce the SessionModule proto, useful for
|
||||
# debugging. TODO(b/64813587): Once the SessionModule proto is
|
||||
# deterministic, move this into the main rule above.
|
||||
session_module_pb = name + "_session_module.pb"
|
||||
native.genrule(
|
||||
name=(name + "_session_module"),
|
||||
srcs=[
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
outs=[
|
||||
session_module_pb,
|
||||
],
|
||||
cmd=("$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
" --target_triple=" + target_llvm_triple() +
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
" " + flags),
|
||||
tools=[tfcompile_tool],
|
||||
visibility=visibility,
|
||||
testonly=testonly,
|
||||
local=1,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# The cc_library rule packaging up the header and object file, and needed
|
||||
# kernel implementations.
|
||||
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
|
||||
native.cc_library(
|
||||
name=name,
|
||||
srcs=[function_object_file, metadata_object_file],
|
||||
hdrs=[header_file],
|
||||
visibility=visibility,
|
||||
testonly=testonly,
|
||||
deps = [
|
||||
# These deps are required by all tf_library targets even if
|
||||
# include_standard_runtime_deps is False. Without them, the
|
||||
# generated code will fail to compile.
|
||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||
"//tensorflow/core:framework_lite",
|
||||
] + (need_xla_data_proto and [
|
||||
# If we're generating the program shape, we must depend on the proto.
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
] or []) + (enable_xla_hlo_profiling and [
|
||||
"//tensorflow/compiler/xla/service:hlo_profile_printer_data"
|
||||
] or []) + (include_standard_runtime_deps and [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||
"//third_party/eigen3",
|
||||
] or []) + (deps or []),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Variables used for gen_test and gen_benchmark.
|
||||
no_ns_name = ""
|
||||
cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
|
||||
if len(cpp_class_split) == 1:
|
||||
no_ns_name = cpp_class_split[0]
|
||||
else:
|
||||
no_ns_name = cpp_class_split[1]
|
||||
sed_replace = (
|
||||
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
|
||||
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
|
||||
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
|
||||
|
||||
if gen_test:
|
||||
test_name = name + "_test"
|
||||
test_file = test_name + ".cc"
|
||||
# Rule to rewrite test.cc to produce the test_file.
|
||||
native.genrule(
|
||||
name=("gen_" + test_name),
|
||||
testonly=1,
|
||||
srcs=[
|
||||
"//tensorflow/compiler/aot:test.cc",
|
||||
header_file,
|
||||
metadata_object_file,
|
||||
function_object_file,
|
||||
],
|
||||
cmd = ("$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
" --target_triple=" + target_llvm_triple() +
|
||||
" --out_header=$(@D)/" + header_file +
|
||||
" --out_metadata_object=$(@D)/" + metadata_object_file +
|
||||
" --out_function_object=$(@D)/" + function_object_file +
|
||||
" " + flags + " " + profiling_flag),
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
# Run tfcompile on the build host since it's typically faster on the local
|
||||
# machine.
|
||||
#
|
||||
# Note that setting the local=1 attribute on a *test target* causes the
|
||||
# test infrastructure to skip that test. However this is a genrule, not a
|
||||
# test target, and runs with --genrule_strategy=forced_forge, meaning the
|
||||
# local=1 attribute is ignored, and the genrule is still run.
|
||||
#
|
||||
# https://www.bazel.io/versions/master/docs/be/general.html#genrule
|
||||
local = 1,
|
||||
tags = tags,
|
||||
outs=[test_file],
|
||||
cmd=("sed " + sed_replace +
|
||||
" $(location //tensorflow/compiler/aot:test.cc) " +
|
||||
"> $(OUTS)"),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Rule that runs tfcompile to produce the SessionModule proto, useful for
|
||||
# debugging. TODO(b/64813587): Once the SessionModule proto is
|
||||
# deterministic, move this into the main rule above.
|
||||
session_module_pb = name + "_session_module.pb"
|
||||
native.genrule(
|
||||
name = (name + "_session_module"),
|
||||
srcs = [
|
||||
tfcompile_graph,
|
||||
config,
|
||||
],
|
||||
outs = [
|
||||
session_module_pb,
|
||||
],
|
||||
cmd = ("$(location " + tfcompile_tool + ")" +
|
||||
" --graph=$(location " + tfcompile_graph + ")" +
|
||||
" --config=$(location " + config + ")" +
|
||||
" --entry_point=" + ep +
|
||||
" --cpp_class=" + cpp_class +
|
||||
" --target_triple=" + target_llvm_triple() +
|
||||
" --out_session_module=$(@D)/" + session_module_pb +
|
||||
" " + flags),
|
||||
tools = [tfcompile_tool],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
local = 1,
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
# The cc_library rule packaging up the header and object file, and needed
|
||||
# kernel implementations.
|
||||
need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1)
|
||||
native.cc_library(
|
||||
name = name,
|
||||
srcs = [function_object_file, metadata_object_file],
|
||||
hdrs = [header_file],
|
||||
visibility = visibility,
|
||||
testonly = testonly,
|
||||
deps = [
|
||||
# These deps are required by all tf_library targets even if
|
||||
# include_standard_runtime_deps is False. Without them, the
|
||||
# generated code will fail to compile.
|
||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||
"//tensorflow/core:framework_lite",
|
||||
] + (need_xla_data_proto and [
|
||||
# If we're generating the program shape, we must depend on the proto.
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
] or []) + (enable_xla_hlo_profiling and [
|
||||
"//tensorflow/compiler/xla/service:hlo_profile_printer_data",
|
||||
] or []) + (include_standard_runtime_deps and [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||
# The cc_test rule for the generated code. To ensure that this works
|
||||
# reliably across build configurations, we must use tf_cc_test instead of
|
||||
# native.cc_test. This is related to how we build
|
||||
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
|
||||
# for more details.
|
||||
tf_cc_test(
|
||||
name=test_name,
|
||||
srcs=[test_file],
|
||||
deps=[
|
||||
":" + name,
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/aot:tf_library_test_main",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
] or []) + (deps or []),
|
||||
tags = tags,
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Variables used for gen_test and gen_benchmark.
|
||||
no_ns_name = ""
|
||||
cpp_class_split = cpp_class.rsplit("::", maxsplit = 2)
|
||||
if len(cpp_class_split) == 1:
|
||||
no_ns_name = cpp_class_split[0]
|
||||
else:
|
||||
no_ns_name = cpp_class_split[1]
|
||||
sed_replace = (
|
||||
"-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
|
||||
"-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
|
||||
"-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" "
|
||||
if gen_benchmark:
|
||||
benchmark_name = name + "_benchmark"
|
||||
benchmark_file = benchmark_name + ".cc"
|
||||
benchmark_main = ("//tensorflow/compiler/aot:" +
|
||||
"benchmark_main.template")
|
||||
|
||||
# Rule to rewrite benchmark.cc to produce the benchmark_file.
|
||||
native.genrule(
|
||||
name=("gen_" + benchmark_name),
|
||||
srcs=[
|
||||
benchmark_main,
|
||||
header_file,
|
||||
],
|
||||
testonly = testonly,
|
||||
outs=[benchmark_file],
|
||||
cmd=("sed " + sed_replace +
|
||||
" $(location " + benchmark_main + ") " +
|
||||
"> $(OUTS)"),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
if gen_test:
|
||||
test_name = name + "_test"
|
||||
test_file = test_name + ".cc"
|
||||
|
||||
# Rule to rewrite test.cc to produce the test_file.
|
||||
native.genrule(
|
||||
name = ("gen_" + test_name),
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"//tensorflow/compiler/aot:test.cc",
|
||||
header_file,
|
||||
],
|
||||
outs = [test_file],
|
||||
cmd = ("sed " + sed_replace +
|
||||
" $(location //tensorflow/compiler/aot:test.cc) " +
|
||||
"> $(OUTS)"),
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
# The cc_test rule for the generated code. To ensure that this works
|
||||
# reliably across build configurations, we must use tf_cc_test instead of
|
||||
# native.cc_test. This is related to how we build
|
||||
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
|
||||
# for more details.
|
||||
tf_cc_test(
|
||||
name = test_name,
|
||||
srcs = [test_file],
|
||||
deps = [
|
||||
":" + name,
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/aot:tf_library_test_main",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
if gen_benchmark:
|
||||
benchmark_name = name + "_benchmark"
|
||||
benchmark_file = benchmark_name + ".cc"
|
||||
benchmark_main = ("//tensorflow/compiler/aot:" +
|
||||
"benchmark_main.template")
|
||||
|
||||
# Rule to rewrite benchmark.cc to produce the benchmark_file.
|
||||
native.genrule(
|
||||
name = ("gen_" + benchmark_name),
|
||||
srcs = [
|
||||
benchmark_main,
|
||||
header_file,
|
||||
],
|
||||
testonly = testonly,
|
||||
outs = [benchmark_file],
|
||||
cmd = ("sed " + sed_replace +
|
||||
" $(location " + benchmark_main + ") " +
|
||||
"> $(OUTS)"),
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
# The cc_benchmark rule for the generated code. This does not need the
|
||||
# tf_cc_binary since we (by deliberate design) do not depend on
|
||||
# //tensorflow/core:lib.
|
||||
#
|
||||
# Note: to get smaller size on android for comparison, compile with:
|
||||
# --copt=-fvisibility=hidden
|
||||
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
|
||||
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
|
||||
native.cc_binary(
|
||||
name = benchmark_name,
|
||||
srcs = [benchmark_file],
|
||||
testonly = testonly,
|
||||
copts = tf_copts(),
|
||||
linkopts = if_android(["-pie", "-s"]),
|
||||
deps = [
|
||||
":" + name,
|
||||
"//tensorflow/compiler/aot:benchmark",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
] + if_android([
|
||||
"//tensorflow/compiler/aot:benchmark_extra_android",
|
||||
]),
|
||||
tags = tags,
|
||||
)
|
||||
# The cc_benchmark rule for the generated code. This does not need the
|
||||
# tf_cc_binary since we (by deliberate design) do not depend on
|
||||
# //tensorflow/core:lib.
|
||||
#
|
||||
# Note: to get smaller size on android for comparison, compile with:
|
||||
# --copt=-fvisibility=hidden
|
||||
# --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
|
||||
# --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
|
||||
native.cc_binary(
|
||||
name=benchmark_name,
|
||||
srcs=[benchmark_file],
|
||||
testonly = testonly,
|
||||
copts = tf_copts(),
|
||||
linkopts = if_android(["-pie", "-s"]),
|
||||
deps=[
|
||||
":" + name,
|
||||
"//tensorflow/compiler/aot:benchmark",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
] + if_android([
|
||||
"//tensorflow/compiler/aot:benchmark_extra_android",
|
||||
]),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
def target_llvm_triple():
|
||||
"""Returns the target LLVM triple to be used for compiling the target."""
|
||||
|
||||
# TODO(toddw): Add target_triple for other targets. For details see:
|
||||
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
|
||||
return select({
|
||||
"//tensorflow:android_armeabi": "armv5-none-android",
|
||||
"//tensorflow:android_arm": "armv7-none-android",
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//tensorflow:android_x86": "i686-none-android",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//tensorflow:darwin": "x86_64-none-darwin",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
})
|
||||
"""Returns the target LLVM triple to be used for compiling the target."""
|
||||
# TODO(toddw): Add target_triple for other targets. For details see:
|
||||
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
|
||||
return select({
|
||||
"//tensorflow:android_armeabi": "armv5-none-android",
|
||||
"//tensorflow:android_arm": "armv7-none-android",
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//tensorflow:android_x86": "i686-none-android",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//tensorflow:darwin": "x86_64-none-darwin",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
})
|
||||
|
@ -4,97 +4,88 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
|
||||
load("//tensorflow/compiler/tests:plugin.bzl", "plugins")
|
||||
|
||||
def all_backends():
|
||||
b = ["cpu"] + plugins.keys()
|
||||
if cuda_is_configured():
|
||||
return b + ["gpu"]
|
||||
b = ["cpu"] + plugins.keys()
|
||||
if cuda_is_configured():
|
||||
return b + ["gpu"]
|
||||
else:
|
||||
return b
|
||||
|
||||
def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
|
||||
disabled_backends=None, **kwargs):
|
||||
"""Generates py_test targets, one per XLA backend.
|
||||
|
||||
This rule generates py_test() targets named name_backend, for each backend
|
||||
in all_backends(). The rule also generates a test suite with named `name` that
|
||||
tests all backends for the test.
|
||||
|
||||
For example, the following rule generates test cases foo_test_cpu,
|
||||
foo_test_gpu, and a test suite name foo_test that tests both.
|
||||
tf_xla_py_test(
|
||||
name="foo_test",
|
||||
srcs="foo_test.py",
|
||||
deps=[...],
|
||||
)
|
||||
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
deps: Dependencies of the target.
|
||||
tags: Tags to apply to the generated targets.
|
||||
data: Data dependencies of the target.
|
||||
main: Same as py_test's main attribute.
|
||||
disabled_backends: A list of backends that should not be tested. Supported
|
||||
values include "cpu" and "gpu". If not specified, defaults to None.
|
||||
**kwargs: keyword arguments passed onto the generated py_test() rules.
|
||||
"""
|
||||
if disabled_backends == None:
|
||||
disabled_backends = []
|
||||
|
||||
enabled_backends = [b for b in all_backends() if b not in disabled_backends]
|
||||
test_names = []
|
||||
for backend in enabled_backends:
|
||||
test_name = "{}_{}".format(name, backend)
|
||||
backend_tags = ["tf_xla_{}".format(backend)]
|
||||
backend_args = []
|
||||
backend_deps = []
|
||||
backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_args += [
|
||||
"--test_device=XLA_CPU",
|
||||
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64"
|
||||
]
|
||||
elif backend == "gpu":
|
||||
backend_args += [
|
||||
"--test_device=XLA_GPU",
|
||||
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16"
|
||||
]
|
||||
backend_tags += ["requires-gpu-sm35"]
|
||||
elif backend in plugins:
|
||||
backend_args += ["--test_device=" + plugins[backend]["device"],
|
||||
"--types=" + plugins[backend]["types"]]
|
||||
backend_tags += plugins[backend]["tags"]
|
||||
backend_args += plugins[backend]["args"]
|
||||
backend_deps += plugins[backend]["deps"]
|
||||
backend_data += plugins[backend]["data"]
|
||||
else:
|
||||
return b
|
||||
fail("Unknown backend {}".format(backend))
|
||||
|
||||
def tf_xla_py_test(
|
||||
name,
|
||||
srcs = [],
|
||||
deps = [],
|
||||
tags = [],
|
||||
data = [],
|
||||
main = None,
|
||||
disabled_backends = None,
|
||||
**kwargs):
|
||||
"""Generates py_test targets, one per XLA backend.
|
||||
|
||||
This rule generates py_test() targets named name_backend, for each backend
|
||||
in all_backends(). The rule also generates a test suite with named `name` that
|
||||
tests all backends for the test.
|
||||
|
||||
For example, the following rule generates test cases foo_test_cpu,
|
||||
foo_test_gpu, and a test suite name foo_test that tests both.
|
||||
tf_xla_py_test(
|
||||
name="foo_test",
|
||||
srcs="foo_test.py",
|
||||
deps=[...],
|
||||
native.py_test(
|
||||
name=test_name,
|
||||
srcs=srcs,
|
||||
srcs_version="PY2AND3",
|
||||
args=backend_args,
|
||||
main="{}.py".format(name) if main == None else main,
|
||||
data=data + backend_data,
|
||||
deps=deps + backend_deps,
|
||||
tags=tags + backend_tags,
|
||||
**kwargs
|
||||
)
|
||||
test_names.append(test_name)
|
||||
native.test_suite(name=name, tests=test_names)
|
||||
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
deps: Dependencies of the target.
|
||||
tags: Tags to apply to the generated targets.
|
||||
data: Data dependencies of the target.
|
||||
main: Same as py_test's main attribute.
|
||||
disabled_backends: A list of backends that should not be tested. Supported
|
||||
values include "cpu" and "gpu". If not specified, defaults to None.
|
||||
**kwargs: keyword arguments passed onto the generated py_test() rules.
|
||||
"""
|
||||
if disabled_backends == None:
|
||||
disabled_backends = []
|
||||
|
||||
enabled_backends = [b for b in all_backends() if b not in disabled_backends]
|
||||
test_names = []
|
||||
for backend in enabled_backends:
|
||||
test_name = "{}_{}".format(name, backend)
|
||||
backend_tags = ["tf_xla_{}".format(backend)]
|
||||
backend_args = []
|
||||
backend_deps = []
|
||||
backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_args += [
|
||||
"--test_device=XLA_CPU",
|
||||
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
|
||||
]
|
||||
elif backend == "gpu":
|
||||
backend_args += [
|
||||
"--test_device=XLA_GPU",
|
||||
"--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
|
||||
]
|
||||
backend_tags += ["requires-gpu-sm35"]
|
||||
elif backend in plugins:
|
||||
backend_args += [
|
||||
"--test_device=" + plugins[backend]["device"],
|
||||
"--types=" + plugins[backend]["types"],
|
||||
]
|
||||
backend_tags += plugins[backend]["tags"]
|
||||
backend_args += plugins[backend]["args"]
|
||||
backend_deps += plugins[backend]["deps"]
|
||||
backend_data += plugins[backend]["data"]
|
||||
else:
|
||||
fail("Unknown backend {}".format(backend))
|
||||
|
||||
native.py_test(
|
||||
name = test_name,
|
||||
srcs = srcs,
|
||||
srcs_version = "PY2AND3",
|
||||
args = backend_args,
|
||||
main = "{}.py".format(name) if main == None else main,
|
||||
data = data + backend_data,
|
||||
deps = deps + backend_deps,
|
||||
tags = tags + backend_tags,
|
||||
**kwargs
|
||||
)
|
||||
test_names.append(test_name)
|
||||
native.test_suite(name = name, tests = test_names)
|
||||
|
||||
def generate_backend_suites(backends = []):
|
||||
"""Generates per-backend test_suites that run all tests for a backend."""
|
||||
if not backends:
|
||||
backends = all_backends()
|
||||
for backend in backends:
|
||||
native.test_suite(name = "%s_tests" % backend, tags = ["tf_xla_%s" % backend])
|
||||
def generate_backend_suites(backends=[]):
|
||||
"""Generates per-backend test_suites that run all tests for a backend."""
|
||||
if not backends:
|
||||
backends = all_backends()
|
||||
for backend in backends:
|
||||
native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
|
||||
|
@ -18,12 +18,13 @@
|
||||
# git update-index --assume-unchanged tensorflow/compiler/tests/plugin.bzl
|
||||
|
||||
plugins = {
|
||||
#"example": {
|
||||
# "device":"XLA_MY_DEVICE",
|
||||
# "types":"DT_FLOAT,DT_HALF,DT_INT32",
|
||||
# "tags":[],
|
||||
# "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"],
|
||||
# "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"],
|
||||
# "deps":[],
|
||||
#},
|
||||
#"example": {
|
||||
# "device":"XLA_MY_DEVICE",
|
||||
# "types":"DT_FLOAT,DT_HALF,DT_INT32",
|
||||
# "tags":[],
|
||||
# "args":["--disabled_manifest=tensorflow/compiler/plugin/example/disabled_manifest.txt"],
|
||||
# "data":["//tensorflow/compiler/plugin/example:disabled_manifest.txt"],
|
||||
# "deps":[],
|
||||
#},
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""build_defs for service/cpu."""
|
||||
|
||||
|
||||
def runtime_copts():
|
||||
"""Returns copts used for CPU runtime libraries."""
|
||||
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
|
||||
"//tensorflow:android_arm": ["-mfpu=neon"],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
"//tensorflow:android": ["-O2"],
|
||||
"//conditions:default": [],
|
||||
}))
|
||||
"""Returns copts used for CPU runtime libraries."""
|
||||
return (["-DEIGEN_AVOID_STL_ARRAY"] + select({
|
||||
"//tensorflow:android_arm": ["-mfpu=neon"],
|
||||
"//conditions:default": []
|
||||
}) + select({
|
||||
"//tensorflow:android": ["-O2"],
|
||||
"//conditions:default": []
|
||||
}))
|
||||
|
@ -7,258 +7,252 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
all_backends = ["cpu", "gpu"] + plugins.keys()
|
||||
|
||||
def filter_backends(backends):
|
||||
"""Removes "gpu" from a backend list if CUDA is not enabled.
|
||||
"""Removes "gpu" from a backend list if CUDA is not enabled.
|
||||
|
||||
This allows us to simply hardcode lists including "gpu" here and in the
|
||||
BUILD file, without causing failures when CUDA isn't enabled.'
|
||||
This allows us to simply hardcode lists including "gpu" here and in the
|
||||
BUILD file, without causing failures when CUDA isn't enabled.'
|
||||
|
||||
Args:
|
||||
backends: A list of backends to filter.
|
||||
Args:
|
||||
backends: A list of backends to filter.
|
||||
|
||||
Returns:
|
||||
The filtered list of backends.
|
||||
"""
|
||||
if cuda_is_configured():
|
||||
return backends
|
||||
else:
|
||||
return [backend for backend in backends if backend != "gpu"]
|
||||
Returns:
|
||||
The filtered list of backends.
|
||||
"""
|
||||
if cuda_is_configured():
|
||||
return backends
|
||||
else:
|
||||
return [backend for backend in backends if backend != "gpu"]
|
||||
|
||||
def xla_test(
|
||||
name,
|
||||
srcs,
|
||||
deps,
|
||||
xla_test_library_deps = [],
|
||||
backends = [],
|
||||
blacklisted_backends = [],
|
||||
args = [],
|
||||
tags = [],
|
||||
copts = [],
|
||||
data = [],
|
||||
backend_tags = {},
|
||||
backend_args = {},
|
||||
**kwargs):
|
||||
"""Generates cc_test targets for the given XLA backends.
|
||||
|
||||
This rule generates a cc_test target for one or more XLA backends and also a
|
||||
platform-agnostic cc_library rule. The arguments are identical to cc_test with
|
||||
two additions: 'backends' and 'backend_args'. 'backends' specifies the
|
||||
backends to generate tests for ("cpu", "gpu"), and
|
||||
'backend_args'/'backend_tags' specifies backend-specific args parameters to
|
||||
use when generating the cc_test.
|
||||
def xla_test(name,
|
||||
srcs,
|
||||
deps,
|
||||
xla_test_library_deps=[],
|
||||
backends=[],
|
||||
blacklisted_backends=[],
|
||||
args=[],
|
||||
tags=[],
|
||||
copts=[],
|
||||
data=[],
|
||||
backend_tags={},
|
||||
backend_args={},
|
||||
**kwargs):
|
||||
"""Generates cc_test targets for the given XLA backends.
|
||||
|
||||
The name of the cc_tests are the provided name argument with the backend name
|
||||
appended, and the cc_library target name is the provided name argument with
|
||||
"_lib" appended. For example, if name parameter is "foo_test", then the cpu
|
||||
test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
|
||||
This rule generates a cc_test target for one or more XLA backends and also a
|
||||
platform-agnostic cc_library rule. The arguments are identical to cc_test with
|
||||
two additions: 'backends' and 'backend_args'. 'backends' specifies the
|
||||
backends to generate tests for ("cpu", "gpu"), and
|
||||
'backend_args'/'backend_tags' specifies backend-specific args parameters to
|
||||
use when generating the cc_test.
|
||||
|
||||
The cc_library target can be used to link with other plugins outside of
|
||||
xla_test.
|
||||
The name of the cc_tests are the provided name argument with the backend name
|
||||
appended, and the cc_library target name is the provided name argument with
|
||||
"_lib" appended. For example, if name parameter is "foo_test", then the cpu
|
||||
test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
|
||||
|
||||
The build rule also defines a test suite ${name} which includes the tests for
|
||||
each of the supported backends.
|
||||
The cc_library target can be used to link with other plugins outside of
|
||||
xla_test.
|
||||
|
||||
Each generated cc_test target has a tag indicating which backend the test is
|
||||
for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
|
||||
tags can be used to gather tests for a particular backend into a test_suite.
|
||||
The build rule also defines a test suite ${name} which includes the tests for
|
||||
each of the supported backends.
|
||||
|
||||
Examples:
|
||||
Each generated cc_test target has a tag indicating which backend the test is
|
||||
for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
|
||||
tags can be used to gather tests for a particular backend into a test_suite.
|
||||
|
||||
# Generates the targets: foo_test_cpu and foo_test_gpu.
|
||||
xla_test(
|
||||
name = "foo_test",
|
||||
srcs = ["foo_test.cc"],
|
||||
backends = ["cpu", "gpu"],
|
||||
deps = [...],
|
||||
)
|
||||
Examples:
|
||||
|
||||
# Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
|
||||
# includes the additional arg "--special_cpu_flag".
|
||||
xla_test(
|
||||
name = "bar_test",
|
||||
srcs = ["bar_test.cc"],
|
||||
backends = ["cpu", "gpu"],
|
||||
backend_args = {"cpu": ["--special_cpu_flag"]}
|
||||
deps = [...],
|
||||
)
|
||||
|
||||
The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
|
||||
to the value 1 where ${BACKEND} is the uppercase name of the backend.
|
||||
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
deps: Dependencies of the target.
|
||||
xla_test_library_deps: If set, the generated test targets will depend on the
|
||||
respective cc_libraries generated by the xla_test_library rule.
|
||||
backends: A list of backends to generate tests for. Supported values: "cpu",
|
||||
"gpu". If this list is empty, the test will be generated for all supported
|
||||
backends.
|
||||
blacklisted_backends: A list of backends to NOT generate tests for.
|
||||
args: Test arguments for the target.
|
||||
tags: Tags for the target.
|
||||
copts: Additional copts to pass to the build.
|
||||
data: Additional data to pass to the build.
|
||||
backend_tags: A dict mapping backend name to list of additional tags to
|
||||
use for that target.
|
||||
backend_args: A dict mapping backend name to list of additional args to
|
||||
use for that target.
|
||||
**kwargs: Additional keyword arguments to pass to native.cc_test.
|
||||
"""
|
||||
test_names = []
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
|
||||
backends = [
|
||||
backend
|
||||
for backend in backends
|
||||
if backend not in blacklisted_backends
|
||||
]
|
||||
|
||||
native.cc_library(
|
||||
name = "%s_lib" % name,
|
||||
srcs = srcs,
|
||||
copts = copts,
|
||||
testonly = True,
|
||||
deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
|
||||
# Generates the targets: foo_test_cpu and foo_test_gpu.
|
||||
xla_test(
|
||||
name = "foo_test",
|
||||
srcs = ["foo_test.cc"],
|
||||
backends = ["cpu", "gpu"],
|
||||
deps = [...],
|
||||
)
|
||||
|
||||
for backend in filter_backends(backends):
|
||||
test_name = "%s_%s" % (name, backend)
|
||||
this_backend_tags = ["xla_%s" % backend]
|
||||
this_backend_copts = []
|
||||
this_backend_args = backend_args.get(backend, [])
|
||||
this_backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
|
||||
elif backend == "gpu":
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
|
||||
this_backend_tags += ["requires-gpu-sm35"]
|
||||
elif backend in plugins:
|
||||
backend_deps = []
|
||||
backend_deps += plugins[backend]["deps"]
|
||||
this_backend_copts += plugins[backend]["copts"]
|
||||
this_backend_tags += plugins[backend]["tags"]
|
||||
this_backend_args += plugins[backend]["args"]
|
||||
this_backend_data += plugins[backend]["data"]
|
||||
else:
|
||||
fail("Unknown backend %s" % backend)
|
||||
# Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
|
||||
# includes the additional arg "--special_cpu_flag".
|
||||
xla_test(
|
||||
name = "bar_test",
|
||||
srcs = ["bar_test.cc"],
|
||||
backends = ["cpu", "gpu"],
|
||||
backend_args = {"cpu": ["--special_cpu_flag"]}
|
||||
deps = [...],
|
||||
)
|
||||
|
||||
if xla_test_library_deps:
|
||||
for lib_dep in xla_test_library_deps:
|
||||
backend_deps += ["%s_%s" % (lib_dep, backend)]
|
||||
The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
|
||||
to the value 1 where ${BACKEND} is the uppercase name of the backend.
|
||||
|
||||
tf_cc_test(
|
||||
name = test_name,
|
||||
srcs = srcs,
|
||||
tags = tags + backend_tags.get(backend, []) + this_backend_tags,
|
||||
extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
|
||||
this_backend_copts,
|
||||
args = args + this_backend_args,
|
||||
deps = deps + backend_deps,
|
||||
data = data + this_backend_data,
|
||||
**kwargs
|
||||
)
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
deps: Dependencies of the target.
|
||||
xla_test_library_deps: If set, the generated test targets will depend on the
|
||||
respective cc_libraries generated by the xla_test_library rule.
|
||||
backends: A list of backends to generate tests for. Supported values: "cpu",
|
||||
"gpu". If this list is empty, the test will be generated for all supported
|
||||
backends.
|
||||
blacklisted_backends: A list of backends to NOT generate tests for.
|
||||
args: Test arguments for the target.
|
||||
tags: Tags for the target.
|
||||
copts: Additional copts to pass to the build.
|
||||
data: Additional data to pass to the build.
|
||||
backend_tags: A dict mapping backend name to list of additional tags to
|
||||
use for that target.
|
||||
backend_args: A dict mapping backend name to list of additional args to
|
||||
use for that target.
|
||||
**kwargs: Additional keyword arguments to pass to native.cc_test.
|
||||
"""
|
||||
test_names = []
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
|
||||
test_names.append(test_name)
|
||||
backends = [backend for backend in backends
|
||||
if backend not in blacklisted_backends]
|
||||
|
||||
native.test_suite(name = name, tests = test_names)
|
||||
native.cc_library(
|
||||
name="%s_lib" % name,
|
||||
srcs=srcs,
|
||||
copts=copts,
|
||||
testonly=True,
|
||||
deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
|
||||
)
|
||||
|
||||
def xla_test_library(
|
||||
name,
|
||||
srcs,
|
||||
hdrs = [],
|
||||
deps = [],
|
||||
backends = []):
|
||||
"""Generates cc_library targets for the given XLA backends.
|
||||
for backend in filter_backends(backends):
|
||||
test_name = "%s_%s" % (name, backend)
|
||||
this_backend_tags = ["xla_%s" % backend]
|
||||
this_backend_copts = []
|
||||
this_backend_args = backend_args.get(backend, [])
|
||||
this_backend_data = []
|
||||
if backend == "cpu":
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
|
||||
elif backend == "gpu":
|
||||
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
|
||||
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
|
||||
this_backend_tags += ["requires-gpu-sm35"]
|
||||
elif backend in plugins:
|
||||
backend_deps = []
|
||||
backend_deps += plugins[backend]["deps"]
|
||||
this_backend_copts += plugins[backend]["copts"]
|
||||
this_backend_tags += plugins[backend]["tags"]
|
||||
this_backend_args += plugins[backend]["args"]
|
||||
this_backend_data += plugins[backend]["data"]
|
||||
else:
|
||||
fail("Unknown backend %s" % backend)
|
||||
|
||||
This rule forces the sources to be compiled for each backend so that the
|
||||
backend specific macros could expand correctly. It's useful when test targets
|
||||
in different directories referring to the same sources but test with different
|
||||
arguments.
|
||||
if xla_test_library_deps:
|
||||
for lib_dep in xla_test_library_deps:
|
||||
backend_deps += ["%s_%s" % (lib_dep, backend)]
|
||||
|
||||
Examples:
|
||||
tf_cc_test(
|
||||
name=test_name,
|
||||
srcs=srcs,
|
||||
tags=tags + backend_tags.get(backend, []) + this_backend_tags,
|
||||
extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
|
||||
this_backend_copts,
|
||||
args=args + this_backend_args,
|
||||
deps=deps + backend_deps,
|
||||
data=data + this_backend_data,
|
||||
**kwargs)
|
||||
|
||||
# Generates the targets: foo_test_library_cpu and foo_test_gpu.
|
||||
xla_test_library(
|
||||
name = "foo_test_library",
|
||||
srcs = ["foo_test.cc"],
|
||||
backends = ["cpu", "gpu"],
|
||||
deps = [...],
|
||||
)
|
||||
# Then use the xla_test rule to generate test targets:
|
||||
xla_test(
|
||||
name = "foo_test",
|
||||
srcs = [],
|
||||
backends = ["cpu", "gpu"],
|
||||
deps = [...],
|
||||
xla_test_library_deps = [":foo_test_library"],
|
||||
)
|
||||
test_names.append(test_name)
|
||||
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
hdrs: Headers for the target.
|
||||
deps: Dependencies of the target.
|
||||
backends: A list of backends to generate libraries for.
|
||||
Supported values: "cpu", "gpu". If this list is empty, the
|
||||
library will be generated for all supported backends.
|
||||
"""
|
||||
native.test_suite(name=name, tests=test_names)
|
||||
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
def xla_test_library(name,
|
||||
srcs,
|
||||
hdrs=[],
|
||||
deps=[],
|
||||
backends=[]):
|
||||
"""Generates cc_library targets for the given XLA backends.
|
||||
|
||||
for backend in filter_backends(backends):
|
||||
this_backend_copts = []
|
||||
if backend in ["cpu", "gpu"]:
|
||||
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
|
||||
elif backend in plugins:
|
||||
backend_deps = plugins[backend]["deps"]
|
||||
this_backend_copts += plugins[backend]["copts"]
|
||||
else:
|
||||
fail("Unknown backend %s" % backend)
|
||||
This rule forces the sources to be compiled for each backend so that the
|
||||
backend specific macros could expand correctly. It's useful when test targets
|
||||
in different directories referring to the same sources but test with different
|
||||
arguments.
|
||||
|
||||
native.cc_library(
|
||||
name = "%s_%s" % (name, backend),
|
||||
srcs = srcs,
|
||||
testonly = True,
|
||||
hdrs = hdrs,
|
||||
copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
|
||||
this_backend_copts,
|
||||
deps = deps + backend_deps,
|
||||
)
|
||||
Examples:
|
||||
|
||||
def generate_backend_suites(backends = []):
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
native.test_suite(
|
||||
name = "%s_tests" % backend,
|
||||
tags = ["xla_%s" % backend],
|
||||
)
|
||||
# Generates the targets: foo_test_library_cpu and foo_test_gpu.
|
||||
xla_test_library(
|
||||
name = "foo_test_library",
|
||||
srcs = ["foo_test.cc"],
|
||||
backends = ["cpu", "gpu"],
|
||||
deps = [...],
|
||||
)
|
||||
# Then use the xla_test rule to generate test targets:
|
||||
xla_test(
|
||||
name = "foo_test",
|
||||
srcs = [],
|
||||
backends = ["cpu", "gpu"],
|
||||
deps = [...],
|
||||
xla_test_library_deps = [":foo_test_library"],
|
||||
)
|
||||
|
||||
def generate_backend_test_macros(backends = []):
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
manifest = ""
|
||||
if backend in plugins:
|
||||
manifest = plugins[backend]["disabled_manifest"]
|
||||
Args:
|
||||
name: Name of the target.
|
||||
srcs: Sources for the target.
|
||||
hdrs: Headers for the target.
|
||||
deps: Dependencies of the target.
|
||||
backends: A list of backends to generate libraries for.
|
||||
Supported values: "cpu", "gpu". If this list is empty, the
|
||||
library will be generated for all supported backends.
|
||||
"""
|
||||
|
||||
native.cc_library(
|
||||
name = "test_macros_%s" % backend,
|
||||
testonly = True,
|
||||
srcs = ["test_macros.cc"],
|
||||
hdrs = ["test_macros.h"],
|
||||
copts = [
|
||||
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
|
||||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
|
||||
for backend in filter_backends(backends):
|
||||
this_backend_copts = []
|
||||
if backend in ["cpu", "gpu"]:
|
||||
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
|
||||
elif backend in plugins:
|
||||
backend_deps = plugins[backend]["deps"]
|
||||
this_backend_copts += plugins[backend]["copts"]
|
||||
else:
|
||||
fail("Unknown backend %s" % backend)
|
||||
|
||||
native.cc_library(
|
||||
name = "%s_%s" % (name, backend),
|
||||
srcs = srcs,
|
||||
testonly = True,
|
||||
hdrs = hdrs,
|
||||
copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()]
|
||||
+ this_backend_copts,
|
||||
deps = deps + backend_deps,
|
||||
)
|
||||
|
||||
|
||||
def generate_backend_suites(backends=[]):
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
native.test_suite(name="%s_tests" % backend,
|
||||
tags = ["xla_%s" % backend])
|
||||
|
||||
|
||||
def generate_backend_test_macros(backends=[]):
|
||||
if not backends:
|
||||
backends = all_backends
|
||||
for backend in filter_backends(backends):
|
||||
manifest = ""
|
||||
if backend in plugins:
|
||||
manifest = plugins[backend]["disabled_manifest"]
|
||||
|
||||
native.cc_library(
|
||||
name="test_macros_%s" % backend,
|
||||
testonly = True,
|
||||
srcs = ["test_macros.cc"],
|
||||
hdrs = ["test_macros.h"],
|
||||
copts = [
|
||||
"-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(),
|
||||
"-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest,
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
])
|
||||
|
@ -33,3 +33,4 @@
|
||||
# }
|
||||
|
||||
plugins = {}
|
||||
|
||||
|
@ -1,35 +1,30 @@
|
||||
"""Wrapper around cc_proto_library used inside the XLA codebase."""
|
||||
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"cc_proto_library",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static",
|
||||
)
|
||||
load("//tensorflow/core:platform/default/build_config.bzl",
|
||||
"cc_proto_library")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
"if_static")
|
||||
|
||||
# xla_proto_library() is a convenience wrapper around cc_proto_library.
|
||||
def xla_proto_library(name, srcs = [], deps = [], visibility = None, testonly = 0, **kwargs):
|
||||
if kwargs.get("use_grpc_plugin"):
|
||||
kwargs["use_grpc_namespace"] = True
|
||||
cc_proto_library(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
deps = deps,
|
||||
cc_libs = if_static(
|
||||
["@protobuf_archive//:protobuf"],
|
||||
otherwise = ["@protobuf_archive//:protobuf_headers"],
|
||||
),
|
||||
protoc = "@protobuf_archive//:protoc",
|
||||
testonly = testonly,
|
||||
visibility = visibility,
|
||||
**kwargs
|
||||
)
|
||||
def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0, **kwargs):
|
||||
if kwargs.get('use_grpc_plugin'):
|
||||
kwargs['use_grpc_namespace'] = True
|
||||
cc_proto_library(name=name,
|
||||
srcs=srcs,
|
||||
deps=deps,
|
||||
cc_libs = if_static(
|
||||
["@protobuf_archive//:protobuf"],
|
||||
otherwise=["@protobuf_archive//:protobuf_headers"],
|
||||
),
|
||||
protoc="@protobuf_archive//:protoc",
|
||||
testonly=testonly,
|
||||
visibility=visibility,
|
||||
**kwargs)
|
||||
|
||||
def xla_py_grpc_library(**kwargs):
|
||||
# Note: we don't currently define any special targets for Python GRPC in OSS.
|
||||
_ignore = kwargs
|
||||
pass
|
||||
# Note: we don't currently define any special targets for Python GRPC in OSS.
|
||||
_ignore = kwargs
|
||||
pass
|
||||
|
||||
|
||||
ORC_JIT_MEMORY_MAPPER_TARGETS = []
|
||||
|
@ -1,196 +1,193 @@
|
||||
"""Generate Flatbuffer binary from json."""
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
def tflite_copts():
|
||||
"""Defines compile time flags."""
|
||||
copts = [
|
||||
"-DFARMHASH_NO_CXX_STRING",
|
||||
] + select({
|
||||
str(Label("//tensorflow:android_arm64")): [
|
||||
"-std=c++11",
|
||||
"-O3",
|
||||
],
|
||||
str(Label("//tensorflow:android_arm")): [
|
||||
"-mfpu=neon",
|
||||
"-mfloat-abi=softfp",
|
||||
"-std=c++11",
|
||||
"-O3",
|
||||
],
|
||||
str(Label("//tensorflow:android_x86")): [
|
||||
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
|
||||
],
|
||||
str(Label("//tensorflow:ios_x86_64")): [
|
||||
"-msse4.1",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
str(Label("//tensorflow:with_default_optimizations")): [],
|
||||
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
|
||||
})
|
||||
"""Defines compile time flags."""
|
||||
copts = [
|
||||
"-DFARMHASH_NO_CXX_STRING",
|
||||
] + select({
|
||||
str(Label("//tensorflow:android_arm64")): [
|
||||
"-std=c++11",
|
||||
"-O3",
|
||||
],
|
||||
str(Label("//tensorflow:android_arm")): [
|
||||
"-mfpu=neon",
|
||||
"-mfloat-abi=softfp",
|
||||
"-std=c++11",
|
||||
"-O3",
|
||||
],
|
||||
str(Label("//tensorflow:android_x86")): [
|
||||
"-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK",
|
||||
],
|
||||
str(Label("//tensorflow:ios_x86_64")): [
|
||||
"-msse4.1",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
str(Label("//tensorflow:with_default_optimizations")): [],
|
||||
"//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"],
|
||||
})
|
||||
|
||||
return copts
|
||||
return copts
|
||||
|
||||
LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds"
|
||||
|
||||
def tflite_linkopts_unstripped():
|
||||
"""Defines linker flags to reduce size of TFLite binary.
|
||||
"""Defines linker flags to reduce size of TFLite binary.
|
||||
|
||||
These are useful when trying to investigate the relative size of the
|
||||
symbols in TFLite.
|
||||
These are useful when trying to investigate the relative size of the
|
||||
symbols in TFLite.
|
||||
|
||||
Returns:
|
||||
a select object with proper linkopts
|
||||
"""
|
||||
return select({
|
||||
"//tensorflow:android": [
|
||||
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
|
||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
|
||||
"-Wl,--gc-sections", # Eliminate unused code and data.
|
||||
"-Wl,--as-needed", # Don't link unused libs.
|
||||
],
|
||||
"//tensorflow/contrib/lite:mips": [],
|
||||
"//tensorflow/contrib/lite:mips64": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--icf=all", # Identical code folding.
|
||||
],
|
||||
})
|
||||
Returns:
|
||||
a select object with proper linkopts
|
||||
"""
|
||||
return select({
|
||||
"//tensorflow:android": [
|
||||
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
|
||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export.
|
||||
"-Wl,--gc-sections", # Eliminate unused code and data.
|
||||
"-Wl,--as-needed", # Don't link unused libs.
|
||||
],
|
||||
"//tensorflow/contrib/lite:mips": [],
|
||||
"//tensorflow/contrib/lite:mips64": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--icf=all", # Identical code folding.
|
||||
],
|
||||
})
|
||||
|
||||
def tflite_jni_linkopts_unstripped():
|
||||
"""Defines linker flags to reduce size of TFLite binary with JNI.
|
||||
"""Defines linker flags to reduce size of TFLite binary with JNI.
|
||||
|
||||
These are useful when trying to investigate the relative size of the
|
||||
symbols in TFLite.
|
||||
These are useful when trying to investigate the relative size of the
|
||||
symbols in TFLite.
|
||||
|
||||
Returns:
|
||||
a select object with proper linkopts
|
||||
"""
|
||||
return select({
|
||||
"//tensorflow:android": [
|
||||
"-Wl,--gc-sections", # Eliminate unused code and data.
|
||||
"-Wl,--as-needed", # Don't link unused libs.
|
||||
],
|
||||
"//tensorflow/contrib/lite:mips": [],
|
||||
"//tensorflow/contrib/lite:mips64": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--icf=all", # Identical code folding.
|
||||
],
|
||||
})
|
||||
Returns:
|
||||
a select object with proper linkopts
|
||||
"""
|
||||
return select({
|
||||
"//tensorflow:android": [
|
||||
"-Wl,--gc-sections", # Eliminate unused code and data.
|
||||
"-Wl,--as-needed", # Don't link unused libs.
|
||||
],
|
||||
"//tensorflow/contrib/lite:mips": [],
|
||||
"//tensorflow/contrib/lite:mips64": [],
|
||||
"//conditions:default": [
|
||||
"-Wl,--icf=all", # Identical code folding.
|
||||
],
|
||||
})
|
||||
|
||||
def tflite_linkopts():
|
||||
"""Defines linker flags to reduce size of TFLite binary."""
|
||||
return tflite_linkopts_unstripped() + select({
|
||||
"//tensorflow:android": [
|
||||
"-s", # Omit symbol table.
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
"""Defines linker flags to reduce size of TFLite binary."""
|
||||
return tflite_linkopts_unstripped() + select({
|
||||
"//tensorflow:android": [
|
||||
"-s", # Omit symbol table.
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tflite_jni_linkopts():
|
||||
"""Defines linker flags to reduce size of TFLite binary with JNI."""
|
||||
return tflite_jni_linkopts_unstripped() + select({
|
||||
"//tensorflow:android": [
|
||||
"-s", # Omit symbol table.
|
||||
"-latomic", # Required for some uses of ISO C++11 <atomic> in x86.
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
"""Defines linker flags to reduce size of TFLite binary with JNI."""
|
||||
return tflite_jni_linkopts_unstripped() + select({
|
||||
"//tensorflow:android": [
|
||||
"-s", # Omit symbol table.
|
||||
"-latomic", # Required for some uses of ISO C++11 <atomic> in x86.
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tflite_jni_binary(
|
||||
name,
|
||||
copts = tflite_copts(),
|
||||
linkopts = tflite_jni_linkopts(),
|
||||
linkscript = LINKER_SCRIPT,
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = []):
|
||||
"""Builds a jni binary for TFLite."""
|
||||
linkopts = linkopts + [
|
||||
"-Wl,--version-script", # Export only jni functions & classes.
|
||||
"$(location {})".format(linkscript),
|
||||
]
|
||||
native.cc_binary(
|
||||
name = name,
|
||||
copts = copts,
|
||||
linkshared = linkshared,
|
||||
linkstatic = linkstatic,
|
||||
deps = deps + [linkscript],
|
||||
linkopts = linkopts,
|
||||
)
|
||||
def tflite_jni_binary(name,
|
||||
copts=tflite_copts(),
|
||||
linkopts=tflite_jni_linkopts(),
|
||||
linkscript=LINKER_SCRIPT,
|
||||
linkshared=1,
|
||||
linkstatic=1,
|
||||
deps=[]):
|
||||
"""Builds a jni binary for TFLite."""
|
||||
linkopts = linkopts + [
|
||||
"-Wl,--version-script", # Export only jni functions & classes.
|
||||
"$(location {})".format(linkscript),
|
||||
]
|
||||
native.cc_binary(
|
||||
name=name,
|
||||
copts=copts,
|
||||
linkshared=linkshared,
|
||||
linkstatic=linkstatic,
|
||||
deps= deps + [linkscript],
|
||||
linkopts=linkopts)
|
||||
|
||||
def tf_to_tflite(name, src, options, out):
|
||||
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
|
||||
"""Convert a frozen tensorflow graphdef to TF Lite's flatbuffer.
|
||||
|
||||
Args:
|
||||
name: Name of rule.
|
||||
src: name of the input graphdef file.
|
||||
options: options passed to TOCO.
|
||||
out: name of the output flatbuffer file.
|
||||
"""
|
||||
Args:
|
||||
name: Name of rule.
|
||||
src: name of the input graphdef file.
|
||||
options: options passed to TOCO.
|
||||
out: name of the output flatbuffer file.
|
||||
"""
|
||||
|
||||
toco_cmdline = " ".join([
|
||||
"//tensorflow/contrib/lite/toco:toco",
|
||||
"--input_format=TENSORFLOW_GRAPHDEF",
|
||||
"--output_format=TFLITE",
|
||||
("--input_file=$(location %s)" % src),
|
||||
("--output_file=$(location %s)" % out),
|
||||
] + options)
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [src],
|
||||
outs = [out],
|
||||
cmd = toco_cmdline,
|
||||
tools = ["//tensorflow/contrib/lite/toco:toco"],
|
||||
)
|
||||
toco_cmdline = " ".join([
|
||||
"//tensorflow/contrib/lite/toco:toco",
|
||||
"--input_format=TENSORFLOW_GRAPHDEF",
|
||||
"--output_format=TFLITE",
|
||||
("--input_file=$(location %s)" % src),
|
||||
("--output_file=$(location %s)" % out),
|
||||
] + options )
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs=[src],
|
||||
outs=[out],
|
||||
cmd = toco_cmdline,
|
||||
tools= ["//tensorflow/contrib/lite/toco:toco"],
|
||||
)
|
||||
|
||||
def tflite_to_json(name, src, out):
|
||||
"""Convert a TF Lite flatbuffer to JSON.
|
||||
"""Convert a TF Lite flatbuffer to JSON.
|
||||
|
||||
Args:
|
||||
name: Name of rule.
|
||||
src: name of the input flatbuffer file.
|
||||
out: name of the output JSON file.
|
||||
"""
|
||||
Args:
|
||||
name: Name of rule.
|
||||
src: name of the input flatbuffer file.
|
||||
out: name of the output JSON file.
|
||||
"""
|
||||
|
||||
flatc = "@flatbuffers//:flatc"
|
||||
schema = "//tensorflow/contrib/lite/schema:schema.fbs"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [schema, src],
|
||||
outs = [out],
|
||||
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
|
||||
"$(location %s) --raw-binary --strict-json -t" +
|
||||
" -o /tmp $(location %s) -- $${TMP}.bin &&" +
|
||||
"cp $${TMP}.json $(location %s)") %
|
||||
(src, flatc, schema, out),
|
||||
tools = [flatc],
|
||||
)
|
||||
flatc = "@flatbuffers//:flatc"
|
||||
schema = "//tensorflow/contrib/lite/schema:schema.fbs"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [schema, src],
|
||||
outs = [out],
|
||||
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" +
|
||||
"$(location %s) --raw-binary --strict-json -t" +
|
||||
" -o /tmp $(location %s) -- $${TMP}.bin &&" +
|
||||
"cp $${TMP}.json $(location %s)")
|
||||
% (src, flatc, schema, out),
|
||||
tools = [flatc],
|
||||
)
|
||||
|
||||
def json_to_tflite(name, src, out):
|
||||
"""Convert a JSON file to TF Lite's flatbuffer.
|
||||
"""Convert a JSON file to TF Lite's flatbuffer.
|
||||
|
||||
Args:
|
||||
name: Name of rule.
|
||||
src: name of the input JSON file.
|
||||
out: name of the output flatbuffer file.
|
||||
"""
|
||||
Args:
|
||||
name: Name of rule.
|
||||
src: name of the input JSON file.
|
||||
out: name of the output flatbuffer file.
|
||||
"""
|
||||
|
||||
flatc = "@flatbuffers//:flatc"
|
||||
schema = "//tensorflow/contrib/lite/schema:schema_fbs"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [schema, src],
|
||||
outs = [out],
|
||||
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
|
||||
"$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
|
||||
" -o /tmp $(location %s) $${TMP}.json &&" +
|
||||
"cp $${TMP}.bin $(location %s)") %
|
||||
(src, flatc, schema, out),
|
||||
tools = [flatc],
|
||||
)
|
||||
flatc = "@flatbuffers//:flatc"
|
||||
schema = "//tensorflow/contrib/lite/schema:schema_fbs"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [schema, src],
|
||||
outs = [out],
|
||||
cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" +
|
||||
"$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" +
|
||||
" -o /tmp $(location %s) $${TMP}.json &&" +
|
||||
"cp $${TMP}.bin $(location %s)")
|
||||
% (src, flatc, schema, out),
|
||||
tools = [flatc],
|
||||
)
|
||||
|
||||
# This is the master list of generated examples that will be made into tests. A
|
||||
# function called make_XXX_tests() must also appear in generate_examples.py.
|
||||
@ -265,58 +262,58 @@ def generated_test_models():
|
||||
]
|
||||
|
||||
def gen_zip_test(name, test_name, **kwargs):
|
||||
"""Generate a zipped-example test and its dependent zip files.
|
||||
"""Generate a zipped-example test and its dependent zip files.
|
||||
|
||||
Args:
|
||||
name: Resulting cc_test target name
|
||||
test_name: Test targets this model. Comes from the list above.
|
||||
**kwargs: tf_cc_test kwargs.
|
||||
"""
|
||||
gen_zipped_test_file(
|
||||
name = "zip_%s" % test_name,
|
||||
file = "%s.zip" % test_name,
|
||||
)
|
||||
tf_cc_test(name, **kwargs)
|
||||
Args:
|
||||
name: Resulting cc_test target name
|
||||
test_name: Test targets this model. Comes from the list above.
|
||||
**kwargs: tf_cc_test kwargs.
|
||||
"""
|
||||
gen_zipped_test_file(
|
||||
name = "zip_%s" % test_name,
|
||||
file = "%s.zip" % test_name,
|
||||
)
|
||||
tf_cc_test(name, **kwargs)
|
||||
|
||||
def gen_zipped_test_file(name, file):
|
||||
"""Generate a zip file of tests by using :generate_examples.
|
||||
"""Generate a zip file of tests by using :generate_examples.
|
||||
|
||||
Args:
|
||||
name: Name of output. We will produce "`file`.files" as a target.
|
||||
file: The name of one of the generated_examples targets, e.g. "transpose"
|
||||
"""
|
||||
toco = "//tensorflow/contrib/lite/toco:toco"
|
||||
native.genrule(
|
||||
name = file + ".files",
|
||||
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco +
|
||||
" --zip_to_output " + file + " $(@D)"),
|
||||
outs = [file],
|
||||
tools = [
|
||||
":generate_examples",
|
||||
toco,
|
||||
],
|
||||
)
|
||||
Args:
|
||||
name: Name of output. We will produce "`file`.files" as a target.
|
||||
file: The name of one of the generated_examples targets, e.g. "transpose"
|
||||
"""
|
||||
toco = "//tensorflow/contrib/lite/toco:toco"
|
||||
native.genrule(
|
||||
name = file + ".files",
|
||||
cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco
|
||||
+ " --zip_to_output " + file + " $(@D)"),
|
||||
outs = [file],
|
||||
tools = [
|
||||
":generate_examples",
|
||||
toco,
|
||||
],
|
||||
)
|
||||
|
||||
native.filegroup(
|
||||
name = name,
|
||||
srcs = [file],
|
||||
)
|
||||
native.filegroup(
|
||||
name = name,
|
||||
srcs = [file],
|
||||
)
|
||||
|
||||
def gen_selected_ops(name, model):
|
||||
"""Generate the library that includes only used ops.
|
||||
"""Generate the library that includes only used ops.
|
||||
|
||||
Args:
|
||||
name: Name of the generated library.
|
||||
model: TFLite model to interpret.
|
||||
"""
|
||||
out = name + "_registration.cc"
|
||||
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
|
||||
tflite_path = "//tensorflow/contrib/lite"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [model],
|
||||
outs = [out],
|
||||
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") %
|
||||
(tool, model, out, tflite_path[2:]),
|
||||
tools = [tool],
|
||||
)
|
||||
Args:
|
||||
name: Name of the generated library.
|
||||
model: TFLite model to interpret.
|
||||
"""
|
||||
out = name + "_registration.cc"
|
||||
tool = "//tensorflow/contrib/lite/tools:generate_op_registrations"
|
||||
tflite_path = "//tensorflow/contrib/lite"
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [model],
|
||||
outs = [out],
|
||||
cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s")
|
||||
% (tool, model, out, tflite_path[2:]),
|
||||
tools = [tool],
|
||||
)
|
||||
|
@ -3,12 +3,12 @@
|
||||
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
|
||||
|
||||
def aar_with_jni(name, android_library):
|
||||
# Generate dummy AndroidManifest.xml for dummy apk usage
|
||||
# (dummy apk is generated by <name>_dummy_app_for_so target below)
|
||||
native.genrule(
|
||||
name = name + "_binary_manifest_generator",
|
||||
outs = [name + "_generated_AndroidManifest.xml"],
|
||||
cmd = """
|
||||
# Generate dummy AndroidManifest.xml for dummy apk usage
|
||||
# (dummy apk is generated by <name>_dummy_app_for_so target below)
|
||||
native.genrule(
|
||||
name = name + "_binary_manifest_generator",
|
||||
outs = [name + "_generated_AndroidManifest.xml"],
|
||||
cmd = """
|
||||
cat > $(OUTS) <<EOF
|
||||
<manifest
|
||||
xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
@ -17,27 +17,27 @@ cat > $(OUTS) <<EOF
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
)
|
||||
)
|
||||
|
||||
# Generate dummy apk including .so files and later we extract out
|
||||
# .so files and throw away the apk.
|
||||
android_binary(
|
||||
name = name + "_dummy_app_for_so",
|
||||
manifest = name + "_generated_AndroidManifest.xml",
|
||||
custom_package = "dummy.package.for.so",
|
||||
deps = [android_library],
|
||||
# In some platforms we don't have an Android SDK/NDK and this target
|
||||
# can't be built. We need to prevent the build system from trying to
|
||||
# use the target in that case.
|
||||
tags = ["manual"],
|
||||
)
|
||||
# Generate dummy apk including .so files and later we extract out
|
||||
# .so files and throw away the apk.
|
||||
android_binary(
|
||||
name = name + "_dummy_app_for_so",
|
||||
manifest = name + "_generated_AndroidManifest.xml",
|
||||
custom_package = "dummy.package.for.so",
|
||||
deps = [android_library],
|
||||
# In some platforms we don't have an Android SDK/NDK and this target
|
||||
# can't be built. We need to prevent the build system from trying to
|
||||
# use the target in that case.
|
||||
tags = ["manual"],
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
|
||||
outs = [name + ".aar"],
|
||||
tags = ["manual"],
|
||||
cmd = """
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
|
||||
outs = [name + ".aar"],
|
||||
tags = ["manual"],
|
||||
cmd = """
|
||||
cp $(location {}.aar) $(location :{}.aar)
|
||||
chmod +w $(location :{}.aar)
|
||||
origdir=$$PWD
|
||||
@ -46,4 +46,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*"
|
||||
cp -r lib jni
|
||||
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
|
||||
""".format(android_library, name, name, name, name),
|
||||
)
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""External versions of build rules that differ outside of Google."""
|
||||
|
||||
def tflite_portable_test_suite(**kwargs):
|
||||
"""This is a no-op outside of Google."""
|
||||
_ignore = [kwargs]
|
||||
pass
|
||||
"""This is a no-op outside of Google."""
|
||||
_ignore = [kwargs]
|
||||
pass
|
||||
|
@ -9,87 +9,81 @@ load("//tensorflow:tensorflow.bzl", "register_extension_info")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
|
||||
|
||||
def _test_name(test, path):
|
||||
return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
|
||||
return "%s_%s_test" % (test, path.split("/")[-1].split(".")[0])
|
||||
|
||||
def decode_proto_test_suite(name, examples):
|
||||
"""Build the decode_proto py_test for each test filename."""
|
||||
for test_filename in examples:
|
||||
tf_py_test(
|
||||
name = _test_name("decode_proto", test_filename),
|
||||
srcs = ["decode_proto_op_test.py"],
|
||||
size = "small",
|
||||
data = [test_filename] + if_static(
|
||||
[],
|
||||
otherwise = [":libtestexample.so"],
|
||||
),
|
||||
main = "decode_proto_op_test.py",
|
||||
args = [
|
||||
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
|
||||
],
|
||||
additional_deps = [
|
||||
":py_test_deps",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/proto:proto",
|
||||
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
|
||||
],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/78026780)
|
||||
"no_windows", # TODO(b/78028010)
|
||||
],
|
||||
)
|
||||
native.test_suite(
|
||||
name = name,
|
||||
tests = [
|
||||
":" + _test_name("decode_proto", test_filename)
|
||||
for test_filename in examples
|
||||
"""Build the decode_proto py_test for each test filename."""
|
||||
for test_filename in examples:
|
||||
tf_py_test(
|
||||
name = _test_name("decode_proto", test_filename),
|
||||
srcs = ["decode_proto_op_test.py"],
|
||||
size = "small",
|
||||
data = [test_filename] + if_static(
|
||||
[],
|
||||
otherwise = [":libtestexample.so"],
|
||||
),
|
||||
main = "decode_proto_op_test.py",
|
||||
args = [
|
||||
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
|
||||
],
|
||||
additional_deps = [
|
||||
":py_test_deps",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/proto:proto",
|
||||
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
|
||||
],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/78026780)
|
||||
"no_windows", # TODO(b/78028010)
|
||||
],
|
||||
)
|
||||
native.test_suite(
|
||||
name = name,
|
||||
tests = [":" + _test_name("decode_proto", test_filename)
|
||||
for test_filename in examples],
|
||||
)
|
||||
|
||||
def encode_proto_test_suite(name, examples):
|
||||
"""Build the encode_proto py_test for each test filename."""
|
||||
for test_filename in examples:
|
||||
tf_py_test(
|
||||
name = _test_name("encode_proto", test_filename),
|
||||
srcs = ["encode_proto_op_test.py"],
|
||||
size = "small",
|
||||
data = [test_filename] + if_static(
|
||||
[],
|
||||
otherwise = [":libtestexample.so"],
|
||||
),
|
||||
main = "encode_proto_op_test.py",
|
||||
args = [
|
||||
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
|
||||
],
|
||||
additional_deps = [
|
||||
":py_test_deps",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/proto:proto",
|
||||
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
|
||||
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
|
||||
],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/78026780)
|
||||
"no_windows", # TODO(b/78028010)
|
||||
],
|
||||
)
|
||||
native.test_suite(
|
||||
name = name,
|
||||
tests = [
|
||||
":" + _test_name("encode_proto", test_filename)
|
||||
for test_filename in examples
|
||||
"""Build the encode_proto py_test for each test filename."""
|
||||
for test_filename in examples:
|
||||
tf_py_test(
|
||||
name = _test_name("encode_proto", test_filename),
|
||||
srcs = ["encode_proto_op_test.py"],
|
||||
size = "small",
|
||||
data = [test_filename] + if_static(
|
||||
[],
|
||||
otherwise = [":libtestexample.so"],
|
||||
),
|
||||
main = "encode_proto_op_test.py",
|
||||
args = [
|
||||
"--message_text_file=\"%s/%s\"" % (native.package_name(), test_filename),
|
||||
],
|
||||
additional_deps = [
|
||||
":py_test_deps",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/contrib/proto:proto",
|
||||
"//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
|
||||
"//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
|
||||
],
|
||||
tags = [
|
||||
"no_pip", # TODO(b/78026780)
|
||||
"no_windows", # TODO(b/78028010)
|
||||
],
|
||||
)
|
||||
native.test_suite(
|
||||
name = name,
|
||||
tests = [":" + _test_name("encode_proto", test_filename)
|
||||
for test_filename in examples],
|
||||
)
|
||||
|
||||
register_extension_info(
|
||||
extension_name = "decode_proto_test_suite",
|
||||
label_regex_map = {
|
||||
"deps": "deps:decode_example_.*",
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
register_extension_info(
|
||||
extension_name = "encode_proto_test_suite",
|
||||
label_regex_map = {
|
||||
"deps": "deps:encode_example_.*",
|
||||
},
|
||||
)
|
||||
})
|
||||
|
@ -1,13 +1,13 @@
|
||||
"""Fuzzing template for TensorFlow ops."""
|
||||
|
||||
def tf_ops_fuzz_target_lib(name):
|
||||
native.cc_library(
|
||||
name = name + "_fuzz_lib",
|
||||
srcs = [name + "_fuzz.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core/kernels/fuzzing:fuzz_session",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
],
|
||||
tags = ["no_windows"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
native.cc_library(
|
||||
name = name + "_fuzz_lib",
|
||||
srcs = [name + "_fuzz.cc"],
|
||||
deps = [
|
||||
"//tensorflow/core/kernels/fuzzing:fuzz_session",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
],
|
||||
tags = ["no_windows"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,58 +3,58 @@
|
||||
# be separate to avoid cyclic references.
|
||||
|
||||
def tf_cuda_tests_tags():
|
||||
return ["requires-gpu"]
|
||||
return ["requires-gpu"]
|
||||
|
||||
def tf_sycl_tests_tags():
|
||||
return ["requires-gpu"]
|
||||
return ["requires-gpu"]
|
||||
|
||||
def tf_additional_plugin_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_xla_support")): [
|
||||
str(Label("//tensorflow/compiler/jit")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
return select({
|
||||
str(Label("//tensorflow:with_xla_support")): [
|
||||
str(Label("//tensorflow/compiler/jit"))
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_xla_deps_py():
|
||||
return []
|
||||
return []
|
||||
|
||||
def tf_additional_grpc_deps_py():
|
||||
return []
|
||||
return []
|
||||
|
||||
def tf_additional_license_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
return select({
|
||||
str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_verbs_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_verbs_support")): [
|
||||
str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
|
||||
str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
return select({
|
||||
str(Label("//tensorflow:with_verbs_support")): [
|
||||
str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
|
||||
str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_mpi_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_mpi_support")): [
|
||||
str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
return select({
|
||||
str(Label("//tensorflow:with_mpi_support")): [
|
||||
str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_gdr_deps():
|
||||
return select({
|
||||
str(Label("//tensorflow:with_gdr_support")): [
|
||||
str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
return select({
|
||||
str(Label("//tensorflow:with_gdr_support")): [
|
||||
str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def if_static(extra_deps, otherwise = []):
|
||||
return select({
|
||||
str(Label("//tensorflow:framework_shared_object")): otherwise,
|
||||
"//conditions:default": extra_deps,
|
||||
})
|
||||
def if_static(extra_deps, otherwise=[]):
|
||||
return select({
|
||||
str(Label("//tensorflow:framework_shared_object")): otherwise,
|
||||
"//conditions:default": extra_deps,
|
||||
})
|
||||
|
@ -5,52 +5,55 @@ CUDNN_VERSION = ""
|
||||
PLATFORM = ""
|
||||
|
||||
def cuda_sdk_version():
|
||||
return CUDA_VERSION
|
||||
return CUDA_VERSION
|
||||
|
||||
def cudnn_sdk_version():
|
||||
return CUDNN_VERSION
|
||||
return CUDNN_VERSION
|
||||
|
||||
def cuda_library_path(name, version = cuda_sdk_version()):
|
||||
if PLATFORM == "Darwin":
|
||||
if not version:
|
||||
return "lib/lib{}.dylib".format(name)
|
||||
else:
|
||||
return "lib/lib{}.{}.dylib".format(name, version)
|
||||
elif not version:
|
||||
return "lib64/lib{}.so".format(name)
|
||||
if PLATFORM == "Darwin":
|
||||
if not version:
|
||||
return "lib/lib{}.dylib".format(name)
|
||||
else:
|
||||
return "lib64/lib{}.so.{}".format(name, version)
|
||||
return "lib/lib{}.{}.dylib".format(name, version)
|
||||
else:
|
||||
if not version:
|
||||
return "lib64/lib{}.so".format(name)
|
||||
else:
|
||||
return "lib64/lib{}.so.{}".format(name, version)
|
||||
|
||||
def cuda_static_library_path(name):
|
||||
if PLATFORM == "Darwin":
|
||||
return "lib/lib{}_static.a".format(name)
|
||||
else:
|
||||
return "lib64/lib{}_static.a".format(name)
|
||||
if PLATFORM == "Darwin":
|
||||
return "lib/lib{}_static.a".format(name)
|
||||
else:
|
||||
return "lib64/lib{}_static.a".format(name)
|
||||
|
||||
def cudnn_library_path(version = cudnn_sdk_version()):
|
||||
if PLATFORM == "Darwin":
|
||||
if not version:
|
||||
return "lib/libcudnn.dylib"
|
||||
else:
|
||||
return "lib/libcudnn.{}.dylib".format(version)
|
||||
elif not version:
|
||||
return "lib64/libcudnn.so"
|
||||
if PLATFORM == "Darwin":
|
||||
if not version:
|
||||
return "lib/libcudnn.dylib"
|
||||
else:
|
||||
return "lib64/libcudnn.so.{}".format(version)
|
||||
return "lib/libcudnn.{}.dylib".format(version)
|
||||
else:
|
||||
if not version:
|
||||
return "lib64/libcudnn.so"
|
||||
else:
|
||||
return "lib64/libcudnn.so.{}".format(version)
|
||||
|
||||
def cupti_library_path(version = cuda_sdk_version()):
|
||||
if PLATFORM == "Darwin":
|
||||
if not version:
|
||||
return "extras/CUPTI/lib/libcupti.dylib"
|
||||
else:
|
||||
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
|
||||
elif not version:
|
||||
return "extras/CUPTI/lib64/libcupti.so"
|
||||
if PLATFORM == "Darwin":
|
||||
if not version:
|
||||
return "extras/CUPTI/lib/libcupti.dylib"
|
||||
else:
|
||||
return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
|
||||
return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
|
||||
else:
|
||||
if not version:
|
||||
return "extras/CUPTI/lib64/libcupti.so"
|
||||
else:
|
||||
return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
|
||||
|
||||
def readlink_command():
|
||||
if PLATFORM == "Darwin":
|
||||
return "greadlink"
|
||||
else:
|
||||
return "readlink"
|
||||
if PLATFORM == "Darwin":
|
||||
return "greadlink"
|
||||
else:
|
||||
return "readlink"
|
||||
|
@ -18,7 +18,7 @@ XLINT_OPTS = [
|
||||
"-Xlint:-processing",
|
||||
"-Xlint:-serial",
|
||||
"-Xlint:-try",
|
||||
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
|
||||
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
|
||||
]
|
||||
|
||||
# The bazel errorprone plugin currently only enables default errorChecks
|
||||
|
@ -17,48 +17,46 @@ load(
|
||||
# and then archive those source files into
|
||||
# ops/gen_sources.srcjar
|
||||
#
|
||||
def tf_java_op_gen_srcjar(
|
||||
name,
|
||||
gen_tool,
|
||||
base_package,
|
||||
api_def_srcs = [],
|
||||
out_dir = "ops/",
|
||||
out_src_dir = "src/main/java/",
|
||||
visibility = ["//tensorflow/java:__pkg__"]):
|
||||
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
|
||||
srcs = api_def_srcs[:]
|
||||
def tf_java_op_gen_srcjar(name,
|
||||
gen_tool,
|
||||
base_package,
|
||||
api_def_srcs=[],
|
||||
out_dir="ops/",
|
||||
out_src_dir="src/main/java/",
|
||||
visibility=["//tensorflow/java:__pkg__"]):
|
||||
|
||||
if not api_def_srcs:
|
||||
api_def_args_str = ","
|
||||
else:
|
||||
api_def_args = []
|
||||
for api_def_src in api_def_srcs:
|
||||
# Add directory of the first ApiDef source to args.
|
||||
# We are assuming all ApiDefs in a single api_def_src are in the
|
||||
# same directory.
|
||||
api_def_args.append(
|
||||
"$$(dirname $$(echo $(locations " + api_def_src +
|
||||
") | cut -d\" \" -f1))",
|
||||
)
|
||||
api_def_args_str = ",".join(api_def_args)
|
||||
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
|
||||
srcs = api_def_srcs[:]
|
||||
|
||||
gen_cmds += ["$(location " + gen_tool + ")" +
|
||||
" --output_dir=$(@D)/" + out_src_dir +
|
||||
" --base_package=" + base_package +
|
||||
" --api_dirs=" + api_def_args_str]
|
||||
if not api_def_srcs:
|
||||
api_def_args_str = ","
|
||||
else:
|
||||
api_def_args = []
|
||||
for api_def_src in api_def_srcs:
|
||||
# Add directory of the first ApiDef source to args.
|
||||
# We are assuming all ApiDefs in a single api_def_src are in the
|
||||
# same directory.
|
||||
api_def_args.append(
|
||||
"$$(dirname $$(echo $(locations " + api_def_src +
|
||||
") | cut -d\" \" -f1))")
|
||||
api_def_args_str = ",".join(api_def_args)
|
||||
|
||||
# Generate a source archive containing generated code for these ops.
|
||||
gen_srcjar = out_dir + name + ".srcjar"
|
||||
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
|
||||
gen_cmds += ["$(location " + gen_tool + ")" +
|
||||
" --output_dir=$(@D)/" + out_src_dir +
|
||||
" --base_package=" + base_package +
|
||||
" --api_dirs=" + api_def_args_str]
|
||||
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
outs = [gen_srcjar],
|
||||
tools = [
|
||||
"@local_jdk//:jar",
|
||||
"@local_jdk//:jdk",
|
||||
gen_tool,
|
||||
] + tf_binary_additional_srcs(),
|
||||
cmd = " && ".join(gen_cmds),
|
||||
)
|
||||
# Generate a source archive containing generated code for these ops.
|
||||
gen_srcjar = out_dir + name + ".srcjar"
|
||||
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
|
||||
|
||||
native.genrule(
|
||||
name=name,
|
||||
srcs=srcs,
|
||||
outs=[gen_srcjar],
|
||||
tools=[
|
||||
"@local_jdk//:jar",
|
||||
"@local_jdk//:jdk",
|
||||
gen_tool
|
||||
] + tf_binary_additional_srcs(),
|
||||
cmd=" && ".join(gen_cmds))
|
||||
|
@ -12,26 +12,22 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
# consumers of the tf_gen_op_wrapper_py rule would be simplified if we don't
|
||||
# hard code the ops/ directory.
|
||||
|
||||
def tf_gen_op_wrapper_private_py(
|
||||
name,
|
||||
out = None,
|
||||
deps = [],
|
||||
require_shape_functions = True,
|
||||
visibility = []):
|
||||
if not name.endswith("_gen"):
|
||||
fail("name must end in _gen")
|
||||
if not visibility:
|
||||
visibility = ["//visibility:private"]
|
||||
bare_op_name = name[:-4] # Strip off the _gen
|
||||
tf_gen_op_wrapper_py(
|
||||
name = bare_op_name,
|
||||
out = out,
|
||||
visibility = visibility,
|
||||
deps = deps,
|
||||
require_shape_functions = require_shape_functions,
|
||||
generated_target_name = name,
|
||||
api_def_srcs = [
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
"//tensorflow/core/api_def:python_api_def",
|
||||
],
|
||||
)
|
||||
def tf_gen_op_wrapper_private_py(name, out=None, deps=[],
|
||||
require_shape_functions=True,
|
||||
visibility=[]):
|
||||
if not name.endswith("_gen"):
|
||||
fail("name must end in _gen")
|
||||
if not visibility:
|
||||
visibility = ["//visibility:private"]
|
||||
bare_op_name = name[:-4] # Strip off the _gen
|
||||
tf_gen_op_wrapper_py(name=bare_op_name,
|
||||
out=out,
|
||||
visibility=visibility,
|
||||
deps=deps,
|
||||
require_shape_functions=require_shape_functions,
|
||||
generated_target_name=name,
|
||||
api_def_srcs = [
|
||||
"//tensorflow/core/api_def:base_api_def",
|
||||
"//tensorflow/core/api_def:python_api_def",
|
||||
],
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -134,7 +134,7 @@ def gen_api_init_files(
|
||||
package_dep = "//tensorflow/python:no_contrib"):
|
||||
root_init_template_flag = ""
|
||||
if root_init_template:
|
||||
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
|
||||
root_init_template_flag = "--root_init_template=$(location " + root_init_template + ")"
|
||||
|
||||
api_gen_binary_target = "create_" + package + "_api"
|
||||
native.py_binary(
|
||||
@ -154,9 +154,8 @@ def gen_api_init_files(
|
||||
outs = output_files,
|
||||
cmd = (
|
||||
"$(location :" + api_gen_binary_target + ") " +
|
||||
root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"
|
||||
),
|
||||
root_init_template_flag + " --apidir=$(@D) --apiname=" + api_name + " --package=" + package + " $(OUTS)"),
|
||||
srcs = srcs,
|
||||
tools = [":" + api_gen_binary_target],
|
||||
tools = [":" + api_gen_binary_target ],
|
||||
visibility = ["//tensorflow:__pkg__"],
|
||||
)
|
||||
|
@ -24,27 +24,27 @@ load("@bazel_tools//tools/cpp:windows_cc_configure.bzl", "find_msvc_tool")
|
||||
load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "auto_configure_fail")
|
||||
|
||||
def _def_file_filter_configure_impl(repository_ctx):
|
||||
if repository_ctx.os.name.lower().find("windows") == -1:
|
||||
repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
|
||||
repository_ctx.file("def_file_filter.py", "")
|
||||
return
|
||||
vc_path = find_vc_path(repository_ctx)
|
||||
if vc_path == "visual-studio-not-found":
|
||||
auto_configure_fail("Visual C++ build tools not found on your machine")
|
||||
|
||||
undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe")
|
||||
if undname == None:
|
||||
auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path)
|
||||
undname_bin_path = undname.replace("\\", "\\\\")
|
||||
|
||||
repository_ctx.template(
|
||||
"def_file_filter.py",
|
||||
Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"),
|
||||
{
|
||||
"%{undname_bin_path}": undname_bin_path,
|
||||
},
|
||||
)
|
||||
if repository_ctx.os.name.lower().find("windows") == -1:
|
||||
repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
|
||||
repository_ctx.file("def_file_filter.py", "")
|
||||
return
|
||||
vc_path = find_vc_path(repository_ctx)
|
||||
if vc_path == "visual-studio-not-found":
|
||||
auto_configure_fail("Visual C++ build tools not found on your machine")
|
||||
|
||||
undname = find_msvc_tool(repository_ctx, vc_path, "undname.exe")
|
||||
if undname == None:
|
||||
auto_configure_fail("Couldn't find undname.exe under %s, please check your VC installation and set BAZEL_VC environment variable correctly." % vc_path)
|
||||
undname_bin_path = undname.replace("\\", "\\\\")
|
||||
|
||||
repository_ctx.template(
|
||||
"def_file_filter.py",
|
||||
Label("//tensorflow/tools/def_file_filter:def_file_filter.py.tpl"),
|
||||
{
|
||||
"%{undname_bin_path}": undname_bin_path,
|
||||
})
|
||||
repository_ctx.symlink(Label("//tensorflow/tools/def_file_filter:BUILD.tpl"), "BUILD")
|
||||
|
||||
|
||||
def_file_filter_configure = repository_rule(
|
||||
implementation = _def_file_filter_configure_impl,
|
||||
@ -55,6 +55,6 @@ def_file_filter_configure = repository_rule(
|
||||
"VS100COMNTOOLS",
|
||||
"VS110COMNTOOLS",
|
||||
"VS120COMNTOOLS",
|
||||
"VS140COMNTOOLS",
|
||||
"VS140COMNTOOLS"
|
||||
],
|
||||
)
|
||||
|
@ -4,66 +4,60 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||
|
||||
# Create a benchmark test target of a TensorFlow C++ test (tf_cc_*_test)
|
||||
def tf_cc_logged_benchmark(
|
||||
name = None,
|
||||
target = None,
|
||||
benchmarks = "..",
|
||||
tags = [],
|
||||
test_log_output_prefix = "",
|
||||
benchmark_type = "cpp_microbenchmark"):
|
||||
if not name:
|
||||
fail("Must provide a name")
|
||||
if not target:
|
||||
fail("Must provide a target")
|
||||
if (not ":" in target or
|
||||
not target.startswith("//") or
|
||||
target.endswith(":all") or
|
||||
target.endswith(".")):
|
||||
fail(" ".join((
|
||||
"Target must be a single well-defined test, e.g.,",
|
||||
"//path/to:test. Received: %s" % target,
|
||||
)))
|
||||
name=None,
|
||||
target=None,
|
||||
benchmarks="..",
|
||||
tags=[],
|
||||
test_log_output_prefix="",
|
||||
benchmark_type="cpp_microbenchmark"):
|
||||
if not name:
|
||||
fail("Must provide a name")
|
||||
if not target:
|
||||
fail("Must provide a target")
|
||||
if (not ":" in target
|
||||
or not target.startswith("//")
|
||||
or target.endswith(":all")
|
||||
or target.endswith(".")):
|
||||
fail(" ".join(("Target must be a single well-defined test, e.g.,",
|
||||
"//path/to:test. Received: %s" % target)))
|
||||
|
||||
all_tags = (
|
||||
depset(tags) + depset(
|
||||
["benchmark-test", "local", "manual", "regression-test"],
|
||||
)
|
||||
).to_list()
|
||||
all_tags = (
|
||||
depset(tags) + depset(
|
||||
["benchmark-test", "local", "manual", "regression-test"])).to_list()
|
||||
|
||||
tf_py_test(
|
||||
name = name,
|
||||
tags = all_tags,
|
||||
size = "large",
|
||||
srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
|
||||
args = [
|
||||
"--name=//%s:%s" % (native.package_name(), name),
|
||||
"--test_name=" + target,
|
||||
"--test_args=--benchmarks=%s" % benchmarks,
|
||||
"--benchmark_type=%s" % benchmark_type,
|
||||
],
|
||||
data = [
|
||||
target,
|
||||
],
|
||||
main = "run_and_gather_logs.py",
|
||||
additional_deps = [
|
||||
"//tensorflow/tools/test:run_and_gather_logs",
|
||||
],
|
||||
)
|
||||
tf_py_test(
|
||||
name = name,
|
||||
tags = all_tags,
|
||||
size = "large",
|
||||
srcs = ["//tensorflow/tools/test:run_and_gather_logs"],
|
||||
args = [
|
||||
"--name=//%s:%s" % (native.package_name(), name),
|
||||
"--test_name=" + target,
|
||||
"--test_args=--benchmarks=%s" % benchmarks,
|
||||
"--benchmark_type=%s" % benchmark_type,
|
||||
],
|
||||
data = [
|
||||
target,
|
||||
],
|
||||
main = "run_and_gather_logs.py",
|
||||
additional_deps = [
|
||||
"//tensorflow/tools/test:run_and_gather_logs"
|
||||
])
|
||||
|
||||
# Create a benchmark test target of a TensorFlow python test (*py_tests)
|
||||
def tf_py_logged_benchmark(
|
||||
name = None,
|
||||
target = None,
|
||||
benchmarks = "..",
|
||||
tags = [],
|
||||
test_log_output_prefix = ""):
|
||||
# For now generating a py benchmark is the same as generating a C++
|
||||
# benchmark target. In the future this may change, so we have
|
||||
# two macros just in case
|
||||
tf_cc_logged_benchmark(
|
||||
name = name,
|
||||
target = target,
|
||||
benchmarks = benchmarks,
|
||||
tags = tags,
|
||||
test_log_output_prefix = test_log_output_prefix,
|
||||
benchmark_type = "python_benchmark",
|
||||
)
|
||||
name=None,
|
||||
target=None,
|
||||
benchmarks="..",
|
||||
tags=[],
|
||||
test_log_output_prefix=""):
|
||||
# For now generating a py benchmark is the same as generating a C++
|
||||
# benchmark target. In the future this may change, so we have
|
||||
# two macros just in case
|
||||
tf_cc_logged_benchmark(
|
||||
name=name,
|
||||
target=target,
|
||||
benchmarks=benchmarks,
|
||||
tags=tags,
|
||||
test_log_output_prefix=test_log_output_prefix,
|
||||
benchmark_type="python_benchmark")
|
||||
|
@ -1,50 +1,48 @@
|
||||
""" Helpers to check minimum version of bazel."""
|
||||
|
||||
def _extract_version_number(bazel_version):
|
||||
"""Extracts the semantic version number from a version string
|
||||
"""Extracts the semantic version number from a version string
|
||||
|
||||
Args:
|
||||
bazel_version: the version string that begins with the semantic version
|
||||
e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash.
|
||||
Args:
|
||||
bazel_version: the version string that begins with the semantic version
|
||||
e.g. "1.2.3rc1 abc1234" where "abc1234" is a commit hash.
|
||||
|
||||
Returns:
|
||||
The semantic version string, like "1.2.3".
|
||||
"""
|
||||
for i in range(len(bazel_version)):
|
||||
c = bazel_version[i]
|
||||
if not (c.isdigit() or c == "."):
|
||||
return bazel_version[:i]
|
||||
return bazel_version
|
||||
Returns:
|
||||
The semantic version string, like "1.2.3".
|
||||
"""
|
||||
for i in range(len(bazel_version)):
|
||||
c = bazel_version[i]
|
||||
if not (c.isdigit() or c == "."):
|
||||
return bazel_version[:i]
|
||||
return bazel_version
|
||||
|
||||
# Parse the bazel version string from `native.bazel_version`.
|
||||
# e.g.
|
||||
# "0.10.0rc1 abc123d" => (0, 10, 0)
|
||||
# "0.3.0" => (0, 3, 0)
|
||||
def _parse_bazel_version(bazel_version):
|
||||
"""Parses a version string into a 3-tuple of ints
|
||||
"""Parses a version string into a 3-tuple of ints
|
||||
|
||||
int tuples can be compared directly using binary operators (<, >).
|
||||
int tuples can be compared directly using binary operators (<, >).
|
||||
|
||||
Args:
|
||||
bazel_version: the Bazel version string
|
||||
Args:
|
||||
bazel_version: the Bazel version string
|
||||
|
||||
Returns:
|
||||
An int 3-tuple of a (major, minor, patch) version.
|
||||
"""
|
||||
Returns:
|
||||
An int 3-tuple of a (major, minor, patch) version.
|
||||
"""
|
||||
|
||||
version = _extract_version_number(bazel_version)
|
||||
return tuple([int(n) for n in version.split(".")])
|
||||
version = _extract_version_number(bazel_version)
|
||||
return tuple([int(n) for n in version.split(".")])
|
||||
|
||||
def check_bazel_version_at_least(minimum_bazel_version):
|
||||
if "bazel_version" not in dir(native):
|
||||
fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version)
|
||||
elif not native.bazel_version:
|
||||
print("\nCurrent Bazel is not a release version, cannot check for compatibility.")
|
||||
print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version)
|
||||
return
|
||||
if "bazel_version" not in dir(native):
|
||||
fail("\nCurrent Bazel version is lower than 0.2.1, expected at least %s\n" % minimum_bazel_version)
|
||||
elif not native.bazel_version:
|
||||
print("\nCurrent Bazel is not a release version, cannot check for compatibility.")
|
||||
print("Make sure that you are running at least Bazel %s.\n" % minimum_bazel_version)
|
||||
return
|
||||
|
||||
if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version):
|
||||
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
|
||||
native.bazel_version,
|
||||
minimum_bazel_version,
|
||||
))
|
||||
if _parse_bazel_version(native.bazel_version) < _parse_bazel_version(minimum_bazel_version):
|
||||
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
|
||||
native.bazel_version, minimum_bazel_version))
|
||||
|
File diff suppressed because it is too large
Load Diff
54
third_party/android/android_configure.bzl
vendored
54
third_party/android/android_configure.bzl
vendored
@ -36,39 +36,33 @@ _ANDROID_NDK_REPO_TEMPLATE = """
|
||||
"""
|
||||
|
||||
def _android_autoconf_impl(repository_ctx):
|
||||
"""Implementation of the android_autoconf repository rule."""
|
||||
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
|
||||
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
|
||||
build_tools_version = repository_ctx.os.environ.get(
|
||||
_ANDROID_BUILD_TOOLS_VERSION,
|
||||
)
|
||||
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
|
||||
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
|
||||
"""Implementation of the android_autoconf repository rule."""
|
||||
sdk_home = repository_ctx.os.environ.get(_ANDROID_SDK_HOME)
|
||||
sdk_api_level = repository_ctx.os.environ.get(_ANDROID_SDK_API_VERSION)
|
||||
build_tools_version = repository_ctx.os.environ.get(
|
||||
_ANDROID_BUILD_TOOLS_VERSION)
|
||||
ndk_home = repository_ctx.os.environ.get(_ANDROID_NDK_HOME)
|
||||
ndk_api_level = repository_ctx.os.environ.get(_ANDROID_NDK_API_VERSION)
|
||||
|
||||
sdk_rule = "pass"
|
||||
if all([sdk_home, sdk_api_level, build_tools_version]):
|
||||
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
|
||||
sdk_home,
|
||||
sdk_api_level,
|
||||
build_tools_version,
|
||||
)
|
||||
sdk_rule = "pass"
|
||||
if all([sdk_home, sdk_api_level, build_tools_version]):
|
||||
sdk_rule = _ANDROID_SDK_REPO_TEMPLATE % (
|
||||
sdk_home, sdk_api_level, build_tools_version)
|
||||
|
||||
ndk_rule = "pass"
|
||||
if all([ndk_home, ndk_api_level]):
|
||||
ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
|
||||
ndk_rule = "pass"
|
||||
if all([ndk_home, ndk_api_level]):
|
||||
ndk_rule = _ANDROID_NDK_REPO_TEMPLATE % (ndk_home, ndk_api_level)
|
||||
|
||||
repository_ctx.template(
|
||||
"BUILD",
|
||||
Label("//third_party/android:android_configure.BUILD.tpl"),
|
||||
)
|
||||
repository_ctx.template(
|
||||
"android.bzl",
|
||||
Label("//third_party/android:android.bzl.tpl"),
|
||||
substitutions = {
|
||||
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
|
||||
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
|
||||
},
|
||||
)
|
||||
repository_ctx.template(
|
||||
"BUILD",
|
||||
Label("//third_party/android:android_configure.BUILD.tpl"))
|
||||
repository_ctx.template(
|
||||
"android.bzl",
|
||||
Label("//third_party/android:android.bzl.tpl"),
|
||||
substitutions={
|
||||
"MAYBE_ANDROID_SDK_REPOSITORY": sdk_rule,
|
||||
"MAYBE_ANDROID_NDK_REPOSITORY": ndk_rule,
|
||||
})
|
||||
|
||||
android_configure = repository_rule(
|
||||
implementation = _android_autoconf_impl,
|
||||
|
@ -7,16 +7,16 @@ _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG"
|
||||
_TF_NEED_CUDA = "TF_NEED_CUDA"
|
||||
|
||||
def _cc_clang_autoconf(repo_ctx):
|
||||
if repo_ctx.os.environ.get(_TF_DOWNLOAD_CLANG) != "1":
|
||||
return
|
||||
if repo_ctx.os.environ.get(_TF_NEED_CUDA) == "1":
|
||||
# Clang is handled separately for CUDA configs.
|
||||
# See cuda_configure.bzl for more details.
|
||||
return
|
||||
if repo_ctx.os.environ.get(_TF_DOWNLOAD_CLANG) != "1":
|
||||
return
|
||||
if repo_ctx.os.environ.get(_TF_NEED_CUDA) == "1":
|
||||
# Clang is handled separately for CUDA configs.
|
||||
# See cuda_configure.bzl for more details.
|
||||
return
|
||||
|
||||
download_clang(repo_ctx, out_folder = "extra_tools")
|
||||
overriden_tools = {"gcc": "extra_tools/bin/clang"}
|
||||
cc_autoconf_impl(repo_ctx, overriden_tools)
|
||||
download_clang(repo_ctx, out_folder='extra_tools')
|
||||
overriden_tools = {'gcc': 'extra_tools/bin/clang'}
|
||||
cc_autoconf_impl(repo_ctx, overriden_tools)
|
||||
|
||||
cc_download_clang_toolchain = repository_rule(
|
||||
environ = [
|
||||
|
90
third_party/clang_toolchain/download_clang.bzl
vendored
90
third_party/clang_toolchain/download_clang.bzl
vendored
@ -1,60 +1,54 @@
|
||||
""" Helpers to download a recent clang release."""
|
||||
|
||||
def _get_platform_folder(os_name):
|
||||
os_name = os_name.lower()
|
||||
if os_name.startswith("windows"):
|
||||
return "Win"
|
||||
if os_name.startswith("mac os"):
|
||||
return "Mac"
|
||||
if not os_name.startswith("linux"):
|
||||
fail("Unknown platform")
|
||||
return "Linux_x64"
|
||||
os_name = os_name.lower()
|
||||
if os_name.startswith('windows'):
|
||||
return 'Win'
|
||||
if os_name.startswith('mac os'):
|
||||
return 'Mac'
|
||||
if not os_name.startswith('linux'):
|
||||
fail('Unknown platform')
|
||||
return 'Linux_x64'
|
||||
|
||||
def _download_chromium_clang(
|
||||
repo_ctx,
|
||||
platform_folder,
|
||||
package_version,
|
||||
sha256,
|
||||
out_folder):
|
||||
cds_url = "https://commondatastorage.googleapis.com/chromium-browser-clang"
|
||||
cds_file = "clang-%s.tgz" % package_version
|
||||
cds_full_url = "{0}/{1}/{2}".format(cds_url, platform_folder, cds_file)
|
||||
repo_ctx.download_and_extract(cds_full_url, output = out_folder, sha256 = sha256)
|
||||
def _download_chromium_clang(repo_ctx, platform_folder, package_version, sha256,
|
||||
out_folder):
|
||||
cds_url = 'https://commondatastorage.googleapis.com/chromium-browser-clang'
|
||||
cds_file = 'clang-%s.tgz' % package_version
|
||||
cds_full_url = '{0}/{1}/{2}'.format(cds_url, platform_folder, cds_file)
|
||||
repo_ctx.download_and_extract(cds_full_url, output=out_folder, sha256=sha256)
|
||||
|
||||
def download_clang(repo_ctx, out_folder):
|
||||
""" Download a fresh clang release and put it into out_folder.
|
||||
""" Download a fresh clang release and put it into out_folder.
|
||||
|
||||
Clang itself will be located in 'out_folder/bin/clang'.
|
||||
We currently download one of the latest releases of clang by the
|
||||
Chromium project (see
|
||||
https://chromium.googlesource.com/chromium/src/+/master/docs/clang.md).
|
||||
Clang itself will be located in 'out_folder/bin/clang'.
|
||||
We currently download one of the latest releases of clang by the
|
||||
Chromium project (see
|
||||
https://chromium.googlesource.com/chromium/src/+/master/docs/clang.md).
|
||||
|
||||
Args:
|
||||
repo_ctx: An instance of repository_context object.
|
||||
out_folder: A folder to extract the compiler into.
|
||||
"""
|
||||
# TODO(ibiryukov): we currently download and extract some extra tools in the
|
||||
# clang release (e.g., sanitizers). We should probably remove the ones
|
||||
# we don't need and document the ones we want provide in addition to clang.
|
||||
Args:
|
||||
repo_ctx: An instance of repository_context object.
|
||||
out_folder: A folder to extract the compiler into.
|
||||
"""
|
||||
# TODO(ibiryukov): we currently download and extract some extra tools in the
|
||||
# clang release (e.g., sanitizers). We should probably remove the ones
|
||||
# we don't need and document the ones we want provide in addition to clang.
|
||||
|
||||
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
|
||||
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
|
||||
CLANG_REVISION = "335091"
|
||||
CLANG_SUB_REVISION = 1
|
||||
# Latest CLANG_REVISION and CLANG_SUB_REVISION of the Chromiums's release
|
||||
# can be found in https://chromium.googlesource.com/chromium/src/tools/clang/+/master/scripts/update.py
|
||||
CLANG_REVISION = '335091'
|
||||
CLANG_SUB_REVISION = 1
|
||||
|
||||
package_version = "%s-%s" % (CLANG_REVISION, CLANG_SUB_REVISION)
|
||||
package_version = '%s-%s' % (CLANG_REVISION, CLANG_SUB_REVISION)
|
||||
|
||||
checksums = {
|
||||
"Linux_x64": "17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f",
|
||||
"Mac": "9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468",
|
||||
"Win": "e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e",
|
||||
}
|
||||
checksums = {
|
||||
'Linux_x64':
|
||||
'17002b75293fccfdd175eacdc9ee47d97b58d7e98fef343384fbbef1b68ce99f',
|
||||
'Mac':
|
||||
'9351e46d28315daaa06a1eb55bd0370ed4aaeb693a2a3e82e48d2737d7723468',
|
||||
'Win':
|
||||
'e78a1e469224d6f6751b4df4374bf58893ac03900ec924e4c8264888ba4aeb1e',
|
||||
}
|
||||
|
||||
platform_folder = _get_platform_folder(repo_ctx.os.name)
|
||||
_download_chromium_clang(
|
||||
repo_ctx,
|
||||
platform_folder,
|
||||
package_version,
|
||||
checksums[platform_folder],
|
||||
out_folder,
|
||||
)
|
||||
platform_folder = _get_platform_folder(repo_ctx.os.name)
|
||||
_download_chromium_clang(repo_ctx, platform_folder, package_version,
|
||||
checksums[platform_folder], out_folder)
|
||||
|
10
third_party/common.bzl
vendored
10
third_party/common.bzl
vendored
@ -21,11 +21,11 @@
|
||||
# substitutions: A dictionary mapping strings to their substitutions
|
||||
|
||||
def template_rule_impl(ctx):
|
||||
ctx.template_action(
|
||||
template = ctx.file.src,
|
||||
output = ctx.outputs.out,
|
||||
substitutions = ctx.attr.substitutions,
|
||||
)
|
||||
ctx.template_action(
|
||||
template = ctx.file.src,
|
||||
output = ctx.outputs.out,
|
||||
substitutions = ctx.attr.substitutions,
|
||||
)
|
||||
|
||||
template_rule = rule(
|
||||
attrs = {
|
||||
|
361
third_party/flatbuffers/build_defs.bzl
vendored
361
third_party/flatbuffers/build_defs.bzl
vendored
@ -8,49 +8,66 @@ DEFAULT_FLATC_ARGS = [
|
||||
"--gen-object-api",
|
||||
]
|
||||
|
||||
def flatbuffer_library_public(
|
||||
name,
|
||||
srcs,
|
||||
outs,
|
||||
language_flag,
|
||||
out_prefix = "",
|
||||
includes = [],
|
||||
include_paths = [],
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
reflection_name = "",
|
||||
reflection_visiblity = None,
|
||||
output_to_bindir = False):
|
||||
"""Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
|
||||
def flatbuffer_library_public(name,
|
||||
srcs,
|
||||
outs,
|
||||
language_flag,
|
||||
out_prefix="",
|
||||
includes=[],
|
||||
include_paths=[],
|
||||
flatc_args=DEFAULT_FLATC_ARGS,
|
||||
reflection_name="",
|
||||
reflection_visiblity=None,
|
||||
output_to_bindir=False):
|
||||
'''Generates code files for reading/writing the given flatbuffers in the requested language using the public compiler.
|
||||
|
||||
Args:
|
||||
name: Rule name.
|
||||
srcs: Source .fbs files. Sent in order to the compiler.
|
||||
outs: Output files from flatc.
|
||||
language_flag: Target language flag. One of [-c, -j, -js].
|
||||
out_prefix: Prepend this path to the front of all generated files except on
|
||||
single source targets. Usually is a directory name.
|
||||
includes: Optional, list of filegroups of schemas that the srcs depend on.
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
flatc_args: Optional, list of additional arguments to pass to flatc.
|
||||
reflection_name: Optional, if set this will generate the flatbuffer
|
||||
reflection binaries for the schemas.
|
||||
reflection_visiblity: The visibility of the generated reflection Fileset.
|
||||
output_to_bindir: Passed to genrule for output to bin directory.
|
||||
Outs:
|
||||
filegroup(name): all generated source files.
|
||||
Fileset([reflection_name]): (Optional) all generated reflection binaries.
|
||||
"""
|
||||
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
|
||||
|
||||
# '$(@D)' when given a single source target will give the appropriate
|
||||
# directory. Appending 'out_prefix' is only necessary when given a build
|
||||
# target with multiple sources.
|
||||
output_directory = (
|
||||
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)")
|
||||
)
|
||||
genrule_cmd = " ".join([
|
||||
Args:
|
||||
name: Rule name.
|
||||
srcs: Source .fbs files. Sent in order to the compiler.
|
||||
outs: Output files from flatc.
|
||||
language_flag: Target language flag. One of [-c, -j, -js].
|
||||
out_prefix: Prepend this path to the front of all generated files except on
|
||||
single source targets. Usually is a directory name.
|
||||
includes: Optional, list of filegroups of schemas that the srcs depend on.
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
flatc_args: Optional, list of additional arguments to pass to flatc.
|
||||
reflection_name: Optional, if set this will generate the flatbuffer
|
||||
reflection binaries for the schemas.
|
||||
reflection_visiblity: The visibility of the generated reflection Fileset.
|
||||
output_to_bindir: Passed to genrule for output to bin directory.
|
||||
Outs:
|
||||
filegroup(name): all generated source files.
|
||||
Fileset([reflection_name]): (Optional) all generated reflection binaries.
|
||||
'''
|
||||
include_paths_cmd = ["-I %s" % (s) for s in include_paths]
|
||||
# '$(@D)' when given a single source target will give the appropriate
|
||||
# directory. Appending 'out_prefix' is only necessary when given a build
|
||||
# target with multiple sources.
|
||||
output_directory = (
|
||||
("-o $(@D)/%s" % (out_prefix)) if len(srcs) > 1 else ("-o $(@D)"))
|
||||
genrule_cmd = " ".join([
|
||||
"for f in $(SRCS); do",
|
||||
"$(location %s)" % (flatc_path),
|
||||
" ".join(flatc_args),
|
||||
" ".join(include_paths_cmd),
|
||||
language_flag,
|
||||
output_directory,
|
||||
"$$f;",
|
||||
"done",
|
||||
])
|
||||
native.genrule(
|
||||
name=name,
|
||||
srcs=srcs,
|
||||
outs=outs,
|
||||
output_to_bindir=output_to_bindir,
|
||||
tools=includes + [flatc_path,],
|
||||
cmd=genrule_cmd,
|
||||
message="Generating flatbuffer files for %s:" % (name),)
|
||||
if reflection_name:
|
||||
reflection_genrule_cmd = " ".join([
|
||||
"for f in $(SRCS); do",
|
||||
"$(location %s)" % (flatc_path),
|
||||
"-b --schema",
|
||||
" ".join(flatc_args),
|
||||
" ".join(include_paths_cmd),
|
||||
language_flag,
|
||||
@ -58,156 +75,122 @@ def flatbuffer_library_public(
|
||||
"$$f;",
|
||||
"done",
|
||||
])
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = srcs,
|
||||
outs = outs,
|
||||
output_to_bindir = output_to_bindir,
|
||||
tools = includes + [flatc_path],
|
||||
cmd = genrule_cmd,
|
||||
message = "Generating flatbuffer files for %s:" % (name),
|
||||
)
|
||||
if reflection_name:
|
||||
reflection_genrule_cmd = " ".join([
|
||||
"for f in $(SRCS); do",
|
||||
"$(location %s)" % (flatc_path),
|
||||
"-b --schema",
|
||||
" ".join(flatc_args),
|
||||
" ".join(include_paths_cmd),
|
||||
language_flag,
|
||||
output_directory,
|
||||
"$$f;",
|
||||
"done",
|
||||
])
|
||||
reflection_outs = [
|
||||
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1])
|
||||
for s in srcs
|
||||
]
|
||||
native.genrule(
|
||||
name = "%s_srcs" % reflection_name,
|
||||
srcs = srcs,
|
||||
outs = reflection_outs,
|
||||
output_to_bindir = output_to_bindir,
|
||||
tools = includes + [flatc_path],
|
||||
cmd = reflection_genrule_cmd,
|
||||
message = "Generating flatbuffer reflection binary for %s:" % (name),
|
||||
)
|
||||
native.Fileset(
|
||||
name = reflection_name,
|
||||
out = "%s_out" % reflection_name,
|
||||
entries = [
|
||||
native.FilesetEntry(files = reflection_outs),
|
||||
],
|
||||
visibility = reflection_visiblity,
|
||||
)
|
||||
|
||||
def flatbuffer_cc_library(
|
||||
name,
|
||||
srcs,
|
||||
srcs_filegroup_name = "",
|
||||
out_prefix = "",
|
||||
includes = [],
|
||||
include_paths = [],
|
||||
flatc_args = DEFAULT_FLATC_ARGS,
|
||||
visibility = None,
|
||||
srcs_filegroup_visibility = None,
|
||||
gen_reflections = False):
|
||||
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Args:
|
||||
name: Rule name.
|
||||
srcs: Source .fbs files. Sent in order to the compiler.
|
||||
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
|
||||
filegroup into the `includes` parameter of any other
|
||||
flatbuffer_cc_library that depends on this one's schemas.
|
||||
out_prefix: Prepend this path to the front of all generated files. Usually
|
||||
is a directory name.
|
||||
includes: Optional, list of filegroups of schemas that the srcs depend on.
|
||||
** SEE REMARKS BELOW **
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
flatc_args: Optional list of additional arguments to pass to flatc
|
||||
(e.g. --gen-mutable).
|
||||
visibility: The visibility of the generated cc_library. By default, use the
|
||||
default visibility of the project.
|
||||
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
|
||||
By default, use the value of the visibility parameter above.
|
||||
gen_reflections: Optional, if true this will generate the flatbuffer
|
||||
reflection binaries for the schemas.
|
||||
Outs:
|
||||
filegroup([name]_srcs): all generated .h files.
|
||||
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
|
||||
Other flatbuffer_cc_library's can pass this in for their `includes`
|
||||
parameter, if they depend on the schemas in this library.
|
||||
Fileset([name]_reflection): (Optional) all generated reflection binaries.
|
||||
cc_library([name]): library with sources and flatbuffers deps.
|
||||
|
||||
Remarks:
|
||||
** Because the genrule used to call flatc does not have any trivial way of
|
||||
computing the output list of files transitively generated by includes and
|
||||
--gen-includes (the default) being defined for flatc, the --gen-includes
|
||||
flag will not work as expected. The way around this is to add a dependency
|
||||
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
|
||||
For example you might define:
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "my_fbs",
|
||||
srcs = [ "schemas/foo.fbs" ],
|
||||
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
|
||||
)
|
||||
|
||||
In which foo.fbs includes a few files from the Fileset defined at
|
||||
//third_party/bazz:bazz_fbs_includes. When compiling the library that
|
||||
includes foo_generated.h, and therefore has my_fbs as a dependency, it
|
||||
will fail to find any of the bazz *_generated.h files unless you also
|
||||
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
|
||||
|
||||
cc_library(
|
||||
name = "my_lib",
|
||||
deps = [
|
||||
":my_fbs",
|
||||
"//third_party/bazz:bazz_fbs"
|
||||
],
|
||||
)
|
||||
|
||||
Happy dependent Flatbuffering!
|
||||
'''
|
||||
output_headers = [
|
||||
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1])
|
||||
for s in srcs
|
||||
reflection_outs = [
|
||||
(out_prefix + "%s.bfbs") % (s.replace(".fbs", "").split("/")[-1]) for s in srcs
|
||||
]
|
||||
reflection_name = "%s_reflection" % name if gen_reflections else ""
|
||||
|
||||
flatbuffer_library_public(
|
||||
name = "%s_srcs" % (name),
|
||||
srcs = srcs,
|
||||
outs = output_headers,
|
||||
language_flag = "-c",
|
||||
out_prefix = out_prefix,
|
||||
includes = includes,
|
||||
include_paths = include_paths,
|
||||
flatc_args = flatc_args,
|
||||
reflection_name = reflection_name,
|
||||
reflection_visiblity = visibility,
|
||||
)
|
||||
native.cc_library(
|
||||
name = name,
|
||||
hdrs = output_headers,
|
||||
srcs = output_headers,
|
||||
features = [
|
||||
"-parse_headers",
|
||||
native.genrule(
|
||||
name= "%s_srcs" % reflection_name,
|
||||
srcs=srcs,
|
||||
outs=reflection_outs,
|
||||
output_to_bindir=output_to_bindir,
|
||||
tools=includes + [flatc_path,],
|
||||
cmd=reflection_genrule_cmd,
|
||||
message="Generating flatbuffer reflection binary for %s:" % (name),)
|
||||
native.Fileset(
|
||||
name=reflection_name,
|
||||
out="%s_out" % reflection_name,
|
||||
entries=[
|
||||
native.FilesetEntry(files=reflection_outs),
|
||||
],
|
||||
deps = [
|
||||
"@flatbuffers//:runtime_cc",
|
||||
],
|
||||
includes = ["."],
|
||||
linkstatic = 1,
|
||||
visibility = visibility,
|
||||
visibility=reflection_visiblity
|
||||
)
|
||||
|
||||
# A filegroup for the `srcs`. That is, all the schema files for this
|
||||
# Flatbuffer set.
|
||||
native.filegroup(
|
||||
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
|
||||
srcs = srcs,
|
||||
visibility = srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility,
|
||||
)
|
||||
|
||||
def flatbuffer_cc_library(name, srcs, srcs_filegroup_name="",
|
||||
out_prefix="", includes=[], include_paths=[],
|
||||
flatc_args=DEFAULT_FLATC_ARGS,
|
||||
visibility=None, srcs_filegroup_visibility=None,
|
||||
gen_reflections=False):
|
||||
'''A cc_library with the generated reader/writers for the given flatbuffer definitions.
|
||||
|
||||
Args:
|
||||
name: Rule name.
|
||||
srcs: Source .fbs files. Sent in order to the compiler.
|
||||
srcs_filegroup_name: Name of the output filegroup that holds srcs. Pass this
|
||||
filegroup into the `includes` parameter of any other
|
||||
flatbuffer_cc_library that depends on this one's schemas.
|
||||
out_prefix: Prepend this path to the front of all generated files. Usually
|
||||
is a directory name.
|
||||
includes: Optional, list of filegroups of schemas that the srcs depend on.
|
||||
** SEE REMARKS BELOW **
|
||||
include_paths: Optional, list of paths the includes files can be found in.
|
||||
flatc_args: Optional list of additional arguments to pass to flatc
|
||||
(e.g. --gen-mutable).
|
||||
visibility: The visibility of the generated cc_library. By default, use the
|
||||
default visibility of the project.
|
||||
srcs_filegroup_visibility: The visibility of the generated srcs filegroup.
|
||||
By default, use the value of the visibility parameter above.
|
||||
gen_reflections: Optional, if true this will generate the flatbuffer
|
||||
reflection binaries for the schemas.
|
||||
Outs:
|
||||
filegroup([name]_srcs): all generated .h files.
|
||||
filegroup(srcs_filegroup_name if specified, or [name]_includes if not):
|
||||
Other flatbuffer_cc_library's can pass this in for their `includes`
|
||||
parameter, if they depend on the schemas in this library.
|
||||
Fileset([name]_reflection): (Optional) all generated reflection binaries.
|
||||
cc_library([name]): library with sources and flatbuffers deps.
|
||||
|
||||
Remarks:
|
||||
** Because the genrule used to call flatc does not have any trivial way of
|
||||
computing the output list of files transitively generated by includes and
|
||||
--gen-includes (the default) being defined for flatc, the --gen-includes
|
||||
flag will not work as expected. The way around this is to add a dependency
|
||||
to the flatbuffer_cc_library defined alongside the flatc included Fileset.
|
||||
For example you might define:
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "my_fbs",
|
||||
srcs = [ "schemas/foo.fbs" ],
|
||||
includes = [ "//third_party/bazz:bazz_fbs_includes" ],
|
||||
)
|
||||
|
||||
In which foo.fbs includes a few files from the Fileset defined at
|
||||
//third_party/bazz:bazz_fbs_includes. When compiling the library that
|
||||
includes foo_generated.h, and therefore has my_fbs as a dependency, it
|
||||
will fail to find any of the bazz *_generated.h files unless you also
|
||||
add bazz's flatbuffer_cc_library to your own dependency list, e.g.:
|
||||
|
||||
cc_library(
|
||||
name = "my_lib",
|
||||
deps = [
|
||||
":my_fbs",
|
||||
"//third_party/bazz:bazz_fbs"
|
||||
],
|
||||
)
|
||||
|
||||
Happy dependent Flatbuffering!
|
||||
'''
|
||||
output_headers = [
|
||||
(out_prefix + "%s_generated.h") % (s.replace(".fbs", "").split("/")[-1]) for s in srcs
|
||||
]
|
||||
reflection_name = "%s_reflection" % name if gen_reflections else ""
|
||||
|
||||
flatbuffer_library_public(name="%s_srcs" % (name),
|
||||
srcs=srcs,
|
||||
outs=output_headers,
|
||||
language_flag="-c",
|
||||
out_prefix=out_prefix,
|
||||
includes=includes,
|
||||
include_paths=include_paths,
|
||||
flatc_args=flatc_args,
|
||||
reflection_name=reflection_name,
|
||||
reflection_visiblity=visibility,)
|
||||
native.cc_library(name=name,
|
||||
hdrs=output_headers,
|
||||
srcs=output_headers,
|
||||
features=[
|
||||
"-parse_headers",
|
||||
],
|
||||
deps=[
|
||||
"@flatbuffers//:runtime_cc",
|
||||
],
|
||||
includes=["."],
|
||||
linkstatic=1,
|
||||
visibility=visibility)
|
||||
|
||||
# A filegroup for the `srcs`. That is, all the schema files for this
|
||||
# Flatbuffer set.
|
||||
native.filegroup(
|
||||
name = srcs_filegroup_name if srcs_filegroup_name else "%s_includes" % (name),
|
||||
srcs = srcs,
|
||||
visibility=srcs_filegroup_visibility if srcs_filegroup_visibility != None else visibility)
|
||||
|
185
third_party/llvm/llvm.bzl
vendored
185
third_party/llvm/llvm.bzl
vendored
@ -8,114 +8,102 @@ correctly understood by the build system.
|
||||
"""
|
||||
|
||||
def gentbl(name, tblgen, td_file, td_srcs, tbl_outs, library = True, **kwargs):
|
||||
"""gentbl() generates tabular code from a table definition file.
|
||||
"""gentbl() generates tabular code from a table definition file.
|
||||
|
||||
Args:
|
||||
name: The name of the build rule for use in dependencies.
|
||||
tblgen: The binary used to produce the output.
|
||||
td_file: The primary table definitions file.
|
||||
td_srcs: A list of table definition files included transitively.
|
||||
tbl_outs: A list of tuples (opts, out), where each opts is a string of
|
||||
options passed to tblgen, and the out is the corresponding output file
|
||||
produced.
|
||||
library: Whether to bundle the generated files into a library.
|
||||
**kwargs: Keyword arguments to pass to subsidiary cc_library() rule.
|
||||
"""
|
||||
if td_file not in td_srcs:
|
||||
td_srcs += [td_file]
|
||||
includes = []
|
||||
for (opts, out) in tbl_outs:
|
||||
outdir = out[:out.rindex("/")]
|
||||
if outdir not in includes:
|
||||
includes.append(outdir)
|
||||
rule_suffix = "_".join(opts.replace("-", "_").replace("=", "_").split(" "))
|
||||
native.genrule(
|
||||
name = "%s_%s_genrule" % (name, rule_suffix),
|
||||
srcs = td_srcs,
|
||||
outs = [out],
|
||||
tools = [tblgen],
|
||||
message = "Generating code from table: %s" % td_file,
|
||||
cmd = (("$(location %s) " + "-I external/llvm/include " +
|
||||
"-I external/llvm/tools/clang/include " +
|
||||
"-I $$(dirname $(location %s)) " + "%s $(location %s) -o $@") % (
|
||||
tblgen,
|
||||
td_file,
|
||||
opts,
|
||||
td_file,
|
||||
)),
|
||||
)
|
||||
|
||||
# For now, all generated files can be assumed to comprise public interfaces.
|
||||
# If this is not true, you should specify library = False
|
||||
# and list the generated '.inc' files in "srcs".
|
||||
if library:
|
||||
native.cc_library(
|
||||
name = name,
|
||||
textual_hdrs = [f for (_, f) in tbl_outs],
|
||||
includes = includes,
|
||||
**kwargs
|
||||
)
|
||||
Args:
|
||||
name: The name of the build rule for use in dependencies.
|
||||
tblgen: The binary used to produce the output.
|
||||
td_file: The primary table definitions file.
|
||||
td_srcs: A list of table definition files included transitively.
|
||||
tbl_outs: A list of tuples (opts, out), where each opts is a string of
|
||||
options passed to tblgen, and the out is the corresponding output file
|
||||
produced.
|
||||
library: Whether to bundle the generated files into a library.
|
||||
**kwargs: Keyword arguments to pass to subsidiary cc_library() rule.
|
||||
"""
|
||||
if td_file not in td_srcs:
|
||||
td_srcs += [td_file]
|
||||
includes = []
|
||||
for (opts, out) in tbl_outs:
|
||||
outdir = out[:out.rindex("/")]
|
||||
if outdir not in includes:
|
||||
includes.append(outdir)
|
||||
rule_suffix = "_".join(opts.replace("-", "_").replace("=", "_").split(" "))
|
||||
native.genrule(
|
||||
name="%s_%s_genrule" % (name, rule_suffix),
|
||||
srcs=td_srcs,
|
||||
outs=[out],
|
||||
tools=[tblgen],
|
||||
message="Generating code from table: %s" % td_file,
|
||||
cmd=(("$(location %s) " + "-I external/llvm/include " +
|
||||
"-I external/llvm/tools/clang/include " +
|
||||
"-I $$(dirname $(location %s)) " + "%s $(location %s) -o $@") % (
|
||||
tblgen, td_file, opts, td_file)))
|
||||
# For now, all generated files can be assumed to comprise public interfaces.
|
||||
# If this is not true, you should specify library = False
|
||||
# and list the generated '.inc' files in "srcs".
|
||||
if library:
|
||||
native.cc_library(name=name, textual_hdrs=[f for (_, f) in tbl_outs],
|
||||
includes=includes, **kwargs)
|
||||
|
||||
def llvm_target_cmake_vars(native_arch, target_triple):
|
||||
return {
|
||||
"LLVM_HOST_TRIPLE": target_triple,
|
||||
"LLVM_DEFAULT_TARGET_TRIPLE": target_triple,
|
||||
"LLVM_NATIVE_ARCH": native_arch,
|
||||
}
|
||||
return {
|
||||
"LLVM_HOST_TRIPLE": target_triple,
|
||||
"LLVM_DEFAULT_TARGET_TRIPLE": target_triple,
|
||||
"LLVM_NATIVE_ARCH": native_arch,
|
||||
}
|
||||
|
||||
def _quote(s):
|
||||
"""Quotes the given string for use in a shell command.
|
||||
"""Quotes the given string for use in a shell command.
|
||||
|
||||
This function double-quotes the given string (in case it contains spaces or
|
||||
other special characters) and escapes any special characters (dollar signs,
|
||||
double-quotes, and backslashes) that may be present.
|
||||
This function double-quotes the given string (in case it contains spaces or
|
||||
other special characters) and escapes any special characters (dollar signs,
|
||||
double-quotes, and backslashes) that may be present.
|
||||
|
||||
Args:
|
||||
s: The string to quote.
|
||||
Returns:
|
||||
An escaped and quoted version of the string that can be passed to a shell
|
||||
command.
|
||||
"""
|
||||
return ('"' +
|
||||
s.replace("\\", "\\\\").replace("$", "\\$").replace('"', '\\"') +
|
||||
'"')
|
||||
Args:
|
||||
s: The string to quote.
|
||||
Returns:
|
||||
An escaped and quoted version of the string that can be passed to a shell
|
||||
command.
|
||||
"""
|
||||
return ('"' +
|
||||
s.replace("\\", "\\\\").replace("$", "\\$").replace('"', '\\"') +
|
||||
'"')
|
||||
|
||||
def cmake_var_string(cmake_vars):
|
||||
"""Converts a dictionary to an input suitable for expand_cmake_vars.
|
||||
"""Converts a dictionary to an input suitable for expand_cmake_vars.
|
||||
|
||||
Ideally we would jist stringify in the expand_cmake_vars() rule, but select()
|
||||
interacts badly with genrules.
|
||||
Ideally we would jist stringify in the expand_cmake_vars() rule, but select()
|
||||
interacts badly with genrules.
|
||||
|
||||
TODO(phawkins): replace the genrule() with native rule and delete this rule.
|
||||
TODO(phawkins): replace the genrule() with native rule and delete this rule.
|
||||
|
||||
Args:
|
||||
cmake_vars: a dictionary with string keys and values that are convertable to
|
||||
strings.
|
||||
"""
|
||||
return " ".join([
|
||||
_quote("{}={}".format(k, str(v)))
|
||||
for (k, v) in cmake_vars.items()
|
||||
])
|
||||
Args:
|
||||
cmake_vars: a dictionary with string keys and values that are convertable to
|
||||
strings.
|
||||
"""
|
||||
return " ".join([_quote("{}={}".format(k, str(v)))
|
||||
for (k, v) in cmake_vars.items()])
|
||||
|
||||
def expand_cmake_vars(name, src, dst, cmake_vars):
|
||||
"""Expands #cmakedefine, #cmakedefine01, and CMake variables in a text file.
|
||||
"""Expands #cmakedefine, #cmakedefine01, and CMake variables in a text file.
|
||||
|
||||
Args:
|
||||
name: the name of the rule
|
||||
src: the input of the rule
|
||||
dst: the output of the rule
|
||||
cmake_vars: a string containing the CMake variables, as generated by
|
||||
cmake_var_string.
|
||||
"""
|
||||
expand_cmake_vars_tool = Label("@org_tensorflow//third_party/llvm:expand_cmake_vars")
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [src],
|
||||
tools = [expand_cmake_vars_tool],
|
||||
outs = [dst],
|
||||
cmd = ("$(location {}) ".format(expand_cmake_vars_tool) + cmake_vars +
|
||||
"< $< > $@"),
|
||||
)
|
||||
Args:
|
||||
name: the name of the rule
|
||||
src: the input of the rule
|
||||
dst: the output of the rule
|
||||
cmake_vars: a string containing the CMake variables, as generated by
|
||||
cmake_var_string.
|
||||
"""
|
||||
expand_cmake_vars_tool = Label("@org_tensorflow//third_party/llvm:expand_cmake_vars")
|
||||
native.genrule(
|
||||
name = name,
|
||||
srcs = [src],
|
||||
tools = [expand_cmake_vars_tool],
|
||||
outs = [dst],
|
||||
cmd = ("$(location {}) ".format(expand_cmake_vars_tool) + cmake_vars +
|
||||
"< $< > $@")
|
||||
)
|
||||
|
||||
# TODO(phawkins): the set of CMake variables was hardcoded for expediency.
|
||||
# However, we should really detect many of these via configure-time tests.
|
||||
@ -225,18 +213,17 @@ darwin_cmake_vars = {
|
||||
llvm_all_cmake_vars = select({
|
||||
"@org_tensorflow//tensorflow:darwin": cmake_var_string(
|
||||
cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") +
|
||||
darwin_cmake_vars,
|
||||
),
|
||||
darwin_cmake_vars),
|
||||
"@org_tensorflow//tensorflow:linux_ppc64le": cmake_var_string(
|
||||
cmake_vars +
|
||||
llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") +
|
||||
linux_cmake_vars,
|
||||
),
|
||||
"//conditions:default": cmake_var_string(
|
||||
cmake_vars +
|
||||
llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +
|
||||
linux_cmake_vars,
|
||||
),
|
||||
cmake_vars +
|
||||
llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +
|
||||
linux_cmake_vars),
|
||||
|
||||
})
|
||||
|
||||
LLVM_LINKOPTS = ["-ldl", "-lm", "-lpthread"]
|
||||
|
49
third_party/mkl/build_defs.bzl
vendored
49
third_party/mkl/build_defs.bzl
vendored
@ -8,8 +8,10 @@ mkl_repository depends on the following environment variables:
|
||||
* `TF_MKL_ROOT`: The root folder where a copy of libmkl is located.
|
||||
"""
|
||||
|
||||
|
||||
_TF_MKL_ROOT = "TF_MKL_ROOT"
|
||||
|
||||
|
||||
def if_mkl(if_true, if_false = []):
|
||||
"""Shorthand for select()'ing on whether we're building with MKL.
|
||||
|
||||
@ -19,7 +21,7 @@ def if_mkl(if_true, if_false = []):
|
||||
"""
|
||||
return select({
|
||||
str(Label("//third_party/mkl:using_mkl")): if_true,
|
||||
"//conditions:default": if_false,
|
||||
"//conditions:default": if_false
|
||||
})
|
||||
|
||||
def if_mkl_lnx_x64(if_true, if_false = []):
|
||||
@ -31,34 +33,37 @@ def if_mkl_lnx_x64(if_true, if_false = []):
|
||||
"""
|
||||
return select({
|
||||
str(Label("//third_party/mkl:using_mkl_lnx_x64")): if_true,
|
||||
"//conditions:default": if_false,
|
||||
"//conditions:default": if_false
|
||||
})
|
||||
|
||||
|
||||
def _enable_local_mkl(repository_ctx):
|
||||
return _TF_MKL_ROOT in repository_ctx.os.environ
|
||||
return _TF_MKL_ROOT in repository_ctx.os.environ
|
||||
|
||||
|
||||
def _mkl_autoconf_impl(repository_ctx):
|
||||
"""Implementation of the local_mkl_autoconf repository rule."""
|
||||
"""Implementation of the local_mkl_autoconf repository rule."""
|
||||
|
||||
if _enable_local_mkl(repository_ctx):
|
||||
# Symlink lib and include local folders.
|
||||
mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT]
|
||||
mkl_lib_path = "%s/lib" % mkl_root
|
||||
repository_ctx.symlink(mkl_lib_path, "lib")
|
||||
mkl_include_path = "%s/include" % mkl_root
|
||||
repository_ctx.symlink(mkl_include_path, "include")
|
||||
mkl_license_path = "%s/license.txt" % mkl_root
|
||||
repository_ctx.symlink(mkl_license_path, "license.txt")
|
||||
else:
|
||||
# setup remote mkl repository.
|
||||
repository_ctx.download_and_extract(
|
||||
repository_ctx.attr.urls,
|
||||
sha256 = repository_ctx.attr.sha256,
|
||||
stripPrefix = repository_ctx.attr.strip_prefix,
|
||||
)
|
||||
if _enable_local_mkl(repository_ctx):
|
||||
# Symlink lib and include local folders.
|
||||
mkl_root = repository_ctx.os.environ[_TF_MKL_ROOT]
|
||||
mkl_lib_path = "%s/lib" % mkl_root
|
||||
repository_ctx.symlink(mkl_lib_path, "lib")
|
||||
mkl_include_path = "%s/include" % mkl_root
|
||||
repository_ctx.symlink(mkl_include_path, "include")
|
||||
mkl_license_path = "%s/license.txt" % mkl_root
|
||||
repository_ctx.symlink(mkl_license_path, "license.txt")
|
||||
else:
|
||||
# setup remote mkl repository.
|
||||
repository_ctx.download_and_extract(
|
||||
repository_ctx.attr.urls,
|
||||
sha256=repository_ctx.attr.sha256,
|
||||
stripPrefix=repository_ctx.attr.strip_prefix,
|
||||
)
|
||||
|
||||
# Also setup BUILD file.
|
||||
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
|
||||
|
||||
# Also setup BUILD file.
|
||||
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
|
||||
|
||||
mkl_repository = repository_rule(
|
||||
implementation = _mkl_autoconf_impl,
|
||||
|
10
third_party/mpi/mpi.bzl
vendored
10
third_party/mpi/mpi.bzl
vendored
@ -2,16 +2,16 @@
|
||||
#based on the configuration options return one or the other
|
||||
|
||||
def mpi_hdr():
|
||||
MPI_LIB_IS_OPENMPI = True
|
||||
hdrs = []
|
||||
MPI_LIB_IS_OPENMPI=True
|
||||
hdrs = []
|
||||
if MPI_LIB_IS_OPENMPI:
|
||||
hdrs = ["mpi.h", "mpi_portable_platform.h"] #When using OpenMPI
|
||||
hdrs = ["mpi.h", "mpi_portable_platform.h"] #When using OpenMPI
|
||||
else:
|
||||
hdrs = ["mpi.h", "mpio.h", "mpicxx.h"] #When using MVAPICH
|
||||
hdrs = ["mpi.h", "mpio.h", "mpicxx.h"] #When using MVAPICH
|
||||
return hdrs
|
||||
|
||||
def if_mpi(if_true, if_false = []):
|
||||
return select({
|
||||
"//tensorflow:with_mpi_support": if_true,
|
||||
"//conditions:default": if_false,
|
||||
"//conditions:default": if_false
|
||||
})
|
||||
|
126
third_party/repo.bzl
vendored
126
third_party/repo.bzl
vendored
@ -19,98 +19,90 @@ _SINGLE_URL_WHITELIST = depset([
|
||||
])
|
||||
|
||||
def _is_windows(ctx):
|
||||
return ctx.os.name.lower().find("windows") != -1
|
||||
return ctx.os.name.lower().find("windows") != -1
|
||||
|
||||
def _wrap_bash_cmd(ctx, cmd):
|
||||
if _is_windows(ctx):
|
||||
bazel_sh = _get_env_var(ctx, "BAZEL_SH")
|
||||
if not bazel_sh:
|
||||
fail("BAZEL_SH environment variable is not set")
|
||||
cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
|
||||
return cmd
|
||||
if _is_windows(ctx):
|
||||
bazel_sh = _get_env_var(ctx, "BAZEL_SH")
|
||||
if not bazel_sh:
|
||||
fail("BAZEL_SH environment variable is not set")
|
||||
cmd = [bazel_sh, "-l", "-c", " ".join(cmd)]
|
||||
return cmd
|
||||
|
||||
def _get_env_var(ctx, name):
|
||||
if name in ctx.os.environ:
|
||||
return ctx.os.environ[name]
|
||||
else:
|
||||
return None
|
||||
if name in ctx.os.environ:
|
||||
return ctx.os.environ[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
# Executes specified command with arguments and calls 'fail' if it exited with
|
||||
# non-zero code
|
||||
def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
|
||||
result = repo_ctx.execute(cmd_and_args, timeout = 10)
|
||||
if result.return_code != 0:
|
||||
fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n" +
|
||||
"Stderr: {3}").format(
|
||||
" ".join(cmd_and_args),
|
||||
result.return_code,
|
||||
result.stdout,
|
||||
result.stderr,
|
||||
))
|
||||
result = repo_ctx.execute(cmd_and_args, timeout=10)
|
||||
if result.return_code != 0:
|
||||
fail(("Non-zero return code({1}) when executing '{0}':\n" + "Stdout: {2}\n"
|
||||
+ "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
|
||||
result.stdout, result.stderr))
|
||||
|
||||
def _repos_are_siblings():
|
||||
return Label("@foo//bar").workspace_root.startswith("../")
|
||||
return Label("@foo//bar").workspace_root.startswith("../")
|
||||
|
||||
# Apply a patch_file to the repository root directory
|
||||
# Runs 'patch -p1'
|
||||
def _apply_patch(ctx, patch_file):
|
||||
# Don't check patch on Windows, because patch is only available under bash.
|
||||
if not _is_windows(ctx) and not ctx.which("patch"):
|
||||
fail("patch command is not found, please install it")
|
||||
cmd = _wrap_bash_cmd(
|
||||
ctx,
|
||||
["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)],
|
||||
)
|
||||
_execute_and_check_ret_code(ctx, cmd)
|
||||
# Don't check patch on Windows, because patch is only available under bash.
|
||||
if not _is_windows(ctx) and not ctx.which("patch"):
|
||||
fail("patch command is not found, please install it")
|
||||
cmd = _wrap_bash_cmd(
|
||||
ctx, ["patch", "-p1", "-d", ctx.path("."), "-i", ctx.path(patch_file)])
|
||||
_execute_and_check_ret_code(ctx, cmd)
|
||||
|
||||
def _apply_delete(ctx, paths):
|
||||
for path in paths:
|
||||
if path.startswith("/"):
|
||||
fail("refusing to rm -rf path starting with '/': " + path)
|
||||
if ".." in path:
|
||||
fail("refusing to rm -rf path containing '..': " + path)
|
||||
cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
|
||||
_execute_and_check_ret_code(ctx, cmd)
|
||||
for path in paths:
|
||||
if path.startswith("/"):
|
||||
fail("refusing to rm -rf path starting with '/': " + path)
|
||||
if ".." in path:
|
||||
fail("refusing to rm -rf path containing '..': " + path)
|
||||
cmd = _wrap_bash_cmd(ctx, ["rm", "-rf"] + [ctx.path(path) for path in paths])
|
||||
_execute_and_check_ret_code(ctx, cmd)
|
||||
|
||||
def _tf_http_archive(ctx):
|
||||
if ("mirror.bazel.build" not in ctx.attr.urls[0] and
|
||||
(len(ctx.attr.urls) < 2 and
|
||||
ctx.attr.name not in _SINGLE_URL_WHITELIST)):
|
||||
fail("tf_http_archive(urls) must have redundant URLs. The " +
|
||||
"mirror.bazel.build URL must be present and it must come first. " +
|
||||
"Even if you don't have permission to mirror the file, please " +
|
||||
"put the correctly formatted mirror URL there anyway, because " +
|
||||
"someone will come along shortly thereafter and mirror the file.")
|
||||
ctx.download_and_extract(
|
||||
ctx.attr.urls,
|
||||
"",
|
||||
ctx.attr.sha256,
|
||||
ctx.attr.type,
|
||||
ctx.attr.strip_prefix,
|
||||
)
|
||||
if ctx.attr.delete:
|
||||
_apply_delete(ctx, ctx.attr.delete)
|
||||
if ctx.attr.patch_file != None:
|
||||
_apply_patch(ctx, ctx.attr.patch_file)
|
||||
if ctx.attr.build_file != None:
|
||||
# Use BUILD.bazel to avoid conflict with third party projects with
|
||||
# BUILD or build (directory) underneath.
|
||||
ctx.template("BUILD.bazel", ctx.attr.build_file, {
|
||||
"%prefix%": ".." if _repos_are_siblings() else "external",
|
||||
}, False)
|
||||
if ("mirror.bazel.build" not in ctx.attr.urls[0] and
|
||||
(len(ctx.attr.urls) < 2 and
|
||||
ctx.attr.name not in _SINGLE_URL_WHITELIST)):
|
||||
fail("tf_http_archive(urls) must have redundant URLs. The " +
|
||||
"mirror.bazel.build URL must be present and it must come first. " +
|
||||
"Even if you don't have permission to mirror the file, please " +
|
||||
"put the correctly formatted mirror URL there anyway, because " +
|
||||
"someone will come along shortly thereafter and mirror the file.")
|
||||
ctx.download_and_extract(
|
||||
ctx.attr.urls,
|
||||
"",
|
||||
ctx.attr.sha256,
|
||||
ctx.attr.type,
|
||||
ctx.attr.strip_prefix)
|
||||
if ctx.attr.delete:
|
||||
_apply_delete(ctx, ctx.attr.delete)
|
||||
if ctx.attr.patch_file != None:
|
||||
_apply_patch(ctx, ctx.attr.patch_file)
|
||||
if ctx.attr.build_file != None:
|
||||
# Use BUILD.bazel to avoid conflict with third party projects with
|
||||
# BUILD or build (directory) underneath.
|
||||
ctx.template("BUILD.bazel", ctx.attr.build_file, {
|
||||
"%prefix%": ".." if _repos_are_siblings() else "external",
|
||||
}, False)
|
||||
|
||||
tf_http_archive = repository_rule(
|
||||
implementation = _tf_http_archive,
|
||||
attrs = {
|
||||
"sha256": attr.string(mandatory = True),
|
||||
"urls": attr.string_list(mandatory = True, allow_empty = False),
|
||||
implementation=_tf_http_archive,
|
||||
attrs={
|
||||
"sha256": attr.string(mandatory=True),
|
||||
"urls": attr.string_list(mandatory=True, allow_empty=False),
|
||||
"strip_prefix": attr.string(),
|
||||
"type": attr.string(),
|
||||
"delete": attr.string_list(),
|
||||
"patch_file": attr.label(),
|
||||
"build_file": attr.label(),
|
||||
},
|
||||
)
|
||||
})
|
||||
"""Downloads and creates Bazel repos for dependencies.
|
||||
|
||||
This is a swappable replacement for both http_archive() and
|
||||
|
317
third_party/sycl/sycl_configure.bzl
vendored
317
third_party/sycl/sycl_configure.bzl
vendored
@ -11,124 +11,122 @@
|
||||
"""
|
||||
|
||||
_HOST_CXX_COMPILER = "HOST_CXX_COMPILER"
|
||||
_HOST_C_COMPILER = "HOST_C_COMPILER"
|
||||
_HOST_C_COMPILER= "HOST_C_COMPILER"
|
||||
_COMPUTECPP_TOOLKIT_PATH = "COMPUTECPP_TOOLKIT_PATH"
|
||||
_TRISYCL_INCLUDE_DIR = "TRISYCL_INCLUDE_DIR"
|
||||
_PYTHON_LIB_PATH = "PYTHON_LIB_PATH"
|
||||
|
||||
def _enable_sycl(repository_ctx):
|
||||
if "TF_NEED_OPENCL_SYCL" in repository_ctx.os.environ:
|
||||
enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL_SYCL"].strip()
|
||||
return enable_sycl == "1"
|
||||
return False
|
||||
if "TF_NEED_OPENCL_SYCL" in repository_ctx.os.environ:
|
||||
enable_sycl = repository_ctx.os.environ["TF_NEED_OPENCL_SYCL"].strip()
|
||||
return enable_sycl == "1"
|
||||
return False
|
||||
|
||||
def _enable_compute_cpp(repository_ctx):
|
||||
return _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ
|
||||
return _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ
|
||||
|
||||
def auto_configure_fail(msg):
|
||||
"""Output failure message when auto configuration fails."""
|
||||
red = "\033[0;31m"
|
||||
no_color = "\033[0m"
|
||||
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
|
||||
|
||||
"""Output failure message when auto configuration fails."""
|
||||
red = "\033[0;31m"
|
||||
no_color = "\033[0m"
|
||||
fail("\n%sAuto-Configuration Error:%s %s\n" % (red, no_color, msg))
|
||||
# END cc_configure common functions (see TODO above).
|
||||
|
||||
def find_c(repository_ctx):
|
||||
"""Find host C compiler."""
|
||||
c_name = "gcc"
|
||||
if _HOST_C_COMPILER in repository_ctx.os.environ:
|
||||
c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip()
|
||||
if c_name.startswith("/"):
|
||||
return c_name
|
||||
c = repository_ctx.which(c_name)
|
||||
if c == None:
|
||||
fail("Cannot find C compiler, please correct your path.")
|
||||
return c
|
||||
"""Find host C compiler."""
|
||||
c_name = "gcc"
|
||||
if _HOST_C_COMPILER in repository_ctx.os.environ:
|
||||
c_name = repository_ctx.os.environ[_HOST_C_COMPILER].strip()
|
||||
if c_name.startswith("/"):
|
||||
return c_name
|
||||
c = repository_ctx.which(c_name)
|
||||
if c == None:
|
||||
fail("Cannot find C compiler, please correct your path.")
|
||||
return c
|
||||
|
||||
def find_cc(repository_ctx):
|
||||
"""Find host C++ compiler."""
|
||||
cc_name = "g++"
|
||||
if _HOST_CXX_COMPILER in repository_ctx.os.environ:
|
||||
cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip()
|
||||
if cc_name.startswith("/"):
|
||||
return cc_name
|
||||
cc = repository_ctx.which(cc_name)
|
||||
if cc == None:
|
||||
fail("Cannot find C++ compiler, please correct your path.")
|
||||
return cc
|
||||
"""Find host C++ compiler."""
|
||||
cc_name = "g++"
|
||||
if _HOST_CXX_COMPILER in repository_ctx.os.environ:
|
||||
cc_name = repository_ctx.os.environ[_HOST_CXX_COMPILER].strip()
|
||||
if cc_name.startswith("/"):
|
||||
return cc_name
|
||||
cc = repository_ctx.which(cc_name)
|
||||
if cc == None:
|
||||
fail("Cannot find C++ compiler, please correct your path.")
|
||||
return cc
|
||||
|
||||
def find_computecpp_root(repository_ctx):
|
||||
"""Find ComputeCpp compiler."""
|
||||
sycl_name = ""
|
||||
if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ:
|
||||
sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip()
|
||||
if sycl_name.startswith("/"):
|
||||
return sycl_name
|
||||
fail("Cannot find SYCL compiler, please correct your path")
|
||||
"""Find ComputeCpp compiler."""
|
||||
sycl_name = ""
|
||||
if _COMPUTECPP_TOOLKIT_PATH in repository_ctx.os.environ:
|
||||
sycl_name = repository_ctx.os.environ[_COMPUTECPP_TOOLKIT_PATH].strip()
|
||||
if sycl_name.startswith("/"):
|
||||
return sycl_name
|
||||
fail("Cannot find SYCL compiler, please correct your path")
|
||||
|
||||
def find_trisycl_include_dir(repository_ctx):
|
||||
"""Find triSYCL include directory. """
|
||||
if _TRISYCL_INCLUDE_DIR in repository_ctx.os.environ:
|
||||
sycl_name = repository_ctx.os.environ[_TRISYCL_INCLUDE_DIR].strip()
|
||||
if sycl_name.startswith("/"):
|
||||
return sycl_name
|
||||
fail("Cannot find triSYCL include directory, please correct your path")
|
||||
"""Find triSYCL include directory. """
|
||||
if _TRISYCL_INCLUDE_DIR in repository_ctx.os.environ:
|
||||
sycl_name = repository_ctx.os.environ[_TRISYCL_INCLUDE_DIR].strip()
|
||||
if sycl_name.startswith("/"):
|
||||
return sycl_name
|
||||
fail( "Cannot find triSYCL include directory, please correct your path")
|
||||
|
||||
def find_python_lib(repository_ctx):
|
||||
"""Returns python path."""
|
||||
if _PYTHON_LIB_PATH in repository_ctx.os.environ:
|
||||
return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip()
|
||||
fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure")
|
||||
"""Returns python path."""
|
||||
if _PYTHON_LIB_PATH in repository_ctx.os.environ:
|
||||
return repository_ctx.os.environ[_PYTHON_LIB_PATH].strip()
|
||||
fail("Environment variable PYTHON_LIB_PATH was not specified re-run ./configure")
|
||||
|
||||
|
||||
def _check_lib(repository_ctx, toolkit_path, lib):
|
||||
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
|
||||
"""Checks if lib exists under sycl_toolkit_path or fail if it doesn't.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
toolkit_path: The toolkit directory containing the libraries.
|
||||
ib: The library to look for under toolkit_path.
|
||||
"""
|
||||
lib_path = toolkit_path + "/" + lib
|
||||
if not repository_ctx.path(lib_path).exists:
|
||||
auto_configure_fail("Cannot find %s" % lib_path)
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
toolkit_path: The toolkit directory containing the libraries.
|
||||
ib: The library to look for under toolkit_path.
|
||||
"""
|
||||
lib_path = toolkit_path + "/" + lib
|
||||
if not repository_ctx.path(lib_path).exists:
|
||||
auto_configure_fail("Cannot find %s" % lib_path)
|
||||
|
||||
def _check_dir(repository_ctx, directory):
|
||||
"""Checks whether the directory exists and fail if it does not.
|
||||
"""Checks whether the directory exists and fail if it does not.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
directory: The directory to check the existence of.
|
||||
"""
|
||||
if not repository_ctx.path(directory).exists:
|
||||
auto_configure_fail("Cannot find dir: %s" % directory)
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
directory: The directory to check the existence of.
|
||||
"""
|
||||
if not repository_ctx.path(directory).exists:
|
||||
auto_configure_fail("Cannot find dir: %s" % directory)
|
||||
|
||||
def _symlink_dir(repository_ctx, src_dir, dest_dir):
|
||||
"""Symlinks all the files in a directory.
|
||||
"""Symlinks all the files in a directory.
|
||||
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
src_dir: The source directory.
|
||||
dest_dir: The destination directory to create the symlinks in.
|
||||
"""
|
||||
files = repository_ctx.path(src_dir).readdir()
|
||||
for src_file in files:
|
||||
repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
|
||||
Args:
|
||||
repository_ctx: The repository context.
|
||||
src_dir: The source directory.
|
||||
dest_dir: The destination directory to create the symlinks in.
|
||||
"""
|
||||
files = repository_ctx.path(src_dir).readdir()
|
||||
for src_file in files:
|
||||
repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
|
||||
if not out:
|
||||
out = tpl.replace(":", "/")
|
||||
repository_ctx.template(
|
||||
out,
|
||||
Label("//third_party/sycl/%s.tpl" % tpl),
|
||||
substitutions,
|
||||
)
|
||||
def _tpl(repository_ctx, tpl, substitutions={}, out=None):
|
||||
if not out:
|
||||
out = tpl.replace(":", "/")
|
||||
repository_ctx.template(
|
||||
out,
|
||||
Label("//third_party/sycl/%s.tpl" % tpl),
|
||||
substitutions)
|
||||
|
||||
def _file(repository_ctx, label):
|
||||
repository_ctx.template(
|
||||
label.replace(":", "/"),
|
||||
Label("//third_party/sycl/%s" % label),
|
||||
{},
|
||||
)
|
||||
repository_ctx.template(
|
||||
label.replace(":", "/"),
|
||||
Label("//third_party/sycl/%s" % label),
|
||||
{})
|
||||
|
||||
_DUMMY_CROSSTOOL_BZL_FILE = """
|
||||
def error_sycl_disabled():
|
||||
@ -149,6 +147,7 @@ def error_sycl_disabled():
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
_DUMMY_CROSSTOOL_BUILD_FILE = """
|
||||
load("//crosstool:error_sycl_disabled.bzl", "error_sycl_disabled")
|
||||
|
||||
@ -156,97 +155,87 @@ error_sycl_disabled()
|
||||
"""
|
||||
|
||||
def _create_dummy_repository(repository_ctx):
|
||||
# Set up BUILD file for sycl/.
|
||||
_tpl(repository_ctx, "sycl:build_defs.bzl")
|
||||
_tpl(repository_ctx, "sycl:BUILD")
|
||||
_file(repository_ctx, "sycl:LICENSE.text")
|
||||
_tpl(repository_ctx, "sycl:platform.bzl")
|
||||
# Set up BUILD file for sycl/.
|
||||
_tpl(repository_ctx, "sycl:build_defs.bzl")
|
||||
_tpl(repository_ctx, "sycl:BUILD")
|
||||
_file(repository_ctx, "sycl:LICENSE.text")
|
||||
_tpl(repository_ctx, "sycl:platform.bzl")
|
||||
|
||||
# Create dummy files for the SYCL toolkit since they are still required by
|
||||
# tensorflow/sycl/platform/default/build_config:sycl.
|
||||
repository_ctx.file("sycl/include/sycl.hpp", "")
|
||||
repository_ctx.file("sycl/lib/libComputeCpp.so", "")
|
||||
# Create dummy files for the SYCL toolkit since they are still required by
|
||||
# tensorflow/sycl/platform/default/build_config:sycl.
|
||||
repository_ctx.file("sycl/include/sycl.hpp", "")
|
||||
repository_ctx.file("sycl/lib/libComputeCpp.so", "")
|
||||
|
||||
# If sycl_configure is not configured to build with SYCL support, and the user
|
||||
# attempts to build with --config=sycl, add a dummy build rule to intercept
|
||||
# this and fail with an actionable error message.
|
||||
repository_ctx.file("crosstool/error_sycl_disabled.bzl",
|
||||
_DUMMY_CROSSTOOL_BZL_FILE)
|
||||
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
|
||||
|
||||
# If sycl_configure is not configured to build with SYCL support, and the user
|
||||
# attempts to build with --config=sycl, add a dummy build rule to intercept
|
||||
# this and fail with an actionable error message.
|
||||
repository_ctx.file(
|
||||
"crosstool/error_sycl_disabled.bzl",
|
||||
_DUMMY_CROSSTOOL_BZL_FILE,
|
||||
)
|
||||
repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
|
||||
|
||||
def _sycl_autoconf_imp(repository_ctx):
|
||||
"""Implementation of the sycl_autoconf rule."""
|
||||
if not _enable_sycl(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
"""Implementation of the sycl_autoconf rule."""
|
||||
if not _enable_sycl(repository_ctx):
|
||||
_create_dummy_repository(repository_ctx)
|
||||
else:
|
||||
# copy template files
|
||||
_tpl(repository_ctx, "sycl:build_defs.bzl")
|
||||
_tpl(repository_ctx, "sycl:BUILD")
|
||||
_tpl(repository_ctx, "sycl:platform.bzl")
|
||||
_tpl(repository_ctx, "crosstool:BUILD")
|
||||
_file(repository_ctx, "sycl:LICENSE.text")
|
||||
|
||||
if _enable_compute_cpp(repository_ctx):
|
||||
_tpl(repository_ctx, "crosstool:computecpp",
|
||||
{
|
||||
"%{host_cxx_compiler}" : find_cc(repository_ctx),
|
||||
"%{host_c_compiler}" : find_c(repository_ctx)
|
||||
})
|
||||
|
||||
computecpp_root = find_computecpp_root(repository_ctx);
|
||||
_check_dir(repository_ctx, computecpp_root)
|
||||
|
||||
_tpl(repository_ctx, "crosstool:CROSSTOOL",
|
||||
{
|
||||
"%{sycl_include_dir}" : computecpp_root,
|
||||
"%{sycl_impl}" : "computecpp",
|
||||
"%{c++_std}" : "-std=c++11",
|
||||
"%{python_lib_path}" : find_python_lib(repository_ctx),
|
||||
})
|
||||
|
||||
# symlink libraries
|
||||
_check_lib(repository_ctx, computecpp_root+"/lib", "libComputeCpp.so" )
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
|
||||
else:
|
||||
# copy template files
|
||||
_tpl(repository_ctx, "sycl:build_defs.bzl")
|
||||
_tpl(repository_ctx, "sycl:BUILD")
|
||||
_tpl(repository_ctx, "sycl:platform.bzl")
|
||||
_tpl(repository_ctx, "crosstool:BUILD")
|
||||
_file(repository_ctx, "sycl:LICENSE.text")
|
||||
|
||||
if _enable_compute_cpp(repository_ctx):
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:computecpp",
|
||||
{
|
||||
"%{host_cxx_compiler}": find_cc(repository_ctx),
|
||||
"%{host_c_compiler}": find_c(repository_ctx),
|
||||
},
|
||||
)
|
||||
trisycl_include_dir = find_trisycl_include_dir(repository_ctx);
|
||||
_check_dir(repository_ctx, trisycl_include_dir)
|
||||
|
||||
computecpp_root = find_computecpp_root(repository_ctx)
|
||||
_check_dir(repository_ctx, computecpp_root)
|
||||
_tpl(repository_ctx, "crosstool:trisycl",
|
||||
{
|
||||
"%{host_cxx_compiler}" : find_cc(repository_ctx),
|
||||
"%{host_c_compiler}" : find_c(repository_ctx),
|
||||
"%{trisycl_include_dir}" : trisycl_include_dir
|
||||
})
|
||||
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:CROSSTOOL",
|
||||
{
|
||||
"%{sycl_include_dir}": computecpp_root,
|
||||
"%{sycl_impl}": "computecpp",
|
||||
"%{c++_std}": "-std=c++11",
|
||||
"%{python_lib_path}": find_python_lib(repository_ctx),
|
||||
},
|
||||
)
|
||||
|
||||
# symlink libraries
|
||||
_check_lib(repository_ctx, computecpp_root + "/lib", "libComputeCpp.so")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/lib", "sycl/lib")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/include", "sycl/include")
|
||||
_symlink_dir(repository_ctx, computecpp_root + "/bin", "sycl/bin")
|
||||
else:
|
||||
trisycl_include_dir = find_trisycl_include_dir(repository_ctx)
|
||||
_check_dir(repository_ctx, trisycl_include_dir)
|
||||
_tpl(repository_ctx, "crosstool:CROSSTOOL",
|
||||
{
|
||||
"%{sycl_include_dir}" : trisycl_include_dir,
|
||||
"%{sycl_impl}" : "trisycl",
|
||||
"%{c++_std}" : "-std=c++1y",
|
||||
"%{python_lib_path}" : find_python_lib(repository_ctx),
|
||||
})
|
||||
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:trisycl",
|
||||
{
|
||||
"%{host_cxx_compiler}": find_cc(repository_ctx),
|
||||
"%{host_c_compiler}": find_c(repository_ctx),
|
||||
"%{trisycl_include_dir}": trisycl_include_dir,
|
||||
},
|
||||
)
|
||||
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include")
|
||||
|
||||
_tpl(
|
||||
repository_ctx,
|
||||
"crosstool:CROSSTOOL",
|
||||
{
|
||||
"%{sycl_include_dir}": trisycl_include_dir,
|
||||
"%{sycl_impl}": "trisycl",
|
||||
"%{c++_std}": "-std=c++1y",
|
||||
"%{python_lib_path}": find_python_lib(repository_ctx),
|
||||
},
|
||||
)
|
||||
|
||||
_symlink_dir(repository_ctx, trisycl_include_dir, "sycl/include")
|
||||
|
||||
sycl_configure = repository_rule(
|
||||
implementation = _sycl_autoconf_imp,
|
||||
local = True,
|
||||
implementation = _sycl_autoconf_imp,
|
||||
local = True,
|
||||
)
|
||||
"""Detects and configures the SYCL toolchain.
|
||||
|
||||
|
49
third_party/toolchains/clang6/repo.bzl
vendored
49
third_party/toolchains/clang6/repo.bzl
vendored
@ -1,37 +1,30 @@
|
||||
"""Repository rule for Debian 8 Jessie Clang-6.0 portable Linux builds."""
|
||||
|
||||
def _clang6_configure(ctx):
|
||||
# TODO(jart): It'd probably be better to use Bazel's struct.to_proto()
|
||||
# method to generate a gigantic CROSSTOOL file that allows
|
||||
# Clang to support everything.
|
||||
ctx.symlink(
|
||||
ctx.os.environ.get(
|
||||
"TF_LLVM_PATH",
|
||||
"/usr/lib/llvm-6.0",
|
||||
),
|
||||
"clang6/llvm",
|
||||
)
|
||||
ctx.symlink(
|
||||
ctx.os.environ.get("STRIP", "/usr/bin/strip"),
|
||||
"clang6/sbin/strip",
|
||||
)
|
||||
ctx.symlink(
|
||||
ctx.os.environ.get("OBJDUMP", "/usr/bin/objdump"),
|
||||
"clang6/sbin/objdump",
|
||||
)
|
||||
ctx.symlink(ctx.attr._build, "clang6/BUILD")
|
||||
ctx.template("clang6/CROSSTOOL", ctx.attr._crosstool, {
|
||||
"%package(@local_config_clang6//clang6)%": str(ctx.path("clang6")),
|
||||
})
|
||||
# TODO(jart): It'd probably be better to use Bazel's struct.to_proto()
|
||||
# method to generate a gigantic CROSSTOOL file that allows
|
||||
# Clang to support everything.
|
||||
ctx.symlink(
|
||||
ctx.os.environ.get('TF_LLVM_PATH',
|
||||
'/usr/lib/llvm-6.0'),
|
||||
'clang6/llvm')
|
||||
ctx.symlink(
|
||||
ctx.os.environ.get('STRIP', '/usr/bin/strip'),
|
||||
'clang6/sbin/strip')
|
||||
ctx.symlink(
|
||||
ctx.os.environ.get('OBJDUMP', '/usr/bin/objdump'),
|
||||
'clang6/sbin/objdump')
|
||||
ctx.symlink(ctx.attr._build, 'clang6/BUILD')
|
||||
ctx.template('clang6/CROSSTOOL', ctx.attr._crosstool, {
|
||||
'%package(@local_config_clang6//clang6)%': str(ctx.path('clang6')),
|
||||
})
|
||||
|
||||
clang6_configure = repository_rule(
|
||||
implementation = _clang6_configure,
|
||||
attrs = {
|
||||
"_build": attr.label(
|
||||
default = str(Label("//third_party/toolchains/clang6:clang.BUILD")),
|
||||
),
|
||||
"_crosstool": attr.label(
|
||||
default = str(Label("//third_party/toolchains/clang6:CROSSTOOL.tpl")),
|
||||
),
|
||||
'_build': attr.label(
|
||||
default=str(Label('//third_party/toolchains/clang6:clang.BUILD'))),
|
||||
'_crosstool': attr.label(
|
||||
default=str(Label('//third_party/toolchains/clang6:CROSSTOOL.tpl'))),
|
||||
},
|
||||
)
|
||||
|
@ -1,38 +1,38 @@
|
||||
# -*- Python -*-
|
||||
"""Repository rule for arm compiler autoconfiguration."""
|
||||
|
||||
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
|
||||
if not out:
|
||||
out = tpl
|
||||
repository_ctx.template(
|
||||
out,
|
||||
Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl),
|
||||
substitutions,
|
||||
)
|
||||
def _tpl(repository_ctx, tpl, substitutions={}, out=None):
|
||||
if not out:
|
||||
out = tpl
|
||||
repository_ctx.template(
|
||||
out,
|
||||
Label("//third_party/toolchains/cpus/arm:%s.tpl" % tpl),
|
||||
substitutions)
|
||||
|
||||
|
||||
def _arm_compiler_configure_impl(repository_ctx):
|
||||
# We need to find a cross-compilation include directory for Python, so look
|
||||
# for an environment variable. Be warned, this crosstool template is only
|
||||
# regenerated on the first run of Bazel, so if you change the variable after
|
||||
# it may not be reflected in later builds. Doing a shutdown and clean of Bazel
|
||||
# doesn't fix this, you'll need to delete the generated file at something like:
|
||||
# external/local_config_arm_compiler/CROSSTOOL in your Bazel install.
|
||||
if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ:
|
||||
python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"]
|
||||
else:
|
||||
python_include_path = "/usr/include/python2.7"
|
||||
_tpl(repository_ctx, "CROSSTOOL", {
|
||||
"%{ARM_COMPILER_PATH}%": str(repository_ctx.path(
|
||||
repository_ctx.attr.remote_config_repo,
|
||||
)),
|
||||
"%{PYTHON_INCLUDE_PATH}%": python_include_path,
|
||||
})
|
||||
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
|
||||
# We need to find a cross-compilation include directory for Python, so look
|
||||
# for an environment variable. Be warned, this crosstool template is only
|
||||
# regenerated on the first run of Bazel, so if you change the variable after
|
||||
# it may not be reflected in later builds. Doing a shutdown and clean of Bazel
|
||||
# doesn't fix this, you'll need to delete the generated file at something like:
|
||||
# external/local_config_arm_compiler/CROSSTOOL in your Bazel install.
|
||||
if "CROSSTOOL_PYTHON_INCLUDE_PATH" in repository_ctx.os.environ:
|
||||
python_include_path = repository_ctx.os.environ["CROSSTOOL_PYTHON_INCLUDE_PATH"]
|
||||
else:
|
||||
python_include_path = "/usr/include/python2.7"
|
||||
_tpl(repository_ctx, "CROSSTOOL", {
|
||||
"%{ARM_COMPILER_PATH}%": str(repository_ctx.path(
|
||||
repository_ctx.attr.remote_config_repo)),
|
||||
"%{PYTHON_INCLUDE_PATH}%": python_include_path,
|
||||
})
|
||||
repository_ctx.symlink(repository_ctx.attr.build_file, "BUILD")
|
||||
|
||||
|
||||
arm_compiler_configure = repository_rule(
|
||||
implementation = _arm_compiler_configure_impl,
|
||||
attrs = {
|
||||
"remote_config_repo": attr.string(mandatory = False, default = ""),
|
||||
"remote_config_repo": attr.string(mandatory = False, default =""),
|
||||
"build_file": attr.label(),
|
||||
},
|
||||
)
|
||||
|
@ -12,13 +12,15 @@ def if_cuda(if_true, if_false = []):
|
||||
return select({
|
||||
"@local_config_cuda//cuda:using_nvcc": if_true,
|
||||
"@local_config_cuda//cuda:using_clang": if_true,
|
||||
"//conditions:default": if_false,
|
||||
"//conditions:default": if_false
|
||||
})
|
||||
|
||||
|
||||
def cuda_default_copts():
|
||||
"""Default options for all CUDA compilations."""
|
||||
return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + ["--cuda-gpu-arch=sm_30"])
|
||||
|
||||
|
||||
def cuda_is_configured():
|
||||
"""Returns true if CUDA was enabled during the configure process."""
|
||||
return True
|
||||
@ -30,5 +32,6 @@ def if_cuda_is_configured(x):
|
||||
--config=cuda. Used to allow non-CUDA code to depend on CUDA libraries.
|
||||
"""
|
||||
if cuda_is_configured():
|
||||
return x
|
||||
return x
|
||||
return []
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user