[TF XLA] Add ability to convert SavedModel subgraphs to compiled [XLA CPU] objects via saved_model_cli.
You can now run, e.g.: saved_model_cli aot_compile_cpu \ --dir /path/to/saved_model \ --tag_set serve \ --signature_def_key action \ --output_prefix /tmp/out \ --cpp_class Serving::Action Which will create the files: /tmp/{out.h, out.o, out_metadata.o, out_makefile.inc} where out.h defines something like: namespace Serving { class Action { ... } } and out_makefile.inc provides the additional flags required to include the header and object files into your build. You can optionally also point aot_compile_cpu to a newer set of checkpoints (weight values) by using the optional argument --checkpoint_path. Also added `tf.test.is_built_with_xla()`. TESTED: * bazel test -c opt :saved_model_cli_test passes * built and installed the pip wheel and tested in the bazel directory via: TEST_SRCDIR=/tmp/tfcompile/bazel-bin/tensorflow/python/tools/saved_model_cli_test.runfiles/ python saved_model_cli_test.py and checking the output files to ensure the proper includes and header directories are set in out_makefile.inc and out.h. PiperOrigin-RevId: 290144104 Change-Id: If8eb6c3334b3042c4b9c24813b1b52c06d8fbc06
This commit is contained in:
parent
3bcfb829bb
commit
9959c04433
tensorflow
BUILD
compiler
aot
BUILDaot_only_var_handle_op.ccaot_only_var_handle_op.hcodegen.cccodegen_test.cccompile.cccompile.hflags.ccflags.htfcompile_main.cctfcompile_wrapper.cc
tf2xla
xla/service/cpu
core
python
tensorflow.bzltf_exported_symbols.ldstf_version_script.ldstools
@ -2,6 +2,7 @@
|
||||
# TensorFlow is a computational framework, primarily for use in machine
|
||||
# learning applications.
|
||||
|
||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
|
@ -1,6 +1,13 @@
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
|
||||
|
||||
# buildifier: disable=same-origin-load
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
@ -27,9 +34,14 @@ cc_library(
|
||||
"compile.h",
|
||||
"flags.h",
|
||||
],
|
||||
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
":aot_only_var_handle_op",
|
||||
":embedded_protocol_buffers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/compiler/tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||
@ -53,12 +65,45 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
# Necessary for the pywrap inclusion below.
|
||||
tf_pybind_cc_library_wrapper(
|
||||
name = "tfcompile_headers_lib",
|
||||
deps = [
|
||||
":tfcompile_lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_tfcompile",
|
||||
srcs = ["tfcompile_wrapper.cc"],
|
||||
features = ["-layering_check"],
|
||||
module_name = "_pywrap_tfcompile",
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
":tfcompile_headers_lib",
|
||||
"@pybind11",
|
||||
"//third_party/python_runtime:headers",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
# These headers cannot be brought in via cc_header_only_library
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
] + if_llvm_aarch64_available([
|
||||
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "codegen_test",
|
||||
srcs = ["codegen_test.cc"],
|
||||
@ -104,11 +149,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||
"@llvm-project//llvm:target",
|
||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||
],
|
||||
)
|
||||
|
||||
@ -214,8 +254,13 @@ cc_library(
|
||||
cc_library(
|
||||
name = "aot_only_var_handle_op",
|
||||
srcs = ["aot_only_var_handle_op.cc"],
|
||||
hdrs = ["aot_only_var_handle_op.h"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/tf2xla:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
|
||||
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
|
||||
.Doc(R"doc(
|
||||
Internal VarHandleOp registration used for XLA AOT compilation.
|
||||
)doc")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("dtype: type")
|
||||
.Attr("shape: shape")
|
||||
.Output("resource: resource")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
DataType t;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
|
||||
PartialTensorShape p;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
|
||||
shape_inference::ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
|
||||
c->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
|
||||
XlaAotOnlyVarHandleOp);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
static constexpr const char* const kXlaAotOnlyVarHandleOp =
|
||||
"_XlaAotOnlyVarHandleOp";
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||
const string include_xla_data_proto =
|
||||
opts.gen_program_shape
|
||||
?
|
||||
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||
: "";
|
||||
|
||||
const string include_hlo_profile_printer_data_proto =
|
||||
@ -458,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
|
||||
{{INCLUDE_XLA_DATA_PROTO}}
|
||||
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
|
||||
|
||||
namespace Eigen { struct ThreadPoolDevice; }
|
||||
namespace xla { class ExecutableRunOptions; }
|
||||
@ -660,6 +659,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
||||
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
|
||||
{"{{ENTRY}}", compile_result.entry_point},
|
||||
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
||||
metadata_result.hlo_profile_printer_data_access_shim},
|
||||
|
@ -197,6 +197,7 @@ TEST(CodegenTest, Golden) {
|
||||
variable3->mutable_shape()->add_dim()->set_size(5);
|
||||
variable3->set_type(DT_INT32);
|
||||
CompileResult compile_result;
|
||||
compile_result.tensorflow_header_root = "third_party/tensorflow";
|
||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||
{},
|
||||
{BufferInfo::MakeTempBuffer(1),
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -83,6 +85,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||
std::move(aot_or.ValueOrDie().back()));
|
||||
compile_result->entry_point = aot_opts.entry_point_name();
|
||||
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
|
||||
compile_result->pointer_size =
|
||||
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
||||
return Status::OK();
|
||||
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
|
||||
} // namespace
|
||||
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result) {
|
||||
// Converts the graph into an XLA computation, and compiles the
|
||||
// computation.
|
||||
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
if (!flags.mlir_components.empty()) {
|
||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToXla(graph_def, config, client, &computation));
|
||||
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||
client, &computation));
|
||||
}
|
||||
if (!flags.out_session_module.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||
@ -127,10 +130,102 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
xla::cpu::CpuAotCompilationOptions aot_opts(
|
||||
flags.target_triple, flags.target_cpu, flags.target_features,
|
||||
flags.entry_point,
|
||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
|
||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
|
||||
flags.tensorflow_header_root);
|
||||
|
||||
return CompileXla(client, computation, aot_opts, compile_result);
|
||||
}
|
||||
|
||||
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
static std::once_flag targets_init;
|
||||
|
||||
static void InitializeTargets() {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
#if TF_LLVM_AARCH64_AVAILABLE
|
||||
LLVMInitializeAArch64Target();
|
||||
LLVMInitializeAArch64TargetInfo();
|
||||
LLVMInitializeAArch64TargetMC();
|
||||
LLVMInitializeAArch64AsmPrinter();
|
||||
#endif
|
||||
LLVMInitializeARMTarget();
|
||||
LLVMInitializeARMTargetInfo();
|
||||
LLVMInitializeARMTargetMC();
|
||||
LLVMInitializeARMAsmPrinter();
|
||||
LLVMInitializePowerPCTarget();
|
||||
LLVMInitializePowerPCTargetInfo();
|
||||
LLVMInitializePowerPCTargetMC();
|
||||
LLVMInitializePowerPCAsmPrinter();
|
||||
LLVMInitializeX86Target();
|
||||
LLVMInitializeX86TargetInfo();
|
||||
LLVMInitializeX86TargetMC();
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
std::call_once(targets_init, &InitializeTargets);
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
if (flags.graph.empty()) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
MetadataResult metadata_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||
metadata_result.object_file_data));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||
metadata_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
@ -35,6 +35,7 @@ struct CompileResult {
|
||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||
string entry_point; // Name of generated function.
|
||||
string tensorflow_header_root; // Prefix for tensorflow headers.
|
||||
int pointer_size = 0; // Size of a pointer in bytes.
|
||||
};
|
||||
|
||||
@ -42,9 +43,12 @@ struct CompileResult {
|
||||
// that performs the graph operations.
|
||||
//
|
||||
// The XLA compilation options are specified in the flags.
|
||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
||||
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
const MainFlags& flags, CompileResult* compile_result);
|
||||
|
||||
// The full compilation method, for reuse in a library setting.
|
||||
Status Main(const MainFlags& flags);
|
||||
|
||||
} // namespace tfcompile
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -74,6 +74,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||
{"gen_program_shape", &flags->gen_program_shape,
|
||||
"Generate program shape data for the ProgramShape method."},
|
||||
{"tensorflow_header_root", &flags->tensorflow_header_root,
|
||||
"Root directory of tensorflow headers."},
|
||||
};
|
||||
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
|
@ -25,6 +25,7 @@ namespace tensorflow {
|
||||
namespace tfcompile {
|
||||
|
||||
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
||||
|
||||
struct MainFlags {
|
||||
string graph;
|
||||
string config;
|
||||
@ -39,6 +40,7 @@ struct MainFlags {
|
||||
string out_header;
|
||||
string out_session_module;
|
||||
string mlir_components;
|
||||
string tensorflow_header_root;
|
||||
|
||||
// C++ codegen options
|
||||
bool gen_name_to_index = false;
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
@ -56,88 +55,6 @@ const char kUsageHeader[] =
|
||||
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
||||
"\n";
|
||||
|
||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||
return ReadTextProto(Env::Default(), fname, proto);
|
||||
} else {
|
||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||
}
|
||||
}
|
||||
|
||||
Status Main(const MainFlags& flags) {
|
||||
// Initialize all LLVM targets so we can cross compile.
|
||||
LLVMInitializeAArch64Target();
|
||||
LLVMInitializeAArch64TargetInfo();
|
||||
LLVMInitializeAArch64TargetMC();
|
||||
LLVMInitializeAArch64AsmPrinter();
|
||||
LLVMInitializeARMTarget();
|
||||
LLVMInitializeARMTargetInfo();
|
||||
LLVMInitializeARMTargetMC();
|
||||
LLVMInitializeARMAsmPrinter();
|
||||
LLVMInitializePowerPCTarget();
|
||||
LLVMInitializePowerPCTargetInfo();
|
||||
LLVMInitializePowerPCTargetMC();
|
||||
LLVMInitializePowerPCAsmPrinter();
|
||||
LLVMInitializeX86Target();
|
||||
LLVMInitializeX86TargetInfo();
|
||||
LLVMInitializeX86TargetMC();
|
||||
LLVMInitializeX86AsmPrinter();
|
||||
|
||||
// Process config.
|
||||
tf2xla::Config config;
|
||||
if (flags.config.empty()) {
|
||||
return errors::InvalidArgument("Must specify --config");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||
if (flags.dump_fetch_nodes) {
|
||||
std::set<string> nodes;
|
||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||
nodes.insert(fetch.id().node_name());
|
||||
}
|
||||
std::cout << absl::StrJoin(nodes, ",");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read and initialize the graph.
|
||||
if (flags.graph.empty()) {
|
||||
return errors::InvalidArgument("Must specify --graph");
|
||||
}
|
||||
GraphDef graph_def;
|
||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||
CompileResult compile_result;
|
||||
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
||||
|
||||
// Write output files.
|
||||
Env* env = Env::Default();
|
||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||
TF_RETURN_IF_ERROR(
|
||||
WriteStringToFile(env, flags.out_function_object,
|
||||
absl::string_view(obj.data(), obj.size())));
|
||||
CodegenOpts codegen_opts;
|
||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||
codegen_opts.target_triple = flags.target_triple;
|
||||
if (flags.cpp_class.empty()) {
|
||||
return errors::InvalidArgument("Must specify --cpp_class");
|
||||
}
|
||||
codegen_opts.gen_hlo_profile_printer_data =
|
||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||
&codegen_opts.namespaces));
|
||||
|
||||
MetadataResult metadata_result;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||
metadata_result.object_file_data));
|
||||
string header;
|
||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||
metadata_result, &header));
|
||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace tfcompile
|
||||
} // end namespace tensorflow
|
||||
|
||||
@ -148,6 +65,7 @@ int main(int argc, char** argv) {
|
||||
flags.out_metadata_object = "out_helper.o";
|
||||
flags.out_header = "out.h";
|
||||
flags.entry_point = "entry";
|
||||
flags.tensorflow_header_root = "third_party/tensorflow";
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
|
75
tensorflow/compiler/aot/tfcompile_wrapper.cc
Normal file
75
tensorflow/compiler/aot/tfcompile_wrapper.cc
Normal file
@ -0,0 +1,75 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "include/pybind11/cast.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/pytypes.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "tensorflow/compiler/aot/compile.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_tfcompile, m) {
|
||||
m.doc() = R"pbdoc(
|
||||
_pywrap_tfcompile
|
||||
-----
|
||||
)pbdoc";
|
||||
|
||||
m.def(
|
||||
"Compile",
|
||||
[](std::string graph, std::string config, std::string target_triple,
|
||||
std::string target_cpu, std::string target_features,
|
||||
std::string entry_point, std::string cpp_class,
|
||||
std::string out_function_object, std::string out_metadata_object,
|
||||
std::string out_header, std::string out_session_module,
|
||||
std::string mlir_components, std::string tensorflow_header_root,
|
||||
bool gen_name_to_index, bool gen_program_shape) {
|
||||
tensorflow::tfcompile::MainFlags flags;
|
||||
flags.graph = std::move(graph);
|
||||
flags.config = std::move(config);
|
||||
flags.target_triple = std::move(target_triple);
|
||||
flags.target_cpu = std::move(target_cpu);
|
||||
flags.target_features = std::move(target_features);
|
||||
flags.entry_point = std::move(entry_point);
|
||||
flags.cpp_class = std::move(cpp_class);
|
||||
flags.out_function_object = std::move(out_function_object);
|
||||
flags.out_metadata_object = std::move(out_metadata_object);
|
||||
flags.out_header = std::move(out_header);
|
||||
flags.out_session_module = std::move(out_session_module);
|
||||
flags.mlir_components = std::move(mlir_components);
|
||||
flags.tensorflow_header_root = std::move(tensorflow_header_root);
|
||||
|
||||
// C++ codegen options
|
||||
flags.gen_name_to_index = gen_name_to_index;
|
||||
flags.gen_program_shape = gen_program_shape;
|
||||
|
||||
tensorflow::MaybeRaiseFromStatus(tensorflow::tfcompile::Main(flags));
|
||||
},
|
||||
py::arg("graph") = "", py::arg("config") = "",
|
||||
py::arg("target_triple") = "x86_64-pc-linux", py::arg("target_cpu") = "",
|
||||
py::arg("target_features") = "", py::arg("entry_point") = "entry",
|
||||
py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o",
|
||||
py::arg("out_metadata_object") = "out_helper.o",
|
||||
py::arg("out_header") = "out.h", py::arg("out_session_module") = "",
|
||||
py::arg("mlir_components") = "",
|
||||
py::arg("tensorflow_header_root") = "third_party/tensorflow",
|
||||
py::arg("gen_name_to_index") = false,
|
||||
py::arg("gen_program_shape") = false);
|
||||
}
|
@ -5,6 +5,7 @@ load(
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
|
||||
@ -62,7 +63,7 @@ tf_cc_binary(
|
||||
deps = [":tf2xla_supported_ops_lib"],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
tf_proto_library(
|
||||
name = "tf2xla_proto",
|
||||
srcs = ["tf2xla.proto"],
|
||||
cc_api_version = 2,
|
||||
@ -140,6 +141,7 @@ cc_library(
|
||||
":tf2xla_proto_cc",
|
||||
":tf2xla_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/aot:aot_only_var_handle_op",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
@ -126,12 +127,28 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
|
||||
for (auto& node : *graph_def->mutable_node()) {
|
||||
if (node.op() == "VarHandleOp") {
|
||||
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
|
||||
}
|
||||
}
|
||||
for (auto& fn : *graph_def->mutable_library()->mutable_function()) {
|
||||
for (auto& node : *fn.mutable_node_def()) {
|
||||
if (node.op() == "VarHandleOp") {
|
||||
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertGraphDefToXla(const GraphDef& graph_def,
|
||||
const tf2xla::Config& config, xla::Client* client,
|
||||
Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
|
||||
xla::Client* client,
|
||||
xla::XlaComputation* computation) {
|
||||
std::unique_ptr<Graph> graph;
|
||||
ConvertVarHandlesToAotVarHandles(&graph_def);
|
||||
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphToXla(std::move(graph), config, client, computation));
|
||||
|
@ -30,8 +30,8 @@ namespace tensorflow {
|
||||
//
|
||||
// The computation is built in the context of the given `client`, which may
|
||||
// subsequently be used to compile or execute the computation.
|
||||
Status ConvertGraphDefToXla(const GraphDef& graph_def,
|
||||
const tf2xla::Config& config, xla::Client* client,
|
||||
Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
|
||||
xla::Client* client,
|
||||
xla::XlaComputation* computation);
|
||||
|
||||
// Similar to ConvertGraphDefToXla, but uses MLIR.
|
||||
|
@ -119,12 +119,13 @@ using BufferInfo = cpu_function_runtime::BufferInfo;
|
||||
|
||||
CpuAotCompilationOptions::CpuAotCompilationOptions(
|
||||
string triple, string cpu_name, string features, string entry_point_name,
|
||||
RelocationModel relocation_model)
|
||||
RelocationModel relocation_model, string tensorflow_header_root)
|
||||
: triple_(std::move(triple)),
|
||||
cpu_name_(std::move(cpu_name)),
|
||||
features_(std::move(features)),
|
||||
entry_point_name_(std::move(entry_point_name)),
|
||||
relocation_model_(relocation_model) {}
|
||||
relocation_model_(relocation_model),
|
||||
tensorflow_header_root_(std::move(tensorflow_header_root)) {}
|
||||
|
||||
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
|
||||
|
||||
|
@ -53,7 +53,17 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
|
||||
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
||||
string entry_point_name,
|
||||
RelocationModel relocation_model);
|
||||
RelocationModel relocation_model,
|
||||
string tensorflow_header_root);
|
||||
|
||||
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
||||
string entry_point_name,
|
||||
RelocationModel relocation_model)
|
||||
: CpuAotCompilationOptions(
|
||||
std::move(triple), std::move(cpu_name), std::move(features),
|
||||
std::move(entry_point_name), relocation_model,
|
||||
/*tensorflow_header_root=*/"third_party/tensorflow") {}
|
||||
|
||||
~CpuAotCompilationOptions() override;
|
||||
|
||||
se::Platform::Id PlatformId() const override;
|
||||
@ -66,6 +76,10 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
const string& features() const { return features_; }
|
||||
// The name to be used for the compiled code's entry point.
|
||||
const string& entry_point_name() const { return entry_point_name_; }
|
||||
// The prefix for tensorflow headers, e.g. "third_party/tensorflow".
|
||||
const string& tensorflow_header_root() const {
|
||||
return tensorflow_header_root_;
|
||||
}
|
||||
// The relocation model used for compilation.
|
||||
RelocationModel relocation_model() const { return relocation_model_; }
|
||||
|
||||
@ -75,6 +89,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
const string features_;
|
||||
const string entry_point_name_;
|
||||
const RelocationModel relocation_model_;
|
||||
const string tensorflow_header_root_;
|
||||
};
|
||||
|
||||
class CpuAotCompilationResult : public AotCompilationResult {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
load(
|
||||
"//tensorflow/core/platform/default:build_config.bzl",
|
||||
_if_llvm_aarch64_available = "if_llvm_aarch64_available",
|
||||
_pyx_library = "pyx_library",
|
||||
_tf_additional_all_protos = "tf_additional_all_protos",
|
||||
_tf_additional_binary_deps = "tf_additional_binary_deps",
|
||||
@ -80,3 +81,4 @@ tf_protos_profiler_impl = _tf_protos_profiler_impl
|
||||
tf_py_clif_cc = _tf_py_clif_cc
|
||||
tf_pyclif_proto_library = _tf_pyclif_proto_library
|
||||
tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
|
||||
if_llvm_aarch64_available = _if_llvm_aarch64_available
|
||||
|
@ -770,3 +770,8 @@ def tf_google_mobile_srcs_no_runtime():
|
||||
|
||||
def tf_google_mobile_srcs_only_runtime():
|
||||
return []
|
||||
|
||||
def if_llvm_aarch64_available(then, otherwise = []):
|
||||
# TODO(b/...): The TF XLA build fails when adding a dependency on
|
||||
# @llvm/llvm-project/llvm:aarch64_target.
|
||||
return otherwise
|
||||
|
@ -34,6 +34,14 @@ bool IsBuiltWithROCm() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsBuiltWithXLA() {
|
||||
#if TENSORFLOW_USE_XLA
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool IsBuiltWithNvcc() {
|
||||
#if TENSORFLOW_USE_NVCC
|
||||
return true;
|
||||
|
@ -24,6 +24,9 @@ bool IsGoogleCudaEnabled();
|
||||
// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm)
|
||||
bool IsBuiltWithROCm();
|
||||
|
||||
// Returns true if TENSORFLOW_USE_XLA is defined. (i.e. TF is built with XLA)
|
||||
bool IsBuiltWithXLA();
|
||||
|
||||
// Returns true if TENSORFLOW_USE_NVCC is defined. (i.e. TF is built with nvcc)
|
||||
bool IsBuiltWithNvcc();
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Public targets:
|
||||
# ":platform" - Low-level and platform-specific Python code.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
@ -1109,10 +1109,12 @@ py_library(
|
||||
":tensor_util",
|
||||
":type_spec",
|
||||
":util",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
"//tensorflow/python/eager:context",
|
||||
] + if_xla_available([
|
||||
"//tensorflow/compiler/aot:_pywrap_tfcompile",
|
||||
]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -5553,6 +5555,8 @@ tf_py_wrap_cc(
|
||||
] + (tf_additional_lib_deps() +
|
||||
tf_additional_plugin_deps()) + if_ngraph([
|
||||
"@ngraph_tf//:ngraph_tf",
|
||||
]) + if_xla_available([
|
||||
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -284,6 +284,10 @@ def IsBuiltWithROCm():
|
||||
return _pywrap_util_port.IsBuiltWithROCm()
|
||||
|
||||
|
||||
def IsBuiltWithXLA():
|
||||
return _pywrap_util_port.IsBuiltWithXLA()
|
||||
|
||||
|
||||
def IsBuiltWithNvcc():
|
||||
return _pywrap_util_port.IsBuiltWithNvcc()
|
||||
|
||||
|
@ -106,3 +106,9 @@ def is_built_with_rocm():
|
||||
def is_built_with_gpu_support():
|
||||
"""Returns whether TensorFlow was built with GPU (i.e. CUDA or ROCm) support."""
|
||||
return is_built_with_cuda() or is_built_with_rocm()
|
||||
|
||||
|
||||
@tf_export('test.is_built_with_xla')
|
||||
def is_built_with_xla():
|
||||
"""Returns whether TensorFlow was built with XLA support."""
|
||||
return _test_util.IsBuiltWithXLA()
|
||||
|
@ -1,8 +1,7 @@
|
||||
# Description:
|
||||
# Tools for manipulating TensorFlow graphs.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "py_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -325,7 +324,10 @@ py_library(
|
||||
":saved_model_utils",
|
||||
"//tensorflow/python",
|
||||
"//tensorflow/python/debug:local_cli_wrapper",
|
||||
],
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
] + if_xla_available(
|
||||
["//tensorflow/compiler/tf2xla:tf2xla_proto_py"],
|
||||
),
|
||||
)
|
||||
|
||||
py_test(
|
||||
@ -339,7 +341,10 @@ py_test(
|
||||
tags = [
|
||||
"manual",
|
||||
"no-internal-py3",
|
||||
"nosan",
|
||||
],
|
||||
# Force-include XLA dependencies of saved_model_cli_lib to ensure we test
|
||||
# the AOT compilation.
|
||||
deps = [
|
||||
":saved_model_cli_lib",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -25,34 +25,131 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
import pipes
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import meta_graph as meta_graph_lib
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import app # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import sysconfig as sysconfig_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
_XLA_DEBUG_OPTIONS_URL = (
|
||||
'https://github.com/tensorflow/tensorflow/blob/master/'
|
||||
'tensorflow/compiler/xla/debug_options_flags.cc')
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow.compiler.aot import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
|
||||
except ImportError as e:
|
||||
_pywrap_tfcompile_import_error = ImportError(
|
||||
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
|
||||
'with XLA. You may need to build tensorflow with flag '
|
||||
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
|
||||
else:
|
||||
_pywrap_tfcompile_import_error = None
|
||||
|
||||
|
||||
# Set of ops to blacklist.
|
||||
_OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2'])
|
||||
|
||||
|
||||
def _shlex_quote(s):
|
||||
if six.PY2:
|
||||
return pipes.quote(s)
|
||||
else:
|
||||
return shlex.quote(s)
|
||||
|
||||
|
||||
def _sysconfig_module():
|
||||
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
|
||||
try:
|
||||
_ = sysconfig_lib.get_include()
|
||||
except ImportError:
|
||||
return None
|
||||
return sysconfig_lib
|
||||
|
||||
|
||||
_XLA_MAKEFILE_TEMPLATE = """
|
||||
INC = -I{tensorflow_includes}
|
||||
LIB = -L{compiled_dir}
|
||||
CXXFLAGS = {cxx_flags}
|
||||
"""
|
||||
|
||||
|
||||
def _xla_makefile_string(output_prefix):
|
||||
"""Returns a Makefile string with variables for using XLA binary object files.
|
||||
|
||||
Attempts to identify the right include header paths when run from either
|
||||
an installed TensorFlow pip package, or from bazel run.
|
||||
|
||||
Args:
|
||||
output_prefix: A string containing the output prefix for the XLA AOT
|
||||
compiled header + object files.
|
||||
|
||||
Returns:
|
||||
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
|
||||
"""
|
||||
sysconfig = _sysconfig_module()
|
||||
output_dir, _ = os.path.split(output_prefix)
|
||||
if sysconfig:
|
||||
tensorflow_includes = _shlex_quote(sysconfig.get_include())
|
||||
else:
|
||||
# Try hard to find the real source directory if this is a local bazel run.
|
||||
if os.path.islink(__file__):
|
||||
this_file = __file__
|
||||
while os.path.islink(this_file):
|
||||
this_file = os.readlink(this_file)
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
|
||||
else:
|
||||
base = test.test_src_dir_path('')
|
||||
expected_header = os.path.join(
|
||||
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
|
||||
if not os.path.exists(expected_header):
|
||||
logging.error(
|
||||
'Could not find includes path. Missing file: {}'
|
||||
.format(expected_header))
|
||||
tensorflow_includes = base
|
||||
|
||||
return _XLA_MAKEFILE_TEMPLATE.format(
|
||||
tensorflow_includes=tensorflow_includes,
|
||||
compiled_dir=_shlex_quote(output_dir),
|
||||
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
|
||||
versions.CXX11_ABI_FLAG))
|
||||
|
||||
|
||||
def _show_tag_sets(saved_model_dir):
|
||||
"""Prints the tag-sets stored in SavedModel directory.
|
||||
|
||||
@ -653,7 +750,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
|
||||
if variable_name:
|
||||
# if file contains a single ndarray, ignore the input name
|
||||
if isinstance(data, np.ndarray):
|
||||
warnings.warn(
|
||||
logging.warn(
|
||||
'Input file %s contains a single ndarray. Name key \"%s\" ignored.'
|
||||
% (filename, variable_name))
|
||||
tensor_key_feed_dict[input_tensor_key] = data
|
||||
@ -680,7 +777,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
|
||||
# When input is a python expression:
|
||||
for input_tensor_key, py_expr_evaluated in input_exprs.items():
|
||||
if input_tensor_key in tensor_key_feed_dict:
|
||||
warnings.warn(
|
||||
logging.warn(
|
||||
'input_key %s has been specified with both --inputs and --input_exprs'
|
||||
' options. Value in --input_exprs will be used.' % input_tensor_key)
|
||||
tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated
|
||||
@ -688,7 +785,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
|
||||
# When input is a tf.Example:
|
||||
for input_tensor_key, example in input_examples.items():
|
||||
if input_tensor_key in tensor_key_feed_dict:
|
||||
warnings.warn(
|
||||
logging.warn(
|
||||
'input_key %s has been specified in multiple options. Value in '
|
||||
'--input_examples will be used.' % input_tensor_key)
|
||||
tensor_key_feed_dict[input_tensor_key] = example
|
||||
@ -776,20 +873,193 @@ def convert_with_tensorrt(args):
|
||||
converter.save(output_saved_model_dir=args.output_dir)
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""Creates a parser that parse the command line arguments.
|
||||
def aot_compile_cpu(args):
|
||||
"""Function triggered by aot_compile_cpu command.
|
||||
|
||||
Returns:
|
||||
A namespace parsed from command line arguments.
|
||||
Args:
|
||||
args: A namespace parsed from command line.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='saved_model_cli: Command-line interface for SavedModel')
|
||||
parser.add_argument('-v', '--version', action='version', version='0.1.0')
|
||||
checkpoint_path = (
|
||||
args.checkpoint_path
|
||||
or os.path.join(args.dir, 'variables/variables'))
|
||||
aot_compile_cpu_meta_graph_def(
|
||||
checkpoint_path=checkpoint_path,
|
||||
meta_graph_def=saved_model_utils.get_meta_graph_def(
|
||||
args.dir, args.tag_set),
|
||||
signature_def_key=args.signature_def_key,
|
||||
freeze_graph=args.freeze_graph,
|
||||
output_prefix=args.output_prefix,
|
||||
cpp_class=args.cpp_class)
|
||||
|
||||
subparsers = parser.add_subparsers(
|
||||
title='commands', description='valid commands', help='additional help')
|
||||
|
||||
# show command
|
||||
def aot_compile_cpu_meta_graph_def(
|
||||
checkpoint_path,
|
||||
meta_graph_def,
|
||||
output_prefix,
|
||||
signature_def_key,
|
||||
cpp_class,
|
||||
freeze_graph=True):
|
||||
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
||||
|
||||
Use XLA AOT (`tfcompile`) to convert the given meta graph and
|
||||
signature into a header + object files. Also create an include makefile
|
||||
that helps identify the appropriate necessary include and library paths
|
||||
to incorporate these files into your C++ program.
|
||||
|
||||
The graph is always optimized with grappler, and optionally (by default)
|
||||
variables are frozen as constants, before compilation happens.
|
||||
|
||||
If the `freeze_graph` is `True`, all variables are embedded as constants
|
||||
into the graph and binary objects. If it is `False`, then the variable
|
||||
values become inputs and outputs of the compiled class and the C++
|
||||
caller must set these values manually.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Python string. Path to checkpoints/variables.
|
||||
meta_graph_def: Instance of `MetaGraphDef`.
|
||||
output_prefix: Python string. Path prefix for outputs.
|
||||
signature_def_key: String, the signature_def to use in the SavedModel.
|
||||
cpp_class: Name of output C++ class.
|
||||
freeze_graph: Whether to freeze the graph before compilation.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If tensorflow was not built with XLA.
|
||||
ImportError: If tensorflow was built with XLA but there was another
|
||||
issue importing the tfcompile python wrapper.
|
||||
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
|
||||
missing or has empty outputs.
|
||||
"""
|
||||
if _pywrap_tfcompile_import_error:
|
||||
raise _pywrap_tfcompile_import_error
|
||||
|
||||
signature_def_map = meta_graph_def.signature_def
|
||||
if signature_def_key not in signature_def_map:
|
||||
raise ValueError(
|
||||
'Unable to find signature_def key \'{}\' in signature def map. '
|
||||
'Available keys: {}'.format(
|
||||
signature_def_key,
|
||||
list(signature_def_map.keys())))
|
||||
signature_def = signature_def_map[signature_def_key]
|
||||
if not signature_def.outputs:
|
||||
raise ValueError(
|
||||
'Signature key {} must have outputs, but saw none:\n{}'.format(
|
||||
signature_def_key, str(signature_def)))
|
||||
|
||||
# This updates graph_def in place.
|
||||
_replace_input_placeholders_with_default_values(
|
||||
meta_graph_def.graph_def, signature_def)
|
||||
graph_def = _optimize_graph(meta_graph_def, signature_def)
|
||||
|
||||
if freeze_graph:
|
||||
# Load the Variables so that we can freeze the graph.
|
||||
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||
restorer = saver_lib.import_meta_graph(
|
||||
meta_graph_def, clear_devices=True)
|
||||
restorer.restore(sess, checkpoint_path)
|
||||
graph_def.CopyFrom(
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
graph_def,
|
||||
[n.name.split(':')[0] for n in signature_def.outputs.values()]))
|
||||
|
||||
temp_dir = test.get_temp_dir()
|
||||
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
|
||||
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
|
||||
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
|
||||
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(graph_def.SerializeToString())
|
||||
config = _signature_to_tf2xla_config(
|
||||
signature_def,
|
||||
frozen_variables=freeze_graph)
|
||||
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
|
||||
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
|
||||
config_writer.write(str(config))
|
||||
|
||||
output_dir = os.path.dirname(output_prefix)
|
||||
file_io.recursive_create_dir(output_dir)
|
||||
|
||||
entry_digest = hashlib.md5()
|
||||
entry_digest.update(str(config).encode())
|
||||
entry_digest.update(str(graph_def).encode())
|
||||
entry_digest = entry_digest.hexdigest()
|
||||
|
||||
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
|
||||
|
||||
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
|
||||
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
|
||||
makefile_writer.write(_xla_makefile_string(output_prefix))
|
||||
|
||||
output_prefix = _shlex_quote(output_prefix)
|
||||
|
||||
additional_compiler_args = {}
|
||||
sysconfig = _sysconfig_module()
|
||||
if sysconfig:
|
||||
# We're inside PIP and need to pass a customized relative path to the
|
||||
# appropriate tensorflow headers.
|
||||
additional_compiler_args['tensorflow_header_root'] = 'tensorflow'
|
||||
|
||||
_pywrap_tfcompile.Compile(
|
||||
graph=frozen_graph_def_location,
|
||||
config=config_pbtxt_location,
|
||||
cpp_class=cpp_class,
|
||||
entry_point='entry_{}'.format(entry_digest),
|
||||
out_function_object='{}.o'.format(output_prefix),
|
||||
out_header='{}.h'.format(output_prefix),
|
||||
out_metadata_object='{}_metadata.o'.format(output_prefix),
|
||||
gen_name_to_index=True,
|
||||
# ProgramShape isn't uniquefied by entry_point.
|
||||
gen_program_shape=False,
|
||||
**additional_compiler_args)
|
||||
|
||||
|
||||
def _optimize_graph(meta_graph_def, signature_def):
|
||||
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
|
||||
# We need to add a collection called 'train_op' so that grappler
|
||||
# knows what the outputs are.
|
||||
new_meta_graph_def = copy.deepcopy(meta_graph_def)
|
||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||
for tensor_info in (
|
||||
list(signature_def.inputs.values()) +
|
||||
list(signature_def.outputs.values())):
|
||||
fetch_collection.node_list.value.append(tensor_info.name)
|
||||
|
||||
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
|
||||
|
||||
config = config_pb2.ConfigProto()
|
||||
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
|
||||
|
||||
|
||||
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
|
||||
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
|
||||
name_to_node_map = dict((n.name, n) for n in graph_def.node)
|
||||
temp_graph = ops_lib.Graph()
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
tensor_name = input_.name.split(':')[0]
|
||||
if tensor_name not in name_to_node_map:
|
||||
raise RuntimeError(
|
||||
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
|
||||
'Graph nodes are: {}'.format(tensor_name,
|
||||
list(name_to_node_map.keys())))
|
||||
node = name_to_node_map[tensor_name]
|
||||
if node.op not in ('Placeholder', 'PlaceholderV2'):
|
||||
logging.info(
|
||||
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
|
||||
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
|
||||
node))
|
||||
continue
|
||||
shape = tensor_shape.TensorShape(input_.tensor_shape)
|
||||
if not shape.is_fully_defined():
|
||||
raise ValueError(
|
||||
'Expected fully defined input shape for signature_def \'{}\', '
|
||||
'tensor name: \'{}\'; but shape is: {}.'
|
||||
.format(name, tensor_name, shape))
|
||||
with temp_graph.as_default():
|
||||
const = array_ops.zeros(shape, dtype=input_.dtype, name=tensor_name)
|
||||
node.CopyFrom(const.op.node_def)
|
||||
|
||||
|
||||
def add_show_subparser(subparsers):
|
||||
"""Add parser for `show`."""
|
||||
show_msg = (
|
||||
'Usage examples:\n'
|
||||
'To show all tag-sets in a SavedModel:\n'
|
||||
@ -833,7 +1103,9 @@ def create_parser():
|
||||
help='key of SignatureDef to display input(s) and output(s) for')
|
||||
parser_show.set_defaults(func=show)
|
||||
|
||||
# run command
|
||||
|
||||
def add_run_subparser(subparsers):
|
||||
"""Add parser for `run`."""
|
||||
run_msg = ('Usage example:\n'
|
||||
'To run input tensors from files through a MetaGraphDef and save'
|
||||
' the output tensors to files:\n'
|
||||
@ -909,7 +1181,9 @@ def create_parser():
|
||||
'This option should be only used if the worker is a TPU job.')
|
||||
parser_run.set_defaults(func=run)
|
||||
|
||||
# scan command
|
||||
|
||||
def add_scan_subparser(subparsers):
|
||||
"""Add parser for `scan`."""
|
||||
scan_msg = ('Usage example:\n'
|
||||
'To scan for blacklisted ops in SavedModel:\n'
|
||||
'$saved_model_cli scan --dir /tmp/saved_model\n'
|
||||
@ -929,7 +1203,9 @@ def create_parser():
|
||||
help='tag-set of graph in SavedModel to scan, separated by \',\'')
|
||||
parser_scan.set_defaults(func=scan)
|
||||
|
||||
# convert command
|
||||
|
||||
def add_convert_subparser(subparsers):
|
||||
"""Add parser for `convert`."""
|
||||
convert_msg = ('Usage example:\n'
|
||||
'To convert the SavedModel to one that have TensorRT ops:\n'
|
||||
'$saved_model_cli convert \\\n'
|
||||
@ -983,9 +1259,161 @@ def create_parser():
|
||||
'in a TensorRT node'))
|
||||
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
||||
|
||||
|
||||
def add_aot_compile_cpu_subparser(subparsers):
|
||||
"""Add parser for `aot_compile_cpu`."""
|
||||
compile_msg = '\n'.join(
|
||||
['Usage example:',
|
||||
'To compile a SavedModel signature via (CPU) XLA AOT:',
|
||||
'$saved_model_cli aot_compile_cpu \\',
|
||||
' --dir /tmp/saved_model \\',
|
||||
' --tag_set serve \\',
|
||||
' --output_dir /tmp/saved_model_xla_aot',
|
||||
'', '',
|
||||
'Note: Additional XLA compilation options are available by setting the ',
|
||||
'XLA_FLAGS environment variable. See the XLA debug options flags for ',
|
||||
'all the options: ',
|
||||
' {}'.format(_XLA_DEBUG_OPTIONS_URL),
|
||||
'',
|
||||
'For example, to disable XLA fast math when compiling:',
|
||||
'',
|
||||
'XLA_FLAGS="--xla_cpu_enable_fast_math=false" $saved_model_cli '
|
||||
'aot_compile_cpu ...',
|
||||
'',
|
||||
'Some possibly useful flags:',
|
||||
' --xla_cpu_enable_fast_math=false',
|
||||
' --xla_cpu_multi_thread_eigen=false',
|
||||
' --xla_force_host_platform_device_count=<num threads>',
|
||||
' (useful in conjunction with disabling eigen multi threading)'
|
||||
])
|
||||
|
||||
parser_compile = subparsers.add_parser(
|
||||
'aot_compile_cpu',
|
||||
description=compile_msg,
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
parser_compile.add_argument(
|
||||
'--dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='directory containing the SavedModel to convert')
|
||||
parser_compile.add_argument(
|
||||
'--output_prefix',
|
||||
type=str,
|
||||
required=True,
|
||||
help=('output directory + filename prefix for the resulting header(s) '
|
||||
'and object file(s)'))
|
||||
parser_compile.add_argument(
|
||||
'--tag_set',
|
||||
type=str,
|
||||
required=True,
|
||||
help='tag-set of graph in SavedModel to convert, separated by \',\'')
|
||||
parser_compile.add_argument(
|
||||
'--signature_def_key',
|
||||
type=str,
|
||||
default=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
||||
help=('signature_def key to use. '
|
||||
'default: DEFAULT_SERVING_SIGNATURE_DEF_KEY'))
|
||||
parser_compile.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Custom checkpoint to use (default: use the SavedModel variables)')
|
||||
parser_compile.add_argument(
|
||||
'--cpp_class',
|
||||
type=str,
|
||||
required=True,
|
||||
help=('The name of the generated C++ class, wrapping the generated '
|
||||
'function. The syntax of this flag is '
|
||||
'[[<optional_namespace>::],...]<class_name>. This mirrors the '
|
||||
'C++ syntax for referring to a class, where multiple namespaces '
|
||||
'may precede the class name, separated by double-colons. '
|
||||
'The class will be generated in the given namespace(s), or if no '
|
||||
'namespaces are given, within the global namespace.'))
|
||||
parser_compile.add_argument(
|
||||
'--freeze_graph',
|
||||
type=bool,
|
||||
default=True,
|
||||
help=('Whether to freeze the tf.Variables into the graph. If false, '
|
||||
'then all Variables in the closure of the signature graph path '
|
||||
'be be added as input and output args to the XLA-compiled graph '
|
||||
'(not currently supported)'))
|
||||
parser_compile.set_defaults(func=aot_compile_cpu)
|
||||
|
||||
|
||||
def create_parser():
|
||||
"""Creates a parser that parse the command line arguments.
|
||||
|
||||
Returns:
|
||||
A namespace parsed from command line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='saved_model_cli: Command-line interface for SavedModel')
|
||||
parser.add_argument('-v', '--version', action='version', version='0.1.0')
|
||||
|
||||
subparsers = parser.add_subparsers(
|
||||
title='commands', description='valid commands', help='additional help')
|
||||
|
||||
# show command
|
||||
add_show_subparser(subparsers)
|
||||
|
||||
# run command
|
||||
add_run_subparser(subparsers)
|
||||
|
||||
# scan command
|
||||
add_scan_subparser(subparsers)
|
||||
|
||||
# tensorrt convert command
|
||||
add_convert_subparser(subparsers)
|
||||
|
||||
# aot_compile_cpu command
|
||||
add_aot_compile_cpu_subparser(subparsers)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _signature_to_tf2xla_config(signature_def, frozen_variables):
|
||||
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
|
||||
|
||||
Args:
|
||||
signature_def: Instance of `SignatureDef`.
|
||||
frozen_variables: Python bool, whether variables are being frozen or not.
|
||||
|
||||
Returns:
|
||||
An instance of `tf2xla.Config` proto.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If TensorFlow was not compiled with XLA.
|
||||
"""
|
||||
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
|
||||
|
||||
config = tf2xla_pb2.Config()
|
||||
tensor_id = tf2xla_pb2.TensorId
|
||||
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
(node_name, output_index) = input_.name.split(':')
|
||||
output_index = int(output_index)
|
||||
config.feed.append(
|
||||
tf2xla_pb2.Feed(
|
||||
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||
name=name,
|
||||
type=input_.dtype,
|
||||
shape=input_.tensor_shape))
|
||||
for name, output_ in signature_def.outputs.items():
|
||||
(node_name, output_index) = output_.name.split(':')
|
||||
output_index = int(output_index)
|
||||
config.fetch.append(
|
||||
tf2xla_pb2.Fetch(
|
||||
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||
name=name,
|
||||
type=output_.dtype,
|
||||
shape=output_.tensor_shape))
|
||||
if not frozen_variables:
|
||||
# Extract all variables along the path and add to config
|
||||
raise NotImplementedError('Non-frozen graphs are not supported.')
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
@ -35,6 +35,8 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.tools import saved_model_cli
|
||||
@ -709,6 +711,63 @@ Defined Functions:
|
||||
output = out.getvalue().strip()
|
||||
self.assertTrue('\'VariableV2\'' in output)
|
||||
|
||||
def testAOTCompileCPUWrongSignatureDefKey(self):
|
||||
if not test.is_built_with_xla():
|
||||
self.skipTest('Skipping test because XLA is not compiled in.')
|
||||
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
|
||||
args = self.parser.parse_args(
|
||||
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
||||
'--output_prefix', output_dir,
|
||||
'--cpp_class', 'Compiled',
|
||||
'--signature_def_key', 'MISSING'])
|
||||
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
|
||||
def testAOTCompileCPUFreezesAndCompiles(self):
|
||||
if not test.is_built_with_xla():
|
||||
self.skipTest('Skipping test because XLA is not compiled in.')
|
||||
|
||||
class DummyModel(tracking.AutoTrackable):
|
||||
"""Model compatible with XLA compilation."""
|
||||
|
||||
def __init__(self):
|
||||
self.var = variables.Variable(1.0, name='my_var')
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
|
||||
])
|
||||
def func2(self, x):
|
||||
return {'res': x + self.var}
|
||||
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||
dummy_model = DummyModel()
|
||||
with self.cached_session():
|
||||
self.evaluate(dummy_model.var.initializer)
|
||||
save.save(dummy_model, saved_model_dir)
|
||||
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
||||
args = self.parser.parse_args(
|
||||
['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
||||
'--output_prefix', output_prefix,
|
||||
'--cpp_class', 'Generated']) # Use the default seving signature_key.
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
|
||||
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
|
||||
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
|
||||
self.assertTrue(
|
||||
file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
|
||||
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
|
||||
self.assertIn('class Generated', header_contents)
|
||||
self.assertIn('arg_x_data', header_contents)
|
||||
self.assertIn('result_res_data', header_contents)
|
||||
makefile_contents = file_io.read_file_to_string(
|
||||
'{}_makefile.inc'.format(output_prefix))
|
||||
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
PYBIND11_MODULE(_pywrap_util_port, m) {
|
||||
m.def("IsGoogleCudaEnabled", tensorflow::IsGoogleCudaEnabled);
|
||||
m.def("IsBuiltWithROCm", tensorflow::IsBuiltWithROCm);
|
||||
m.def("IsBuiltWithXLA", tensorflow::IsBuiltWithXLA);
|
||||
m.def("IsBuiltWithNvcc", tensorflow::IsBuiltWithNvcc);
|
||||
m.def("GpuSupportsHalfMatMulAndConv",
|
||||
tensorflow::GpuSupportsHalfMatMulAndConv);
|
||||
|
@ -55,6 +55,11 @@ load(
|
||||
VERSION = "2.1.0"
|
||||
VERSION_MAJOR = VERSION.split(".")[0]
|
||||
|
||||
# Sanitize a dependency so that it works correctly from code that includes
|
||||
# TensorFlow as a submodule.
|
||||
def clean_dep(dep):
|
||||
return str(Label(dep))
|
||||
|
||||
def if_v2(a):
|
||||
return select({
|
||||
clean_dep("//tensorflow:api_version_2"): a,
|
||||
@ -76,6 +81,12 @@ def if_nvcc(a):
|
||||
def if_cuda_is_configured_compat(x):
|
||||
return if_cuda_is_configured(x)
|
||||
|
||||
def if_xla_available(if_true, if_false = []):
|
||||
return select({
|
||||
clean_dep("//tensorflow:with_xla_support"): if_true,
|
||||
"//conditions:default": if_false,
|
||||
})
|
||||
|
||||
# Given a source file, generate a test name.
|
||||
# i.e. "common_runtime/direct_session_test.cc" becomes
|
||||
# "common_runtime_direct_session_test"
|
||||
@ -113,11 +124,6 @@ def tf_portable_proto_library(name, proto_deps, deps = [], **kwargs):
|
||||
_ignore = [kwargs]
|
||||
native.cc_library(name = name, deps = deps + [dep + "_cc" for dep in proto_deps])
|
||||
|
||||
# Sanitize a dependency so that it works correctly from code that includes
|
||||
# TensorFlow as a submodule.
|
||||
def clean_dep(dep):
|
||||
return str(Label(dep))
|
||||
|
||||
def if_android_x86(a):
|
||||
return select({
|
||||
clean_dep("//tensorflow:android_x86"): a,
|
||||
@ -304,6 +310,7 @@ def tf_copts(
|
||||
(if_not_windows(["-fno-exceptions"]) if not allow_exceptions else []) +
|
||||
if_cuda(["-DGOOGLE_CUDA=1"]) +
|
||||
if_nvcc(["-DTENSORFLOW_USE_NVCC=1"]) +
|
||||
if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) +
|
||||
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
|
||||
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
|
||||
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
|
||||
@ -1418,7 +1425,7 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
|
||||
]) + if_rocm_is_configured(cuda_deps + [
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
]),
|
||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@ -2458,6 +2465,7 @@ def pybind_extension(
|
||||
copts = [],
|
||||
linkopts = [],
|
||||
deps = [],
|
||||
defines = [],
|
||||
visibility = None,
|
||||
testonly = None,
|
||||
licenses = None,
|
||||
@ -2524,6 +2532,7 @@ def pybind_extension(
|
||||
exported_symbols_file,
|
||||
version_script_file,
|
||||
],
|
||||
defines = defines,
|
||||
features = features + ["-use_header_modules"],
|
||||
linkshared = 1,
|
||||
testonly = testonly,
|
||||
@ -2569,6 +2578,7 @@ def tf_python_pybind_extension(
|
||||
copts = [],
|
||||
hdrs = [],
|
||||
deps = [],
|
||||
defines = [],
|
||||
visibility = None):
|
||||
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
|
||||
|
||||
@ -2583,9 +2593,20 @@ def tf_python_pybind_extension(
|
||||
copts = copts,
|
||||
hdrs = hdrs,
|
||||
deps = deps + tf_binary_pybind_deps() + mkl_deps(),
|
||||
defines = defines,
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
def tf_pybind_cc_library_wrapper(name, deps, visibility = None):
|
||||
"""Wrapper for cc_library and proto dependencies used by tf_python_pybind_extension.
|
||||
|
||||
This wrapper ensures that cc libraries' and protos' headers are made
|
||||
available to pybind code, without creating ODR violations in the dynamically
|
||||
linked case. The symbols in these deps symbols should be linked to, and
|
||||
exported by, the core pywrap_tensorflow_internal.so
|
||||
"""
|
||||
cc_header_only_library(name = name, deps = deps, visibility = visibility)
|
||||
|
||||
def if_cuda_or_rocm(if_true, if_false = []):
|
||||
"""Shorthand for select()'ing whether to build for either CUDA or ROCm.
|
||||
|
||||
@ -2621,8 +2642,8 @@ def tf_jit_compilation_passes_extra_deps():
|
||||
|
||||
def if_mlir(if_true, if_false = []):
|
||||
return select({
|
||||
str(Label("//tensorflow:with_mlir_support")): if_true,
|
||||
"//conditions:default": if_false,
|
||||
"//tensorflow:with_mlir_support": if_true,
|
||||
})
|
||||
|
||||
def tfcompile_extra_flags():
|
||||
|
@ -7,3 +7,4 @@
|
||||
*TFE_*
|
||||
*nsync_*
|
||||
*stream_executor*
|
||||
*xla*
|
||||
|
@ -8,6 +8,7 @@ tensorflow {
|
||||
*TFE_*;
|
||||
*nsync_*;
|
||||
*stream_executor*;
|
||||
*xla*;
|
||||
local:
|
||||
*;
|
||||
};
|
||||
|
@ -56,6 +56,10 @@ tf_module {
|
||||
name: "is_built_with_rocm"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_built_with_xla"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_gpu_available"
|
||||
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
|
@ -40,6 +40,10 @@ tf_module {
|
||||
name: "is_built_with_rocm"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_built_with_xla"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_gpu_available"
|
||||
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Description:
|
||||
# Tools for building the TensorFlow pip package.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_windows", "transitive_hdrs")
|
||||
load("//tensorflow:tensorflow.bzl", "if_windows", "if_xla_available", "transitive_hdrs")
|
||||
load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
|
||||
@ -104,7 +104,9 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/tools/docs:generate_lib",
|
||||
"//tensorflow/tools/docs:parser",
|
||||
"//tensorflow/tools/docs:py_guide_parser",
|
||||
]
|
||||
] + if_xla_available([
|
||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||
])
|
||||
|
||||
# On Windows, python binary is a zip file of runfiles tree.
|
||||
# Add everything to its data dependency for generating a runfiles tree
|
||||
|
@ -18,3 +18,4 @@ recursive-include tensorflow_core/include/google *.inc
|
||||
recursive-include tensorflow_core/include/include *.h
|
||||
recursive-include tensorflow_core/include/third_party *
|
||||
recursive-include tensorflow_core/include/unsupported *
|
||||
|
||||
|
@ -245,6 +245,7 @@ else:
|
||||
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.so'
|
||||
|
||||
headers = (
|
||||
list(find_files('*.h', 'tensorflow_core/compiler')) +
|
||||
list(find_files('*.h', 'tensorflow_core/core')) +
|
||||
list(find_files('*.h', 'tensorflow_core/stream_executor')) +
|
||||
list(find_files('*.h', 'google/com_google_protobuf/src')) +
|
||||
|
Loading…
Reference in New Issue
Block a user