[TF XLA] Add ability to convert SavedModel subgraphs to compiled [XLA CPU] objects via saved_model_cli.

You can now run, e.g.:

saved_model_cli aot_compile_cpu \
  --dir /path/to/saved_model \
  --tag_set serve \
  --signature_def_key action \
  --output_prefix /tmp/out \
  --cpp_class Serving::Action

Which will create the files:
  /tmp/{out.h, out.o, out_metadata.o, out_makefile.inc}

where out.h defines something like:

namespace Serving {
  class Action {
    ...
  }
}

and out_makefile.inc provides the additional flags required to include the header
and object files into your build.

You can optionally also point aot_compile_cpu to a newer set of checkpoints (weight values) by using the optional argument --checkpoint_path.

Also added `tf.test.is_built_with_xla()`.

TESTED:
* bazel test -c opt :saved_model_cli_test passes
* built and installed the pip wheel and tested in the bazel directory via:
  TEST_SRCDIR=/tmp/tfcompile/bazel-bin/tensorflow/python/tools/saved_model_cli_test.runfiles/ python saved_model_cli_test.py

and checking the output files to ensure the proper includes and header directories are
set in out_makefile.inc and out.h.

PiperOrigin-RevId: 290144104
Change-Id: If8eb6c3334b3042c4b9c24813b1b52c06d8fbc06
This commit is contained in:
Eugene Brevdo 2020-01-16 14:20:03 -08:00 committed by TensorFlower Gardener
parent 3bcfb829bb
commit 9959c04433
36 changed files with 934 additions and 141 deletions

View File

@ -2,6 +2,7 @@
# TensorFlow is a computational framework, primarily for use in machine
# learning applications.
load("@bazel_skylib//lib:selects.bzl", "selects")
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
load(
"//tensorflow/core/platform:build_config.bzl",

View File

@ -1,6 +1,13 @@
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
package(
default_visibility = ["//visibility:private"],
licenses = ["notice"], # Apache 2.0
@ -27,9 +34,14 @@ cc_library(
"compile.h",
"flags.h",
],
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":aot_only_var_handle_op",
":embedded_protocol_buffers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
@ -53,12 +65,45 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
# Necessary for the pywrap inclusion below.
tf_pybind_cc_library_wrapper(
name = "tfcompile_headers_lib",
deps = [
":tfcompile_lib",
],
)
tf_python_pybind_extension(
name = "_pywrap_tfcompile",
srcs = ["tfcompile_wrapper.cc"],
features = ["-layering_check"],
module_name = "_pywrap_tfcompile",
visibility = ["//tensorflow/python:__pkg__"],
deps = [
":tfcompile_headers_lib",
"@pybind11",
"//third_party/python_runtime:headers",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
# These headers cannot be brought in via cc_header_only_library
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
] + if_llvm_aarch64_available([
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
]),
)
tf_cc_test(
name = "codegen_test",
srcs = ["codegen_test.cc"],
@ -104,11 +149,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
"@llvm-project//llvm:target",
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
],
)
@ -214,8 +254,13 @@ cc_library(
cc_library(
name = "aot_only_var_handle_op",
srcs = ["aot_only_var_handle_op.cc"],
hdrs = ["aot_only_var_handle_op.h"],
visibility = [
"//tensorflow/compiler/tf2xla:__pkg__",
],
deps = [
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/core:framework",
],
alwayslink = 1,
)

View File

@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace {
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
}
} // namespace
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
.Doc(R"doc(
Internal VarHandleOp registration used for XLA AOT compilation.
)doc")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
PartialTensorShape p;
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
});
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
XlaAotOnlyVarHandleOp);
} // namespace tensorflow

View File

@ -0,0 +1,27 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
namespace tensorflow {
namespace tfcompile {
static constexpr const char* const kXlaAotOnlyVarHandleOp =
"_XlaAotOnlyVarHandleOp";
} // namespace tfcompile
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_

View File

@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
const string include_xla_data_proto =
opts.gen_program_shape
?
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: "";
const string include_hlo_profile_printer_data_proto =
@ -458,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
{{INCLUDE_XLA_DATA_PROTO}}
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; }
@ -660,6 +659,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},

View File

@ -197,6 +197,7 @@ TEST(CodegenTest, Golden) {
variable3->mutable_shape()->add_dim()->set_size(5);
variable3->set_type(DT_INT32);
CompileResult compile_result;
compile_result.tensorflow_header_root = "third_party/tensorflow";
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{},
{BufferInfo::MakeTempBuffer(1),

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -83,6 +85,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
compile_result->pointer_size =
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
} // namespace
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result) {
// Converts the graph into an XLA computation, and compiles the
// computation.
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
if (!flags.mlir_components.empty()) {
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
}
TF_RETURN_IF_ERROR(
ConvertGraphDefToXla(graph_def, config, client, &computation));
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
client, &computation));
}
if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
@ -127,10 +130,102 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
flags.tensorflow_header_root);
return CompileXla(client, computation, aot_opts, compile_result);
}
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
static std::once_flag targets_init;
static void InitializeTargets() {
// Initialize all LLVM targets so we can cross compile.
#if TF_LLVM_AARCH64_AVAILABLE
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
#endif
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
}
Status Main(const MainFlags& flags) {
std::call_once(targets_init, &InitializeTargets);
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(
CompileGraph(std::move(graph_def), config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // namespace tfcompile
} // namespace tensorflow

View File

@ -35,6 +35,7 @@ struct CompileResult {
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShapeProto program_shape; // Static shape of args and results.
string entry_point; // Name of generated function.
string tensorflow_header_root; // Prefix for tensorflow headers.
int pointer_size = 0; // Size of a pointer in bytes.
};
@ -42,9 +43,12 @@ struct CompileResult {
// that performs the graph operations.
//
// The XLA compilation options are specified in the flags.
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
const MainFlags& flags, CompileResult* compile_result);
// The full compilation method, for reuse in a library setting.
Status Main(const MainFlags& flags);
} // namespace tfcompile
} // namespace tensorflow

