STT-tensorflow/tensorflow/python/tools/tools.bzl
Eugene Brevdo e25d9862ca [TF saved_model_cli] Allow user to set target_cpu for xla aot compilation.
PiperOrigin-RevId: 312290453
Change-Id: I024e2b3884e436578e351d43961199e4c28307c3
2020-05-19 09:14:38 -07:00

171 lines
6.3 KiB
Python

"""Definitions for using tools like saved_model_cli."""
load("//tensorflow:tensorflow.bzl", "clean_dep", "if_xla_available")
load("//tensorflow:tensorflow.bzl", "tfcompile_target_cpu")
load("//tensorflow/compiler/aot:tfcompile.bzl", "target_llvm_triple")
def _maybe_force_compile(args, force_compile):
if force_compile:
return args
else:
return if_xla_available(args)
def saved_model_compile_aot(
name,
directory,
filegroups,
cpp_class,
checkpoint_path = None,
tag_set = "serve",
signature_def = "serving_default",
variables_to_feed = "",
target_triple = None,
target_cpu = None,
force_without_xla_support_flag = True,
tags = None):
"""Compile a SavedModel directory accessible from a filegroup.
This target rule takes a path to a filegroup directory containing a
SavedModel and generates a cc_library with an AOT compiled model.
For extra details, see the help for saved_model_cli's aot_compile_cpu help.
**NOTE** Any variables passed to `variables_to_feed` *must be set by the
user*. These variables will NOT be frozen and their values will be
uninitialized in the compiled object (this applies to all input
arguments from the signature as well).
Example usage:
```
saved_model_compile_aot(
name = "aot_compiled_x_plus_y",
cpp_class = "tensorflow::CompiledModel",
directory = "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo",
filegroups = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
]
)
cc_test(
name = "test",
srcs = ["test.cc"],
deps = [
"//tensorflow/core:test_main",
":aot_compiled_x_plus_y",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
]),
)
In "test.cc":
#include "third_party/tensorflow/python/tools/aot_compiled_x_plus_y.h"
TEST(Test, Run) {
tensorflow::CompiledModel model;
CHECK(model.Run());
}
```
Args:
name: The rule name, and the name prefix of the headers and object file
emitted by this rule.
directory: The bazel directory containing saved_model.pb and variables/
subdirectories.
filegroups: List of `filegroup` targets; these filegroups contain the
files pointed to by `directory` and `checkpoint_path`.
cpp_class: The name of the C++ class that will be generated, including
namespace; e.g. "my_model::InferenceRunner".
checkpoint_path: The bazel directory containing `variables.index`. If
not provided, then `$directory/variables/` is used
(default for SavedModels).
tag_set: The tag set to use in the SavedModel.
signature_def: The name of the signature to use from the SavedModel.
variables_to_feed: (optional) The names of the variables to feed, a comma
separated string, or 'all'. If empty, all variables will be frozen and none
may be fed at runtime.
**NOTE** Any variables passed to `variables_to_feed` *must be set by
the user*. These variables will NOT be frozen and their values will be
uninitialized in the compiled object (this applies to all input
arguments from the signature as well).
target_triple: The LLVM target triple to use (defaults to current build's
target architecture's triple). Similar to clang's -target flag.
target_cpu: The LLVM cpu name used for compilation. Similar to clang's
-mcpu flag.
force_without_xla_support_flag: Whether to compile even when
`--define=with_xla_support=true` is not set. If `False`, and the
define is not passed when building, then the created `cc_library`
will be empty. In this case, downstream targets should
conditionally build using macro `tfcompile.bzl:if_xla_available`.
This flag is used by the TensorFlow build to avoid building on
architectures that do not support XLA.
tags: List of target tags.
"""
saved_model = "{}/saved_model.pb".format(directory)
target_triple = target_triple or target_llvm_triple()
target_cpu = target_cpu or tfcompile_target_cpu() or ""
variables_to_feed = variables_to_feed or "''"
if checkpoint_path:
checkpoint_cmd_args = (
"--checkpoint_path \"$$(dirname $(location {}/variables.index))\" "
.format(checkpoint_path)
)
checkpoint_srcs = ["{}/variables.index".format(checkpoint_path)]
else:
checkpoint_cmd_args = ""
checkpoint_srcs = []
native.genrule(
name = "{}_gen".format(name),
srcs = filegroups + [saved_model] + checkpoint_srcs,
outs = [
"{}.h".format(name),
"{}.o".format(name),
"{}_metadata.o".format(name),
"{}_makefile.inc".format(name),
],
cmd = (
"$(location {}) aot_compile_cpu ".format(
clean_dep("//tensorflow/python/tools:saved_model_cli"),
) +
"--dir \"$$(dirname $(location {}))\" ".format(saved_model) +
checkpoint_cmd_args +
"--output_prefix $(@D)/{} ".format(name) +
"--cpp_class {} ".format(cpp_class) +
"--variables_to_feed {} ".format(variables_to_feed) +
"--signature_def_key {} ".format(signature_def) +
"--target_triple " + target_triple + " " +
("--target_cpu " + target_cpu + " " if target_cpu else "") +
"--tag_set {} ".format(tag_set)
),
tags = tags,
tools = [
"//tensorflow/python/tools:saved_model_cli",
],
)
native.cc_library(
name = name,
srcs = _maybe_force_compile(
[
":{}.o".format(name),
":{}_metadata.o".format(name),
],
force_compile = force_without_xla_support_flag,
),
hdrs = _maybe_force_compile(
[
":{}.h".format(name),
],
force_compile = force_without_xla_support_flag,
),
tags = tags,
deps = _maybe_force_compile(
[
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_runtime_standalone",
],
force_compile = force_without_xla_support_flag,
),
)