This change allows the linkage of multithreaded XLA AOT CPU backend objects, such as multithreaded matmul, conv2d, etc. These are not enabled by default. New unit tests confirm that the objects are emitted and linked correctly, and the resulting computations are numerically correct. MKL service backend objects are not included. Other changes: * C++ Unit tests now use arg_feed_{x,y} instead of arg0/arg1, since the names are flaky (they may swap from the signature) * Add argument "multithreading=" to the bzl file and saved_model_cli. * Add unit tests using "nm" to ensure that the proper symbols are used when enabling or disabling multithreading (not sure if they are windows-friendly). * Use a simpler and more unique string for the entry_point string. PiperOrigin-RevId: 338112208 Change-Id: Id734e75e63e72db93a743f451ddb7eb6f489c1c7
178 lines
6.7 KiB
Python
178 lines
6.7 KiB
Python
"""Definitions for using tools like saved_model_cli."""
|
|
|
|
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
|
|
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
|
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
|
|
|
|
def _maybe_force_compile(args, force_compile):
|
|
if force_compile:
|
|
return args
|
|
else:
|
|
return if_xla_available(args)
|
|
|
|
def saved_model_compile_aot(
|
|
name,
|
|
directory,
|
|
filegroups,
|
|
cpp_class,
|
|
checkpoint_path = None,
|
|
tag_set = "serve",
|
|
signature_def = "serving_default",
|
|
variables_to_feed = "",
|
|
target_triple = None,
|
|
target_cpu = None,
|
|
multithreading = False,
|
|
force_without_xla_support_flag = True,
|
|
tags = None):
|
|
"""Compile a SavedModel directory accessible from a filegroup.
|
|
|
|
This target rule takes a path to a filegroup directory containing a
|
|
SavedModel and generates a cc_library with an AOT compiled model.
|
|
For extra details, see the help for saved_model_cli's aot_compile_cpu help.
|
|
|
|
**NOTE** Any variables passed to `variables_to_feed` *must be set by the
|
|
user*. These variables will NOT be frozen and their values will be
|
|
uninitialized in the compiled object (this applies to all input
|
|
arguments from the signature as well).
|
|
|
|
Example usage:
|
|
|
|
```
|
|
saved_model_compile_aot(
|
|
name = "aot_compiled_x_plus_y",
|
|
cpp_class = "tensorflow::CompiledModel",
|
|
directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo",
|
|
filegroups = [
|
|
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
|
]
|
|
)
|
|
|
|
cc_test(
|
|
name = "test",
|
|
srcs = ["test.cc"],
|
|
deps = [
|
|
"//tensorflow/core:test_main",
|
|
":aot_compiled_x_plus_y",
|
|
"//tensorflow/core:test",
|
|
"//tensorflow/core/platform:logging",
|
|
]),
|
|
)
|
|
|
|
In "test.cc":
|
|
|
|
#include "third_party/tensorflow/python/tools/aot_compiled_x_plus_y.h"
|
|
|
|
TEST(Test, Run) {
|
|
tensorflow::CompiledModel model;
|
|
CHECK(model.Run());
|
|
}
|
|
```
|
|
|
|
Args:
|
|
name: The rule name, and the name prefix of the headers and object file
|
|
emitted by this rule.
|
|
directory: The bazel directory containing saved_model.pb and variables/
|
|
subdirectories.
|
|
filegroups: List of `filegroup` targets; these filegroups contain the
|
|
files pointed to by `directory` and `checkpoint_path`.
|
|
cpp_class: The name of the C++ class that will be generated, including
|
|
namespace; e.g. "my_model::InferenceRunner".
|
|
checkpoint_path: The bazel directory containing `variables.index`. If
|
|
not provided, then `$directory/variables/` is used
|
|
(default for SavedModels).
|
|
tag_set: The tag set to use in the SavedModel.
|
|
signature_def: The name of the signature to use from the SavedModel.
|
|
variables_to_feed: (optional) The names of the variables to feed, a comma
|
|
separated string, or 'all'. If empty, all variables will be frozen and none
|
|
may be fed at runtime.
|
|
|
|
**NOTE** Any variables passed to `variables_to_feed` *must be set by
|
|
the user*. These variables will NOT be frozen and their values will be
|
|
uninitialized in the compiled object (this applies to all input
|
|
arguments from the signature as well).
|
|
target_triple: The LLVM target triple to use (defaults to current build's
|
|
target architecture's triple). Similar to clang's -target flag.
|
|
target_cpu: The LLVM cpu name used for compilation. Similar to clang's
|
|
-mcpu flag.
|
|
multithreading: Whether to compile multithreaded AOT code.
|
|
Note, this increases the set of dependencies for binaries using
|
|
the AOT library at both build and runtime. For example,
|
|
the resulting object files may have external dependencies on
|
|
multithreading libraries like nsync.
|
|
force_without_xla_support_flag: Whether to compile even when
|
|
`--define=with_xla_support=true` is not set. If `False`, and the
|
|
define is not passed when building, then the created `cc_library`
|
|
will be empty. In this case, downstream targets should
|
|
conditionally build using macro `tfcompile.bzl:if_xla_available`.
|
|
This flag is used by the TensorFlow build to avoid building on
|
|
architectures that do not support XLA.
|
|
tags: List of target tags.
|
|
"""
|
|
saved_model = "{}/saved_model.pb".format(directory)
|
|
target_triple = target_triple or target_llvm_triple()
|
|
target_cpu = target_cpu or tfcompile_target_cpu() or ""
|
|
variables_to_feed = variables_to_feed or "''"
|
|
if checkpoint_path:
|
|
checkpoint_cmd_args = (
|
|
"--checkpoint_path \"$$(dirname $(location {}/variables.index))\" "
|
|
.format(checkpoint_path)
|
|
)
|
|
checkpoint_srcs = ["{}/variables.index".format(checkpoint_path)]
|
|
else:
|
|
checkpoint_cmd_args = ""
|
|
checkpoint_srcs = []
|
|
|
|
native.genrule(
|
|
name = "{}_gen".format(name),
|
|
srcs = filegroups + [saved_model] + checkpoint_srcs,
|
|
outs = [
|
|
"{}.h".format(name),
|
|
"{}.o".format(name),
|
|
"{}_metadata.o".format(name),
|
|
"{}_makefile.inc".format(name),
|
|
],
|
|
cmd = (
|
|
"$(location {}) aot_compile_cpu ".format(
|
|
clean_dep("//tensorflow/python/tools:saved_model_cli"),
|
|
) +
|
|
"--dir \"$$(dirname $(location {}))\" ".format(saved_model) +
|
|
checkpoint_cmd_args +
|
|
"--output_prefix $(@D)/{} ".format(name) +
|
|
"--cpp_class {} ".format(cpp_class) +
|
|
"--variables_to_feed {} ".format(variables_to_feed) +
|
|
"--signature_def_key {} ".format(signature_def) +
|
|
"--multithreading {} ".format(multithreading) +
|
|
"--target_triple " + target_triple + " " +
|
|
("--target_cpu " + target_cpu + " " if target_cpu else "") +
|
|
"--tag_set {} ".format(tag_set)
|
|
),
|
|
tags = tags,
|
|
tools = [
|
|
"//tensorflow/python/tools:saved_model_cli",
|
|
],
|
|
)
|
|
|
|
native.cc_library(
|
|
name = name,
|
|
srcs = _maybe_force_compile(
|
|
[
|
|
":{}.o".format(name),
|
|
":{}_metadata.o".format(name),
|
|
],
|
|
force_compile = force_without_xla_support_flag,
|
|
),
|
|
hdrs = _maybe_force_compile(
|
|
[
|
|
":{}.h".format(name),
|
|
],
|
|
force_compile = force_without_xla_support_flag,
|
|
),
|
|
tags = tags,
|
|
deps = _maybe_force_compile(
|
|
[
|
|
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_standalone",
|
|
],
|
|
force_compile = force_without_xla_support_flag,
|
|
),
|
|
)
|