View File

@ -74,6 +74,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape,
"Generate program shape data for the ProgramShape method."},
{"tensorflow_header_root", &flags->tensorflow_header_root,
"Root directory of tensorflow headers."},
};
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
}

View File

@ -25,6 +25,7 @@ namespace tensorflow {
namespace tfcompile {
// Flags for the tfcompile binary. See *.cc file for descriptions.
struct MainFlags {
string graph;
string config;
@ -39,6 +40,7 @@ struct MainFlags {
string out_header;
string out_session_module;
string mlir_components;
string tensorflow_header_root;
// C++ codegen options
bool gen_name_to_index = false;

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "llvm-c/Target.h"
#include "tensorflow/compiler/aot/codegen.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
@ -56,88 +55,6 @@ const char kUsageHeader[] =
"--cpp_class=\"mynamespace::MyComputation\"\n"
"\n";
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
if (absl::EndsWith(fname, ".pbtxt")) {
return ReadTextProto(Env::Default(), fname, proto);
} else {
return ReadBinaryProto(Env::Default(), fname, proto);
}
}
Status Main(const MainFlags& flags) {
// Initialize all LLVM targets so we can cross compile.
LLVMInitializeAArch64Target();
LLVMInitializeAArch64TargetInfo();
LLVMInitializeAArch64TargetMC();
LLVMInitializeAArch64AsmPrinter();
LLVMInitializeARMTarget();
LLVMInitializeARMTargetInfo();
LLVMInitializeARMTargetMC();
LLVMInitializeARMAsmPrinter();
LLVMInitializePowerPCTarget();
LLVMInitializePowerPCTargetInfo();
LLVMInitializePowerPCTargetMC();
LLVMInitializePowerPCAsmPrinter();
LLVMInitializeX86Target();
LLVMInitializeX86TargetInfo();
LLVMInitializeX86TargetMC();
LLVMInitializeX86AsmPrinter();
// Process config.
tf2xla::Config config;
if (flags.config.empty()) {
return errors::InvalidArgument("Must specify --config");
}
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
TF_RETURN_IF_ERROR(ValidateConfig(config));
if (flags.dump_fetch_nodes) {
std::set<string> nodes;
for (const tf2xla::Fetch& fetch : config.fetch()) {
nodes.insert(fetch.id().node_name());
}
std::cout << absl::StrJoin(nodes, ",");
return Status::OK();
}
// Read and initialize the graph.
if (flags.graph.empty()) {
return errors::InvalidArgument("Must specify --graph");
}
GraphDef graph_def;
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
CompileResult compile_result;
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
// Write output files.
Env* env = Env::Default();
const std::vector<char>& obj = compile_result.aot->object_file_data();
TF_RETURN_IF_ERROR(
WriteStringToFile(env, flags.out_function_object,
absl::string_view(obj.data(), obj.size())));
CodegenOpts codegen_opts;
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
codegen_opts.gen_program_shape = flags.gen_program_shape;
codegen_opts.target_triple = flags.target_triple;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
codegen_opts.gen_hlo_profile_printer_data =
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));
MetadataResult metadata_result;
TF_RETURN_IF_ERROR(
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
metadata_result.object_file_data));
string header;
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
metadata_result, &header));
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
return Status::OK();
}
} // end namespace tfcompile
} // end namespace tensorflow
@ -148,6 +65,7 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h";
flags.entry_point = "entry";
flags.tensorflow_header_root = "third_party/tensorflow";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);

View File

@ -0,0 +1,75 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "include/pybind11/cast.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/pytypes.h"
#include "include/pybind11/stl.h"
#include "tensorflow/compiler/aot/compile.h"
#include "tensorflow/compiler/aot/flags.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = pybind11;
PYBIND11_MODULE(_pywrap_tfcompile, m) {
m.doc() = R"pbdoc(
_pywrap_tfcompile
-----
)pbdoc";
m.def(
"Compile",
[](std::string graph, std::string config, std::string target_triple,
std::string target_cpu, std::string target_features,
std::string entry_point, std::string cpp_class,
std::string out_function_object, std::string out_metadata_object,
std::string out_header, std::string out_session_module,
std::string mlir_components, std::string tensorflow_header_root,
bool gen_name_to_index, bool gen_program_shape) {
tensorflow::tfcompile::MainFlags flags;
flags.graph = std::move(graph);
flags.config = std::move(config);
flags.target_triple = std::move(target_triple);
flags.target_cpu = std::move(target_cpu);
flags.target_features = std::move(target_features);
flags.entry_point = std::move(entry_point);
flags.cpp_class = std::move(cpp_class);
flags.out_function_object = std::move(out_function_object);
flags.out_metadata_object = std::move(out_metadata_object);
flags.out_header = std::move(out_header);
flags.out_session_module = std::move(out_session_module);
flags.mlir_components = std::move(mlir_components);
flags.tensorflow_header_root = std::move(tensorflow_header_root);
// C++ codegen options
flags.gen_name_to_index = gen_name_to_index;
flags.gen_program_shape = gen_program_shape;
tensorflow::MaybeRaiseFromStatus(tensorflow::tfcompile::Main(flags));
},
py::arg("graph") = "", py::arg("config") = "",
py::arg("target_triple") = "x86_64-pc-linux", py::arg("target_cpu") = "",
py::arg("target_features") = "", py::arg("entry_point") = "entry",
py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o",
py::arg("out_metadata_object") = "out_helper.o",
py::arg("out_header") = "out.h", py::arg("out_session_module") = "",
py::arg("mlir_components") = "",
py::arg("tensorflow_header_root") = "third_party/tensorflow",
py::arg("gen_name_to_index") = false,
py::arg("gen_program_shape") = false);
}

