STT-tensorflow/tensorflow/python/tools/tools.bzl
Eugene Brevdo 84967b39fa [TF] [saved_model_cli] Add support for multithreaded cpu service. Off by default.
This change allows the linkage of multithreaded XLA AOT CPU backend objects,
such as multithreaded matmul, conv2d, etc.  These are not enabled by default.

New unit tests confirm that the objects are emitted and linked correctly,
and the resulting computations are numerically correct.

MKL service backend objects are not included.

Other changes:
* C++ Unit tests now use arg_feed_{x,y} instead of arg0/arg1, since the names
  are flaky (they may swap from the signature)
* Add argument "multithreading=" to the bzl file and saved_model_cli.
* Add unit tests using "nm" to ensure that the proper symbols are used when
  enabling or disabling multithreading (not sure if they are windows-friendly).
* Use a simpler and more unique string for the entry_point string.

PiperOrigin-RevId: 338112208
Change-Id: Id734e75e63e72db93a743f451ddb7eb6f489c1c7
2020-10-20 12:36:58 -07:00

178 lines
6.7 KiB
Python

"""Definitions for using tools like saved_model_cli."""
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
def _maybe_force_compile(args, force_compile):
if force_compile:
return args
else:
return if_xla_available(args)
def saved_model_compile_aot(
name,
directory,
filegroups,
cpp_class,
checkpoint_path = None,
tag_set = "serve",
signature_def = "serving_default",
variables_to_feed = "",
target_triple = None,
target_cpu = None,
multithreading = False,
force_without_xla_support_flag = True,
tags = None):
"""Compile a SavedModel directory accessible from a filegroup.
This target rule takes a path to a filegroup directory containing a
SavedModel and generates a cc_library with an AOT compiled model.
For extra details, see the help for saved_model_cli's aot_compile_cpu help.
**NOTE** Any variables passed to `variables_to_feed` *must be set by the
user*. These variables will NOT be frozen and their values will be
uninitialized in the compiled object (this applies to all input
arguments from the signature as well).
Example usage:
```
saved_model_compile_aot(
name = "aot_compiled_x_plus_y",
cpp_class = "tensorflow::CompiledModel",
directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo",
filegroups = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
]
)
cc_test(
name = "test",
srcs = ["test.cc"],
deps = [
"//tensorflow/core:test_main",
":aot_compiled_x_plus_y",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
]),
)
In "test.cc":
#include "third_party/tensorflow/python/tools/aot_compiled_x_plus_y.h"
TEST(Test, Run) {
tensorflow::CompiledModel model;
CHECK(model.Run());
}
```
Args:
name: The rule name, and the name prefix of the headers and object file
emitted by this rule.
directory: The bazel directory containing saved_model.pb and variables/
subdirectories.
filegroups: List of `filegroup` targets; these filegroups contain the
files pointed to by `directory` and `checkpoint_path`.
cpp_class: The name of the C++ class that will be generated, including
namespace; e.g. "my_model::InferenceRunner".
checkpoint_path: The bazel directory containing `variables.index`. If
not provided, then `$directory/variables/` is used
(default for SavedModels).
tag_set: The tag set to use in the SavedModel.
signature_def: The name of the signature to use from the SavedModel.
variables_to_feed: (optional) The names of the variables to feed, a comma
separated string, or 'all'. If empty, all variables will be frozen and none
may be fed at runtime.
**NOTE** Any variables passed to `variables_to_feed` *must be set by
the user*. These variables will NOT be frozen and their values will be
uninitialized in the compiled object (this applies to all input
arguments from the signature as well).
target_triple: The LLVM target triple to use (defaults to current build's
target architecture's triple). Similar to clang's -target flag.
target_cpu: The LLVM cpu name used for compilation. Similar to clang's
-mcpu flag.
multithreading: Whether to compile multithreaded AOT code.
Note, this increases the set of dependencies for binaries using
the AOT library at both build and runtime. For example,
the resulting object files may have external dependencies on
multithreading libraries like nsync.
force_without_xla_support_flag: Whether to compile even when
`--define=with_xla_support=true` is not set. If `False`, and the
define is not passed when building, then the created `cc_library`
will be empty. In this case, downstream targets should
conditionally build using macro `tfcompile.bzl:if_xla_available`.
This flag is used by the TensorFlow build to avoid building on
architectures that do not support XLA.
tags: List of target tags.
"""
saved_model = "{}/saved_model.pb".format(directory)
target_triple = target_triple or target_llvm_triple()
target_cpu = target_cpu or tfcompile_target_cpu() or ""
variables_to_feed = variables_to_feed or "''"
if checkpoint_path:
checkpoint_cmd_args = (
"--checkpoint_path \"$$(dirname $(location {}/variables.index))\" "
.format(checkpoint_path)
)
checkpoint_srcs = ["{}/variables.index".format(checkpoint_path)]
else:
checkpoint_cmd_args = ""
checkpoint_srcs = []
native.genrule(
name = "{}_gen".format(name),
srcs = filegroups + [saved_model] + checkpoint_srcs,
outs = [
"{}.h".format(name),
"{}.o".format(name),
"{}_metadata.o".format(name),
"{}_makefile.inc".format(name),
],
cmd = (
"$(location {}) aot_compile_cpu ".format(
clean_dep("//tensorflow/python/tools:saved_model_cli"),
) +
"--dir \"$$(dirname $(location {}))\" ".format(saved_model) +
checkpoint_cmd_args +
"--output_prefix $(@D)/{} ".format(name) +
"--cpp_class {} ".format(cpp_class) +
"--variables_to_feed {} ".format(variables_to_feed) +
"--signature_def_key {} ".format(signature_def) +
"--multithreading {} ".format(multithreading) +
"--target_triple " + target_triple + " " +
("--target_cpu " + target_cpu + " " if target_cpu else "") +
"--tag_set {} ".format(tag_set)
),
tags = tags,
tools = [
"//tensorflow/python/tools:saved_model_cli",
],
)
native.cc_library(
name = name,
srcs = _maybe_force_compile(
[
":{}.o".format(name),
":{}_metadata.o".format(name),
],
force_compile = force_without_xla_support_flag,
),
hdrs = _maybe_force_compile(
[
":{}.h".format(name),
],
force_compile = force_without_xla_support_flag,
),
tags = tags,
deps = _maybe_force_compile(
[
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_standalone",
],
force_compile = force_without_xla_support_flag,
),
)