View File

@ -5,6 +5,7 @@ load(
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_proto_library",
"tf_proto_library_cc",
)
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
@ -62,7 +63,7 @@ tf_cc_binary(
deps = [":tf2xla_supported_ops_lib"],
)
tf_proto_library_cc(
tf_proto_library(
name = "tf2xla_proto",
srcs = ["tf2xla.proto"],
cc_api_version = 2,
@ -140,6 +141,7 @@ cc_library(
":tf2xla_proto_cc",
":tf2xla_util",
":xla_compiler",
"//tensorflow/compiler/aot:aot_only_var_handle_op",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:xla_computation",

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
#include "tensorflow/compiler/tf2xla/graph_compiler_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
@ -126,12 +127,28 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
return Status::OK();
}
void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
for (auto& node : *graph_def->mutable_node()) {
if (node.op() == "VarHandleOp") {
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
}
}
for (auto& fn : *graph_def->mutable_library()->mutable_function()) {
for (auto& node : *fn.mutable_node_def()) {
if (node.op() == "VarHandleOp") {
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
}
}
}
}
} // namespace
Status ConvertGraphDefToXla(const GraphDef& graph_def,
const tf2xla::Config& config, xla::Client* client,
Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
xla::Client* client,
xla::XlaComputation* computation) {
std::unique_ptr<Graph> graph;
ConvertVarHandlesToAotVarHandles(&graph_def);
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
TF_RETURN_IF_ERROR(
ConvertGraphToXla(std::move(graph), config, client, computation));

View File

@ -30,8 +30,8 @@ namespace tensorflow {
//
// The computation is built in the context of the given `client`, which may
// subsequently be used to compile or execute the computation.
Status ConvertGraphDefToXla(const GraphDef& graph_def,
const tf2xla::Config& config, xla::Client* client,
Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
xla::Client* client,
xla::XlaComputation* computation);
// Similar to ConvertGraphDefToXla, but uses MLIR.

View File

@ -119,12 +119,13 @@ using BufferInfo = cpu_function_runtime::BufferInfo;
CpuAotCompilationOptions::CpuAotCompilationOptions(
string triple, string cpu_name, string features, string entry_point_name,
RelocationModel relocation_model)
RelocationModel relocation_model, string tensorflow_header_root)
: triple_(std::move(triple)),
cpu_name_(std::move(cpu_name)),
features_(std::move(features)),
entry_point_name_(std::move(entry_point_name)),
relocation_model_(relocation_model) {}
relocation_model_(relocation_model),
tensorflow_header_root_(std::move(tensorflow_header_root)) {}
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;

View File

@ -53,7 +53,17 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
CpuAotCompilationOptions(string triple, string cpu_name, string features,
string entry_point_name,
RelocationModel relocation_model);
RelocationModel relocation_model,
string tensorflow_header_root);
CpuAotCompilationOptions(string triple, string cpu_name, string features,
string entry_point_name,
RelocationModel relocation_model)
: CpuAotCompilationOptions(
std::move(triple), std::move(cpu_name), std::move(features),
std::move(entry_point_name), relocation_model,
/*tensorflow_header_root=*/"third_party/tensorflow") {}
~CpuAotCompilationOptions() override;
se::Platform::Id PlatformId() const override;
@ -66,6 +76,10 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
const string& features() const { return features_; }
// The name to be used for the compiled code's entry point.
const string& entry_point_name() const { return entry_point_name_; }
// The prefix for tensorflow headers, e.g. "third_party/tensorflow".
const string& tensorflow_header_root() const {
return tensorflow_header_root_;
}
// The relocation model used for compilation.
RelocationModel relocation_model() const { return relocation_model_; }
@ -75,6 +89,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
const string features_;
const string entry_point_name_;
const RelocationModel relocation_model_;
const string tensorflow_header_root_;
};
class CpuAotCompilationResult : public AotCompilationResult {

View File

@ -2,6 +2,7 @@
load(
"//tensorflow/core/platform/default:build_config.bzl",
_if_llvm_aarch64_available = "if_llvm_aarch64_available",
_pyx_library = "pyx_library",
_tf_additional_all_protos = "tf_additional_all_protos",
_tf_additional_binary_deps = "tf_additional_binary_deps",
@ -80,3 +81,4 @@ tf_protos_profiler_impl = _tf_protos_profiler_impl
tf_py_clif_cc = _tf_py_clif_cc
tf_pyclif_proto_library = _tf_pyclif_proto_library
tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
if_llvm_aarch64_available = _if_llvm_aarch64_available

View File

@ -770,3 +770,8 @@ def tf_google_mobile_srcs_no_runtime():
def tf_google_mobile_srcs_only_runtime():
return []
def if_llvm_aarch64_available(then, otherwise = []):
# TODO(b/...): The TF XLA build fails when adding a dependency on
# @llvm/llvm-project/llvm:aarch64_target.
return otherwise

View File

@ -34,6 +34,14 @@ bool IsBuiltWithROCm() {
#endif
}
bool IsBuiltWithXLA() {
#if TENSORFLOW_USE_XLA
return true;
#else
return false;
#endif
}
bool IsBuiltWithNvcc() {
#if TENSORFLOW_USE_NVCC
return true;

View File

@ -24,6 +24,9 @@ bool IsGoogleCudaEnabled();
// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm)
bool IsBuiltWithROCm();
// Returns true if TENSORFLOW_USE_XLA is defined. (i.e. TF is built with XLA)
bool IsBuiltWithXLA();
// Returns true if TENSORFLOW_USE_NVCC is defined. (i.e. TF is built with nvcc)
bool IsBuiltWithNvcc();

View File

@ -3,7 +3,7 @@
# Public targets:
# ":platform" - Low-level and platform-specific Python code.
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow:tensorflow.bzl", "pybind_extension")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
@ -1109,10 +1109,12 @@ py_library(
":tensor_util",
":type_spec",
":util",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@six_archive//:six",
],
"//tensorflow/python/eager:context",
] + if_xla_available([
"//tensorflow/compiler/aot:_pywrap_tfcompile",
]),
)
py_library(
@ -5553,6 +5555,8 @@ tf_py_wrap_cc(
] + (tf_additional_lib_deps() +
tf_additional_plugin_deps()) + if_ngraph([
"@ngraph_tf//:ngraph_tf",
]) + if_xla_available([
"//tensorflow/compiler/aot:tfcompile_lib",
]),
)

View File

@ -284,6 +284,10 @@ def IsBuiltWithROCm():
return _pywrap_util_port.IsBuiltWithROCm()
def IsBuiltWithXLA():
return _pywrap_util_port.IsBuiltWithXLA()
def IsBuiltWithNvcc():
return _pywrap_util_port.IsBuiltWithNvcc()

View File

@ -106,3 +106,9 @@ def is_built_with_rocm():
def is_built_with_gpu_support():
"""Returns whether TensorFlow was built with GPU (i.e. CUDA or ROCm) support."""
return is_built_with_cuda() or is_built_with_rocm()
@tf_export('test.is_built_with_xla')
def is_built_with_xla():
"""Returns whether TensorFlow was built with XLA support."""
return _test_util.IsBuiltWithXLA()

View File

@ -1,8 +1,7 @@
# Description:
# Tools for manipulating TensorFlow graphs.
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "py_binary")
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test")
package(
default_visibility = ["//visibility:public"],
@ -325,7 +324,10 @@ py_library(
":saved_model_utils",
"//tensorflow/python",
"//tensorflow/python/debug:local_cli_wrapper",
],
"//tensorflow/python:tf_optimizer",
] + if_xla_available(
["//tensorflow/compiler/tf2xla:tf2xla_proto_py"],
),
)
py_test(
@ -339,7 +341,10 @@ py_test(
tags = [
"manual",
"no-internal-py3",
"nosan",
],
# Force-include XLA dependencies of saved_model_cli_lib to ensure we test
# the AOT compilation.
deps = [
":saved_model_cli_lib",
"//tensorflow/core:protos_all_py",

View File

@ -25,34 +25,131 @@ from __future__ import print_function
import argparse
import collections
import copy
import hashlib
import os
import pipes
import re
import shlex
import sys
import warnings
import numpy as np
import six
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import versions
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.platform import sysconfig as sysconfig_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as saver_lib
_XLA_DEBUG_OPTIONS_URL = (
'https://github.com/tensorflow/tensorflow/blob/master/'
'tensorflow/compiler/xla/debug_options_flags.cc')
try:
from tensorflow.compiler.aot import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
except ImportError as e:
_pywrap_tfcompile_import_error = ImportError(
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
'with XLA. You may need to build tensorflow with flag '
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
else:
_pywrap_tfcompile_import_error = None
# Set of ops to blacklist.
_OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2'])
def _shlex_quote(s):
if six.PY2:
return pipes.quote(s)
else:
return shlex.quote(s)
def _sysconfig_module():
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
try:
_ = sysconfig_lib.get_include()
except ImportError:
return None
return sysconfig_lib
_XLA_MAKEFILE_TEMPLATE = """
INC = -I{tensorflow_includes}
LIB = -L{compiled_dir}
CXXFLAGS = {cxx_flags}
"""
def _xla_makefile_string(output_prefix):
"""Returns a Makefile string with variables for using XLA binary object files.
Attempts to identify the right include header paths when run from either
an installed TensorFlow pip package, or from bazel run.
Args:
output_prefix: A string containing the output prefix for the XLA AOT
compiled header + object files.
Returns:
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
"""
sysconfig = _sysconfig_module()
output_dir, _ = os.path.split(output_prefix)
if sysconfig:
tensorflow_includes = _shlex_quote(sysconfig.get_include())
else:
# Try hard to find the real source directory if this is a local bazel run.
if os.path.islink(__file__):
this_file = __file__
while os.path.islink(this_file):
this_file = os.readlink(this_file)
base = os.path.realpath(
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
else:
base = test.test_src_dir_path('')
expected_header = os.path.join(
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
if not os.path.exists(expected_header):
logging.error(
'Could not find includes path. Missing file: {}'
.format(expected_header))
tensorflow_includes = base
return _XLA_MAKEFILE_TEMPLATE.format(
tensorflow_includes=tensorflow_includes,
compiled_dir=_shlex_quote(output_dir),
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
versions.CXX11_ABI_FLAG))
def _show_tag_sets(saved_model_dir):
"""Prints the tag-sets stored in SavedModel directory.
@ -653,7 +750,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
if variable_name:
# if file contains a single ndarray, ignore the input name
if isinstance(data, np.ndarray):
warnings.warn(
logging.warn(
'Input file %s contains a single ndarray. Name key \"%s\" ignored.'
% (filename, variable_name))
tensor_key_feed_dict[input_tensor_key] = data
@ -680,7 +777,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
# When input is a python expression:
for input_tensor_key, py_expr_evaluated in input_exprs.items():
if input_tensor_key in tensor_key_feed_dict:
warnings.warn(
logging.warn(
'input_key %s has been specified with both --inputs and --input_exprs'
' options. Value in --input_exprs will be used.' % input_tensor_key)
tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated
@ -688,7 +785,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
# When input is a tf.Example:
for input_tensor_key, example in input_examples.items():
if input_tensor_key in tensor_key_feed_dict:
warnings.warn(
logging.warn(
'input_key %s has been specified in multiple options. Value in '
'--input_examples will be used.' % input_tensor_key)
tensor_key_feed_dict[input_tensor_key] = example
@ -776,20 +873,193 @@ def convert_with_tensorrt(args):
converter.save(output_saved_model_dir=args.output_dir)
def create_parser():
"""Creates a parser that parse the command line arguments.
def aot_compile_cpu(args):
"""Function triggered by aot_compile_cpu command.
Returns:
A namespace parsed from command line arguments.
Args:
args: A namespace parsed from command line.
"""
parser = argparse.ArgumentParser(
description='saved_model_cli: Command-line interface for SavedModel')
parser.add_argument('-v', '--version', action='version', version='0.1.0')
checkpoint_path = (
args.checkpoint_path
or os.path.join(args.dir, 'variables/variables'))
aot_compile_cpu_meta_graph_def(
checkpoint_path=checkpoint_path,
meta_graph_def=saved_model_utils.get_meta_graph_def(
args.dir, args.tag_set),
signature_def_key=args.signature_def_key,
freeze_graph=args.freeze_graph,
output_prefix=args.output_prefix,
cpp_class=args.cpp_class)
subparsers = parser.add_subparsers(
title='commands', description='valid commands', help='additional help')
# show command
def aot_compile_cpu_meta_graph_def(
checkpoint_path,
meta_graph_def,
output_prefix,
signature_def_key,
cpp_class,
freeze_graph=True):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
Use XLA AOT (`tfcompile`) to convert the given meta graph and
signature into a header + object files. Also create an include makefile
that helps identify the appropriate necessary include and library paths
to incorporate these files into your C++ program.
The graph is always optimized with grappler, and optionally (by default)
variables are frozen as constants, before compilation happens.
If the `freeze_graph` is `True`, all variables are embedded as constants
into the graph and binary objects. If it is `False`, then the variable
values become inputs and outputs of the compiled class and the C++
caller must set these values manually.
Args:
checkpoint_path: Python string. Path to checkpoints/variables.
meta_graph_def: Instance of `MetaGraphDef`.
output_prefix: Python string. Path prefix for outputs.
signature_def_key: String, the signature_def to use in the SavedModel.
cpp_class: Name of output C++ class.
freeze_graph: Whether to freeze the graph before compilation.
Raises:
RuntimeError: If tensorflow was not built with XLA.
ImportError: If tensorflow was built with XLA but there was another
issue importing the tfcompile python wrapper.
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
missing or has empty outputs.
"""
if _pywrap_tfcompile_import_error:
raise _pywrap_tfcompile_import_error
signature_def_map = meta_graph_def.signature_def
if signature_def_key not in signature_def_map:
raise ValueError(
'Unable to find signature_def key \'{}\' in signature def map. '
'Available keys: {}'.format(
signature_def_key,
list(signature_def_map.keys())))
signature_def = signature_def_map[signature_def_key]
if not signature_def.outputs:
raise ValueError(
'Signature key {} must have outputs, but saw none:\n{}'.format(
signature_def_key, str(signature_def)))
# This updates graph_def in place.
_replace_input_placeholders_with_default_values(
meta_graph_def.graph_def, signature_def)
graph_def = _optimize_graph(meta_graph_def, signature_def)
if freeze_graph:
# Load the Variables so that we can freeze the graph.
with session.Session(graph=ops_lib.Graph()) as sess:
restorer = saver_lib.import_meta_graph(
meta_graph_def, clear_devices=True)
restorer.restore(sess, checkpoint_path)
graph_def.CopyFrom(
graph_util.convert_variables_to_constants(
sess,
graph_def,
[n.name.split(':')[0] for n in signature_def.outputs.values()]))
temp_dir = test.get_temp_dir()
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(graph_def.SerializeToString())
config = _signature_to_tf2xla_config(
signature_def,
frozen_variables=freeze_graph)
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
config_writer.write(str(config))
output_dir = os.path.dirname(output_prefix)
file_io.recursive_create_dir(output_dir)
entry_digest = hashlib.md5()
entry_digest.update(str(config).encode())
entry_digest.update(str(graph_def).encode())
entry_digest = entry_digest.hexdigest()
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
makefile_writer.write(_xla_makefile_string(output_prefix))
output_prefix = _shlex_quote(output_prefix)
additional_compiler_args = {}
sysconfig = _sysconfig_module()
if sysconfig:
# We're inside PIP and need to pass a customized relative path to the
# appropriate tensorflow headers.
additional_compiler_args['tensorflow_header_root'] = 'tensorflow'
_pywrap_tfcompile.Compile(
graph=frozen_graph_def_location,
config=config_pbtxt_location,
cpp_class=cpp_class,
entry_point='entry_{}'.format(entry_digest),
out_function_object='{}.o'.format(output_prefix),
out_header='{}.h'.format(output_prefix),
out_metadata_object='{}_metadata.o'.format(output_prefix),
gen_name_to_index=True,
# ProgramShape isn't uniquefied by entry_point.
gen_program_shape=False,
**additional_compiler_args)
def _optimize_graph(meta_graph_def, signature_def):
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
# We need to add a collection called 'train_op' so that grappler
# knows what the outputs are.
new_meta_graph_def = copy.deepcopy(meta_graph_def)
fetch_collection = meta_graph_pb2.CollectionDef()
for tensor_info in (
list(signature_def.inputs.values()) +
list(signature_def.outputs.values())):
fetch_collection.node_list.value.append(tensor_info.name)
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
config = config_pb2.ConfigProto()
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
name_to_node_map = dict((n.name, n) for n in graph_def.node)
temp_graph = ops_lib.Graph()
for name, input_ in signature_def.inputs.items():
tensor_name = input_.name.split(':')[0]
if tensor_name not in name_to_node_map:
raise RuntimeError(
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
'Graph nodes are: {}'.format(tensor_name,
list(name_to_node_map.keys())))
node = name_to_node_map[tensor_name]
if node.op not in ('Placeholder', 'PlaceholderV2'):
logging.info(
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
node))
continue
shape = tensor_shape.TensorShape(input_.tensor_shape)
if not shape.is_fully_defined():
raise ValueError(
'Expected fully defined input shape for signature_def \'{}\', '
'tensor name: \'{}\'; but shape is: {}.'
.format(name, tensor_name, shape))
with temp_graph.as_default():
const = array_ops.zeros(shape, dtype=input_.dtype, name=tensor_name)
node.CopyFrom(const.op.node_def)
def add_show_subparser(subparsers):
"""Add parser for `show`."""
show_msg = (
'Usage examples:\n'
'To show all tag-sets in a SavedModel:\n'
@ -833,7 +1103,9 @@ def create_parser():
help='key of SignatureDef to display input(s) and output(s) for')
parser_show.set_defaults(func=show)
# run command
def add_run_subparser(subparsers):
"""Add parser for `run`."""
run_msg = ('Usage example:\n'
'To run input tensors from files through a MetaGraphDef and save'
' the output tensors to files:\n'
@ -909,7 +1181,9 @@ def create_parser():
'This option should be only used if the worker is a TPU job.')
parser_run.set_defaults(func=run)
# scan command
def add_scan_subparser(subparsers):
"""Add parser for `scan`."""
scan_msg = ('Usage example:\n'
'To scan for blacklisted ops in SavedModel:\n'
'$saved_model_cli scan --dir /tmp/saved_model\n'
@ -929,7 +1203,9 @@ def create_parser():
help='tag-set of graph in SavedModel to scan, separated by \',\'')
parser_scan.set_defaults(func=scan)
# convert command
def add_convert_subparser(subparsers):
"""Add parser for `convert`."""
convert_msg = ('Usage example:\n'
'To convert the SavedModel to one that have TensorRT ops:\n'
'$saved_model_cli convert \\\n'
@ -983,9 +1259,161 @@ def create_parser():
'in a TensorRT node'))
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
def add_aot_compile_cpu_subparser(subparsers):
"""Add parser for `aot_compile_cpu`."""
compile_msg = '\n'.join(
['Usage example:',
'To compile a SavedModel signature via (CPU) XLA AOT:',
'$saved_model_cli aot_compile_cpu \\',
' --dir /tmp/saved_model \\',
' --tag_set serve \\',
' --output_dir /tmp/saved_model_xla_aot',
'', '',
'Note: Additional XLA compilation options are available by setting the ',
'XLA_FLAGS environment variable. See the XLA debug options flags for ',
'all the options: ',
' {}'.format(_XLA_DEBUG_OPTIONS_URL),
'',
'For example, to disable XLA fast math when compiling:',
'',
'XLA_FLAGS="--xla_cpu_enable_fast_math=false" $saved_model_cli '
'aot_compile_cpu ...',
'',
'Some possibly useful flags:',
' --xla_cpu_enable_fast_math=false',
' --xla_cpu_multi_thread_eigen=false',
' --xla_force_host_platform_device_count=<num threads>',
' (useful in conjunction with disabling eigen multi threading)'
])
parser_compile = subparsers.add_parser(
'aot_compile_cpu',
description=compile_msg,
formatter_class=argparse.RawTextHelpFormatter)
parser_compile.add_argument(
'--dir',
type=str,
required=True,
help='directory containing the SavedModel to convert')
parser_compile.add_argument(
'--output_prefix',
type=str,
required=True,
help=('output directory + filename prefix for the resulting header(s) '
'and object file(s)'))
parser_compile.add_argument(
'--tag_set',
type=str,
required=True,
help='tag-set of graph in SavedModel to convert, separated by \',\'')
parser_compile.add_argument(
'--signature_def_key',
type=str,
default=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
help=('signature_def key to use. '
'default: DEFAULT_SERVING_SIGNATURE_DEF_KEY'))
parser_compile.add_argument(
'--checkpoint_path',
type=str,
default=None,
help='Custom checkpoint to use (default: use the SavedModel variables)')
parser_compile.add_argument(
'--cpp_class',
type=str,
required=True,
help=('The name of the generated C++ class, wrapping the generated '
'function. The syntax of this flag is '
'[[<optional_namespace>::],...]<class_name>. This mirrors the '
'C++ syntax for referring to a class, where multiple namespaces '
'may precede the class name, separated by double-colons. '
'The class will be generated in the given namespace(s), or if no '
'namespaces are given, within the global namespace.'))
parser_compile.add_argument(
'--freeze_graph',
type=bool,
default=True,
help=('Whether to freeze the tf.Variables into the graph. If false, '
'then all Variables in the closure of the signature graph path '
'be be added as input and output args to the XLA-compiled graph '
'(not currently supported)'))
parser_compile.set_defaults(func=aot_compile_cpu)
def create_parser():
"""Creates a parser that parse the command line arguments.
Returns:
A namespace parsed from command line arguments.
"""
parser = argparse.ArgumentParser(
description='saved_model_cli: Command-line interface for SavedModel')
parser.add_argument('-v', '--version', action='version', version='0.1.0')
subparsers = parser.add_subparsers(
title='commands', description='valid commands', help='additional help')
# show command
add_show_subparser(subparsers)
# run command
add_run_subparser(subparsers)
# scan command
add_scan_subparser(subparsers)
# tensorrt convert command
add_convert_subparser(subparsers)
# aot_compile_cpu command
add_aot_compile_cpu_subparser(subparsers)
return parser
def _signature_to_tf2xla_config(signature_def, frozen_variables):
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
Args:
signature_def: Instance of `SignatureDef`.
frozen_variables: Python bool, whether variables are being frozen or not.
Returns:
An instance of `tf2xla.Config` proto.
Raises:
RuntimeError: If TensorFlow was not compiled with XLA.
"""
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
config = tf2xla_pb2.Config()
tensor_id = tf2xla_pb2.TensorId
for name, input_ in signature_def.inputs.items():
(node_name, output_index) = input_.name.split(':')
output_index = int(output_index)
config.feed.append(
tf2xla_pb2.Feed(
id=tensor_id(node_name=node_name, output_index=output_index),
name=name,
type=input_.dtype,
shape=input_.tensor_shape))
for name, output_ in signature_def.outputs.items():
(node_name, output_index) = output_.name.split(':')
output_index = int(output_index)
config.fetch.append(
tf2xla_pb2.Fetch(
id=tensor_id(node_name=node_name, output_index=output_index),
name=name,
type=output_.dtype,
shape=output_.tensor_shape))
if not frozen_variables:
# Extract all variables along the path and add to config
raise NotImplementedError('Non-frozen graphs are not supported.')
return config
def main():
parser = create_parser()
args = parser.parse_args()

View File

@ -35,6 +35,8 @@ from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import save
from tensorflow.python.tools import saved_model_cli
@ -709,6 +711,63 @@ Defined Functions:
output = out.getvalue().strip()
self.assertTrue('\'VariableV2\'' in output)
def testAOTCompileCPUWrongSignatureDefKey(self):
if not test.is_built_with_xla():
self.skipTest('Skipping test because XLA is not compiled in.')
self.parser = saved_model_cli.create_parser()
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
args = self.parser.parse_args(
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
'--output_prefix', output_dir,
'--cpp_class', 'Compiled',
'--signature_def_key', 'MISSING'])
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
saved_model_cli.aot_compile_cpu(args)
def testAOTCompileCPUFreezesAndCompiles(self):
if not test.is_built_with_xla():
self.skipTest('Skipping test because XLA is not compiled in.')
class DummyModel(tracking.AutoTrackable):
"""Model compatible with XLA compilation."""
def __init__(self):
self.var = variables.Variable(1.0, name='my_var')
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
])
def func2(self, x):
return {'res': x + self.var}
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
dummy_model = DummyModel()
with self.cached_session():
self.evaluate(dummy_model.var.initializer)
save.save(dummy_model, saved_model_dir)
self.parser = saved_model_cli.create_parser()
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
args = self.parser.parse_args(
['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
'--output_prefix', output_prefix,
'--cpp_class', 'Generated']) # Use the default seving signature_key.
saved_model_cli.aot_compile_cpu(args)
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
self.assertTrue(
file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
self.assertIn('class Generated', header_contents)
self.assertIn('arg_x_data', header_contents)
self.assertIn('result_res_data', header_contents)
makefile_contents = file_io.read_file_to_string(
'{}_makefile.inc'.format(output_prefix))
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
if __name__ == '__main__':
test.main()

View File

@ -20,6 +20,7 @@ limitations under the License.
PYBIND11_MODULE(_pywrap_util_port, m) {
m.def("IsGoogleCudaEnabled", tensorflow::IsGoogleCudaEnabled);
m.def("IsBuiltWithROCm", tensorflow::IsBuiltWithROCm);
m.def("IsBuiltWithXLA", tensorflow::IsBuiltWithXLA);
m.def("IsBuiltWithNvcc", tensorflow::IsBuiltWithNvcc);
m.def("GpuSupportsHalfMatMulAndConv",
tensorflow::GpuSupportsHalfMatMulAndConv);

View File

@ -55,6 +55,11 @@ load(
VERSION = "2.1.0"
VERSION_MAJOR = VERSION.split(".")[0]
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
return str(Label(dep))
def if_v2(a):
return select({
clean_dep("//tensorflow:api_version_2"): a,
@ -76,6 +81,12 @@ def if_nvcc(a):
def if_cuda_is_configured_compat(x):
return if_cuda_is_configured(x)
def if_xla_available(if_true, if_false = []):
return select({
clean_dep("//tensorflow:with_xla_support"): if_true,
"//conditions:default": if_false,
})
# Given a source file, generate a test name.
# i.e. "common_runtime/direct_session_test.cc" becomes
# "common_runtime_direct_session_test"
@ -113,11 +124,6 @@ def tf_portable_proto_library(name, proto_deps, deps = [], **kwargs):
_ignore = [kwargs]
native.cc_library(name = name, deps = deps + [dep + "_cc" for dep in proto_deps])
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
return str(Label(dep))
def if_android_x86(a):
return select({
clean_dep("//tensorflow:android_x86"): a,
@ -304,6 +310,7 @@ def tf_copts(
(if_not_windows(["-fno-exceptions"]) if not allow_exceptions else []) +
if_cuda(["-DGOOGLE_CUDA=1"]) +
if_nvcc(["-DTENSORFLOW_USE_NVCC=1"]) +
if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) +
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
@ -1418,7 +1425,7 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
]) + if_rocm_is_configured(cuda_deps + [
"@local_config_rocm//rocm:rocm_headers",
]),
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs
)
@ -2458,6 +2465,7 @@ def pybind_extension(
copts = [],
linkopts = [],
deps = [],
defines = [],
visibility = None,
testonly = None,
licenses = None,
@ -2524,6 +2532,7 @@ def pybind_extension(
exported_symbols_file,
version_script_file,
],
defines = defines,
features = features + ["-use_header_modules"],
linkshared = 1,
testonly = testonly,
@ -2569,6 +2578,7 @@ def tf_python_pybind_extension(
copts = [],
hdrs = [],
deps = [],
defines = [],
visibility = None):
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
@ -2583,9 +2593,20 @@ def tf_python_pybind_extension(
copts = copts,
hdrs = hdrs,
deps = deps + tf_binary_pybind_deps() + mkl_deps(),
defines = defines,
visibility = visibility,
)
def tf_pybind_cc_library_wrapper(name, deps, visibility = None):
"""Wrapper for cc_library and proto dependencies used by tf_python_pybind_extension.
This wrapper ensures that cc libraries' and protos' headers are made
available to pybind code, without creating ODR violations in the dynamically
linked case. The symbols in these deps symbols should be linked to, and
exported by, the core pywrap_tensorflow_internal.so
"""
cc_header_only_library(name = name, deps = deps, visibility = visibility)
def if_cuda_or_rocm(if_true, if_false = []):
"""Shorthand for select()'ing whether to build for either CUDA or ROCm.
@ -2621,8 +2642,8 @@ def tf_jit_compilation_passes_extra_deps():
def if_mlir(if_true, if_false = []):
return select({
str(Label("//tensorflow:with_mlir_support")): if_true,
"//conditions:default": if_false,
"//tensorflow:with_mlir_support": if_true,
})
def tfcompile_extra_flags():

View File

@ -7,3 +7,4 @@
*TFE_*
*nsync_*
*stream_executor*
*xla*

View File

@ -8,6 +8,7 @@ tensorflow {
*TFE_*;
*nsync_*;
*stream_executor*;
*xla*;
local:
*;
};

View File

@ -56,6 +56,10 @@ tf_module {
name: "is_built_with_rocm"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_built_with_xla"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_gpu_available"
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "

View File

@ -40,6 +40,10 @@ tf_module {
name: "is_built_with_rocm"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_built_with_xla"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_gpu_available"
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "

View File

@ -1,7 +1,7 @@
# Description:
# Tools for building the TensorFlow pip package.
load("//tensorflow:tensorflow.bzl", "if_windows", "transitive_hdrs")
load("//tensorflow:tensorflow.bzl", "if_windows", "if_xla_available", "transitive_hdrs")
load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
@ -104,7 +104,9 @@ COMMON_PIP_DEPS = [
"//tensorflow/tools/docs:generate_lib",
"//tensorflow/tools/docs:parser",
"//tensorflow/tools/docs:py_guide_parser",
]
] + if_xla_available([
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
])
# On Windows, python binary is a zip file of runfiles tree.
# Add everything to its data dependency for generating a runfiles tree

View File

@ -18,3 +18,4 @@ recursive-include tensorflow_core/include/google *.inc
recursive-include tensorflow_core/include/include *.h
recursive-include tensorflow_core/include/third_party *
recursive-include tensorflow_core/include/unsupported *

View File

@ -245,6 +245,7 @@ else:
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.so'
headers = (
list(find_files('*.h', 'tensorflow_core/compiler')) +
list(find_files('*.h', 'tensorflow_core/core')) +
list(find_files('*.h', 'tensorflow_core/stream_executor')) +
list(find_files('*.h', 'google/com_google_protobuf/src')) +