[TF XLA] Add ability to convert SavedModel subgraphs to compiled [XLA CPU] objects via saved_model_cli.
You can now run, e.g.: saved_model_cli aot_compile_cpu \ --dir /path/to/saved_model \ --tag_set serve \ --signature_def_key action \ --output_prefix /tmp/out \ --cpp_class Serving::Action Which will create the files: /tmp/{out.h, out.o, out_metadata.o, out_makefile.inc} where out.h defines something like: namespace Serving { class Action { ... } } and out_makefile.inc provides the additional flags required to include the header and object files into your build. You can optionally also point aot_compile_cpu to a newer set of checkpoints (weight values) by using the optional argument --checkpoint_path. Also added `tf.test.is_built_with_xla()`. TESTED: * bazel test -c opt :saved_model_cli_test passes * built and installed the pip wheel and tested in the bazel directory via: TEST_SRCDIR=/tmp/tfcompile/bazel-bin/tensorflow/python/tools/saved_model_cli_test.runfiles/ python saved_model_cli_test.py and checking the output files to ensure the proper includes and header directories are set in out_makefile.inc and out.h. PiperOrigin-RevId: 290144104 Change-Id: If8eb6c3334b3042c4b9c24813b1b52c06d8fbc06
This commit is contained in:
parent
3bcfb829bb
commit
9959c04433
@ -2,6 +2,7 @@
|
|||||||
# TensorFlow is a computational framework, primarily for use in machine
|
# TensorFlow is a computational framework, primarily for use in machine
|
||||||
# learning applications.
|
# learning applications.
|
||||||
|
|
||||||
|
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||||
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
load("//tensorflow:tensorflow.bzl", "VERSION", "tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_native_cc_binary")
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||||
|
|
||||||
|
# buildifier: disable=same-origin-load
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
|
||||||
|
|
||||||
|
# buildifier: disable=same-origin-load
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
|
load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:private"],
|
default_visibility = ["//visibility:private"],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
@ -27,9 +34,14 @@ cc_library(
|
|||||||
"compile.h",
|
"compile.h",
|
||||||
"flags.h",
|
"flags.h",
|
||||||
],
|
],
|
||||||
|
defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]),
|
||||||
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
deps = [
|
deps = [
|
||||||
":aot_only_var_handle_op",
|
":aot_only_var_handle_op",
|
||||||
":embedded_protocol_buffers",
|
":embedded_protocol_buffers",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
"//tensorflow/compiler/tf2xla",
|
"//tensorflow/compiler/tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
"//tensorflow/compiler/tf2xla:mlir_tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
|
||||||
@ -53,12 +65,45 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/memory",
|
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||||
"@com_google_absl//absl/strings",
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
"@com_google_absl//absl/types:span",
|
"@llvm-project//llvm:target",
|
||||||
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
] + if_llvm_aarch64_available([
|
||||||
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Necessary for the pywrap inclusion below.
|
||||||
|
tf_pybind_cc_library_wrapper(
|
||||||
|
name = "tfcompile_headers_lib",
|
||||||
|
deps = [
|
||||||
|
":tfcompile_lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_tfcompile",
|
||||||
|
srcs = ["tfcompile_wrapper.cc"],
|
||||||
|
features = ["-layering_check"],
|
||||||
|
module_name = "_pywrap_tfcompile",
|
||||||
|
visibility = ["//tensorflow/python:__pkg__"],
|
||||||
|
deps = [
|
||||||
|
":tfcompile_headers_lib",
|
||||||
|
"@pybind11",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//tensorflow/python:pybind11_status",
|
||||||
|
# These headers cannot be brought in via cc_header_only_library
|
||||||
|
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
||||||
|
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
||||||
|
"@llvm-project//llvm:target",
|
||||||
|
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
||||||
|
] + if_llvm_aarch64_available([
|
||||||
|
"//third_party/llvm/llvm-project/llvm:aarch64_target", # fixdeps: keep
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "codegen_test",
|
name = "codegen_test",
|
||||||
srcs = ["codegen_test.cc"],
|
srcs = ["codegen_test.cc"],
|
||||||
@ -104,11 +149,6 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@llvm-project//llvm:aarch64_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:arm_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:powerpc_code_gen", # fixdeps: keep
|
|
||||||
"@llvm-project//llvm:target",
|
|
||||||
"@llvm-project//llvm:x86_code_gen", # fixdeps: keep
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -214,8 +254,13 @@ cc_library(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "aot_only_var_handle_op",
|
name = "aot_only_var_handle_op",
|
||||||
srcs = ["aot_only_var_handle_op.cc"],
|
srcs = ["aot_only_var_handle_op.cc"],
|
||||||
|
hdrs = ["aot_only_var_handle_op.h"],
|
||||||
|
visibility = [
|
||||||
|
"//tensorflow/compiler/tf2xla:__pkg__",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -51,6 +54,31 @@ void XlaAotOnlyVarHandleOp::Compile(XlaOpKernelContext* context) {
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_XLA_OP(Name("VarHandleOp").CompilationOnly(), XlaAotOnlyVarHandleOp);
|
REGISTER_OP(tfcompile::kXlaAotOnlyVarHandleOp)
|
||||||
|
.Doc(R"doc(
|
||||||
|
Internal VarHandleOp registration used for XLA AOT compilation.
|
||||||
|
)doc")
|
||||||
|
.Attr("container: string = ''")
|
||||||
|
.Attr("shared_name: string = ''")
|
||||||
|
.Attr("dtype: type")
|
||||||
|
.Attr("shape: shape")
|
||||||
|
.Output("resource: resource")
|
||||||
|
.SetIsStateful()
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
DataType t;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
|
||||||
|
PartialTensorShape p;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
|
||||||
|
shape_inference::ShapeHandle s;
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
|
||||||
|
c->set_output_handle_shapes_and_types(
|
||||||
|
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
|
REGISTER_XLA_OP(Name(tfcompile::kXlaAotOnlyVarHandleOp).CompilationOnly(),
|
||||||
|
XlaAotOnlyVarHandleOp);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
27
tensorflow/compiler/aot/aot_only_var_handle_op.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tfcompile {
|
||||||
|
|
||||||
|
static constexpr const char* const kXlaAotOnlyVarHandleOp =
|
||||||
|
"_XlaAotOnlyVarHandleOp";
|
||||||
|
|
||||||
|
} // namespace tfcompile
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_AOT_AOT_ONLY_VAR_HANDLE_OP_H_
|
@ -423,8 +423,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
|
||||||
const string include_xla_data_proto =
|
const string include_xla_data_proto =
|
||||||
opts.gen_program_shape
|
opts.gen_program_shape
|
||||||
?
|
? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
||||||
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
|
|
||||||
: "";
|
: "";
|
||||||
|
|
||||||
const string include_hlo_profile_printer_data_proto =
|
const string include_hlo_profile_printer_data_proto =
|
||||||
@ -458,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
|||||||
|
|
||||||
{{INCLUDE_XLA_DATA_PROTO}}
|
{{INCLUDE_XLA_DATA_PROTO}}
|
||||||
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
|
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
|
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
|
||||||
|
|
||||||
namespace Eigen { struct ThreadPoolDevice; }
|
namespace Eigen { struct ThreadPoolDevice; }
|
||||||
namespace xla { class ExecutableRunOptions; }
|
namespace xla { class ExecutableRunOptions; }
|
||||||
@ -660,6 +659,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
|||||||
{"{{CLASS}}", opts.class_name},
|
{"{{CLASS}}", opts.class_name},
|
||||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||||
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
||||||
|
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
|
||||||
{"{{ENTRY}}", compile_result.entry_point},
|
{"{{ENTRY}}", compile_result.entry_point},
|
||||||
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
||||||
metadata_result.hlo_profile_printer_data_access_shim},
|
metadata_result.hlo_profile_printer_data_access_shim},
|
||||||
|
@ -197,6 +197,7 @@ TEST(CodegenTest, Golden) {
|
|||||||
variable3->mutable_shape()->add_dim()->set_size(5);
|
variable3->mutable_shape()->add_dim()->set_size(5);
|
||||||
variable3->set_type(DT_INT32);
|
variable3->set_type(DT_INT32);
|
||||||
CompileResult compile_result;
|
CompileResult compile_result;
|
||||||
|
compile_result.tensorflow_header_root = "third_party/tensorflow";
|
||||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||||
{},
|
{},
|
||||||
{BufferInfo::MakeTempBuffer(1),
|
{BufferInfo::MakeTempBuffer(1),
|
||||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "llvm-c/Target.h"
|
||||||
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
#include "tensorflow/compiler/aot/flags.h"
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
@ -83,6 +85,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||||
std::move(aot_or.ValueOrDie().back()));
|
std::move(aot_or.ValueOrDie().back()));
|
||||||
compile_result->entry_point = aot_opts.entry_point_name();
|
compile_result->entry_point = aot_opts.entry_point_name();
|
||||||
|
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
|
||||||
compile_result->pointer_size =
|
compile_result->pointer_size =
|
||||||
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -90,7 +93,7 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||||
const MainFlags& flags, CompileResult* compile_result) {
|
const MainFlags& flags, CompileResult* compile_result) {
|
||||||
// Converts the graph into an XLA computation, and compiles the
|
// Converts the graph into an XLA computation, and compiles the
|
||||||
// computation.
|
// computation.
|
||||||
@ -108,8 +111,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
|||||||
if (!flags.mlir_components.empty()) {
|
if (!flags.mlir_components.empty()) {
|
||||||
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
return errors::Unknown("Unknown mlir_components ", flags.mlir_components);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config,
|
||||||
ConvertGraphDefToXla(graph_def, config, client, &computation));
|
client, &computation));
|
||||||
}
|
}
|
||||||
if (!flags.out_session_module.empty()) {
|
if (!flags.out_session_module.empty()) {
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module,
|
||||||
@ -127,10 +130,102 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
|||||||
xla::cpu::CpuAotCompilationOptions aot_opts(
|
xla::cpu::CpuAotCompilationOptions aot_opts(
|
||||||
flags.target_triple, flags.target_cpu, flags.target_features,
|
flags.target_triple, flags.target_cpu, flags.target_features,
|
||||||
flags.entry_point,
|
flags.entry_point,
|
||||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
|
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
|
||||||
|
flags.tensorflow_header_root);
|
||||||
|
|
||||||
return CompileXla(client, computation, aot_opts, compile_result);
|
return CompileXla(client, computation, aot_opts, compile_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
||||||
|
if (absl::EndsWith(fname, ".pbtxt")) {
|
||||||
|
return ReadTextProto(Env::Default(), fname, proto);
|
||||||
|
} else {
|
||||||
|
return ReadBinaryProto(Env::Default(), fname, proto);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::once_flag targets_init;
|
||||||
|
|
||||||
|
static void InitializeTargets() {
|
||||||
|
// Initialize all LLVM targets so we can cross compile.
|
||||||
|
#if TF_LLVM_AARCH64_AVAILABLE
|
||||||
|
LLVMInitializeAArch64Target();
|
||||||
|
LLVMInitializeAArch64TargetInfo();
|
||||||
|
LLVMInitializeAArch64TargetMC();
|
||||||
|
LLVMInitializeAArch64AsmPrinter();
|
||||||
|
#endif
|
||||||
|
LLVMInitializeARMTarget();
|
||||||
|
LLVMInitializeARMTargetInfo();
|
||||||
|
LLVMInitializeARMTargetMC();
|
||||||
|
LLVMInitializeARMAsmPrinter();
|
||||||
|
LLVMInitializePowerPCTarget();
|
||||||
|
LLVMInitializePowerPCTargetInfo();
|
||||||
|
LLVMInitializePowerPCTargetMC();
|
||||||
|
LLVMInitializePowerPCAsmPrinter();
|
||||||
|
LLVMInitializeX86Target();
|
||||||
|
LLVMInitializeX86TargetInfo();
|
||||||
|
LLVMInitializeX86TargetMC();
|
||||||
|
LLVMInitializeX86AsmPrinter();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Main(const MainFlags& flags) {
|
||||||
|
std::call_once(targets_init, &InitializeTargets);
|
||||||
|
|
||||||
|
// Process config.
|
||||||
|
tf2xla::Config config;
|
||||||
|
if (flags.config.empty()) {
|
||||||
|
return errors::InvalidArgument("Must specify --config");
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
||||||
|
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
||||||
|
if (flags.dump_fetch_nodes) {
|
||||||
|
std::set<string> nodes;
|
||||||
|
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
||||||
|
nodes.insert(fetch.id().node_name());
|
||||||
|
}
|
||||||
|
std::cout << absl::StrJoin(nodes, ",");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and initialize the graph.
|
||||||
|
if (flags.graph.empty()) {
|
||||||
|
return errors::InvalidArgument("Must specify --graph");
|
||||||
|
}
|
||||||
|
GraphDef graph_def;
|
||||||
|
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
||||||
|
CompileResult compile_result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
CompileGraph(std::move(graph_def), config, flags, &compile_result));
|
||||||
|
|
||||||
|
// Write output files.
|
||||||
|
Env* env = Env::Default();
|
||||||
|
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
WriteStringToFile(env, flags.out_function_object,
|
||||||
|
absl::string_view(obj.data(), obj.size())));
|
||||||
|
CodegenOpts codegen_opts;
|
||||||
|
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
||||||
|
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
||||||
|
codegen_opts.target_triple = flags.target_triple;
|
||||||
|
if (flags.cpp_class.empty()) {
|
||||||
|
return errors::InvalidArgument("Must specify --cpp_class");
|
||||||
|
}
|
||||||
|
codegen_opts.gen_hlo_profile_printer_data =
|
||||||
|
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
||||||
|
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
||||||
|
&codegen_opts.namespaces));
|
||||||
|
|
||||||
|
MetadataResult metadata_result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
||||||
|
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
||||||
|
metadata_result.object_file_data));
|
||||||
|
string header;
|
||||||
|
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
||||||
|
metadata_result, &header));
|
||||||
|
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -35,6 +35,7 @@ struct CompileResult {
|
|||||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||||
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||||
string entry_point; // Name of generated function.
|
string entry_point; // Name of generated function.
|
||||||
|
string tensorflow_header_root; // Prefix for tensorflow headers.
|
||||||
int pointer_size = 0; // Size of a pointer in bytes.
|
int pointer_size = 0; // Size of a pointer in bytes.
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -42,9 +43,12 @@ struct CompileResult {
|
|||||||
// that performs the graph operations.
|
// that performs the graph operations.
|
||||||
//
|
//
|
||||||
// The XLA compilation options are specified in the flags.
|
// The XLA compilation options are specified in the flags.
|
||||||
Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
|
Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||||
const MainFlags& flags, CompileResult* compile_result);
|
const MainFlags& flags, CompileResult* compile_result);
|
||||||
|
|
||||||
|
// The full compilation method, for reuse in a library setting.
|
||||||
|
Status Main(const MainFlags& flags);
|
||||||
|
|
||||||
} // namespace tfcompile
|
} // namespace tfcompile
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -74,6 +74,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
|||||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||||
{"gen_program_shape", &flags->gen_program_shape,
|
{"gen_program_shape", &flags->gen_program_shape,
|
||||||
"Generate program shape data for the ProgramShape method."},
|
"Generate program shape data for the ProgramShape method."},
|
||||||
|
{"tensorflow_header_root", &flags->tensorflow_header_root,
|
||||||
|
"Root directory of tensorflow headers."},
|
||||||
};
|
};
|
||||||
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ namespace tensorflow {
|
|||||||
namespace tfcompile {
|
namespace tfcompile {
|
||||||
|
|
||||||
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
// Flags for the tfcompile binary. See *.cc file for descriptions.
|
||||||
|
|
||||||
struct MainFlags {
|
struct MainFlags {
|
||||||
string graph;
|
string graph;
|
||||||
string config;
|
string config;
|
||||||
@ -39,6 +40,7 @@ struct MainFlags {
|
|||||||
string out_header;
|
string out_header;
|
||||||
string out_session_module;
|
string out_session_module;
|
||||||
string mlir_components;
|
string mlir_components;
|
||||||
|
string tensorflow_header_root;
|
||||||
|
|
||||||
// C++ codegen options
|
// C++ codegen options
|
||||||
bool gen_name_to_index = false;
|
bool gen_name_to_index = false;
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "llvm-c/Target.h"
|
|
||||||
#include "tensorflow/compiler/aot/codegen.h"
|
#include "tensorflow/compiler/aot/codegen.h"
|
||||||
#include "tensorflow/compiler/aot/compile.h"
|
#include "tensorflow/compiler/aot/compile.h"
|
||||||
#include "tensorflow/compiler/aot/flags.h"
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
@ -56,88 +55,6 @@ const char kUsageHeader[] =
|
|||||||
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
"--cpp_class=\"mynamespace::MyComputation\"\n"
|
||||||
"\n";
|
"\n";
|
||||||
|
|
||||||
Status ReadProtoFile(const string& fname, protobuf::Message* proto) {
|
|
||||||
if (absl::EndsWith(fname, ".pbtxt")) {
|
|
||||||
return ReadTextProto(Env::Default(), fname, proto);
|
|
||||||
} else {
|
|
||||||
return ReadBinaryProto(Env::Default(), fname, proto);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status Main(const MainFlags& flags) {
|
|
||||||
// Initialize all LLVM targets so we can cross compile.
|
|
||||||
LLVMInitializeAArch64Target();
|
|
||||||
LLVMInitializeAArch64TargetInfo();
|
|
||||||
LLVMInitializeAArch64TargetMC();
|
|
||||||
LLVMInitializeAArch64AsmPrinter();
|
|
||||||
LLVMInitializeARMTarget();
|
|
||||||
LLVMInitializeARMTargetInfo();
|
|
||||||
LLVMInitializeARMTargetMC();
|
|
||||||
LLVMInitializeARMAsmPrinter();
|
|
||||||
LLVMInitializePowerPCTarget();
|
|
||||||
LLVMInitializePowerPCTargetInfo();
|
|
||||||
LLVMInitializePowerPCTargetMC();
|
|
||||||
LLVMInitializePowerPCAsmPrinter();
|
|
||||||
LLVMInitializeX86Target();
|
|
||||||
LLVMInitializeX86TargetInfo();
|
|
||||||
LLVMInitializeX86TargetMC();
|
|
||||||
LLVMInitializeX86AsmPrinter();
|
|
||||||
|
|
||||||
// Process config.
|
|
||||||
tf2xla::Config config;
|
|
||||||
if (flags.config.empty()) {
|
|
||||||
return errors::InvalidArgument("Must specify --config");
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config));
|
|
||||||
TF_RETURN_IF_ERROR(ValidateConfig(config));
|
|
||||||
if (flags.dump_fetch_nodes) {
|
|
||||||
std::set<string> nodes;
|
|
||||||
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
|
||||||
nodes.insert(fetch.id().node_name());
|
|
||||||
}
|
|
||||||
std::cout << absl::StrJoin(nodes, ",");
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and initialize the graph.
|
|
||||||
if (flags.graph.empty()) {
|
|
||||||
return errors::InvalidArgument("Must specify --graph");
|
|
||||||
}
|
|
||||||
GraphDef graph_def;
|
|
||||||
TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def));
|
|
||||||
CompileResult compile_result;
|
|
||||||
TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result));
|
|
||||||
|
|
||||||
// Write output files.
|
|
||||||
Env* env = Env::Default();
|
|
||||||
const std::vector<char>& obj = compile_result.aot->object_file_data();
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
WriteStringToFile(env, flags.out_function_object,
|
|
||||||
absl::string_view(obj.data(), obj.size())));
|
|
||||||
CodegenOpts codegen_opts;
|
|
||||||
codegen_opts.gen_name_to_index = flags.gen_name_to_index;
|
|
||||||
codegen_opts.gen_program_shape = flags.gen_program_shape;
|
|
||||||
codegen_opts.target_triple = flags.target_triple;
|
|
||||||
if (flags.cpp_class.empty()) {
|
|
||||||
return errors::InvalidArgument("Must specify --cpp_class");
|
|
||||||
}
|
|
||||||
codegen_opts.gen_hlo_profile_printer_data =
|
|
||||||
xla::GetDebugOptionsFromFlags().xla_hlo_profile();
|
|
||||||
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
|
|
||||||
&codegen_opts.namespaces));
|
|
||||||
|
|
||||||
MetadataResult metadata_result;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
GenerateMetadata(codegen_opts, compile_result, &metadata_result));
|
|
||||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object,
|
|
||||||
metadata_result.object_file_data));
|
|
||||||
string header;
|
|
||||||
TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result,
|
|
||||||
metadata_result, &header));
|
|
||||||
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace tfcompile
|
} // end namespace tfcompile
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
@ -148,6 +65,7 @@ int main(int argc, char** argv) {
|
|||||||
flags.out_metadata_object = "out_helper.o";
|
flags.out_metadata_object = "out_helper.o";
|
||||||
flags.out_header = "out.h";
|
flags.out_header = "out.h";
|
||||||
flags.entry_point = "entry";
|
flags.entry_point = "entry";
|
||||||
|
flags.tensorflow_header_root = "third_party/tensorflow";
|
||||||
|
|
||||||
std::vector<tensorflow::Flag> flag_list;
|
std::vector<tensorflow::Flag> flag_list;
|
||||||
AppendMainFlags(&flag_list, &flags);
|
AppendMainFlags(&flag_list, &flags);
|
||||||
|
75
tensorflow/compiler/aot/tfcompile_wrapper.cc
Normal file
75
tensorflow/compiler/aot/tfcompile_wrapper.cc
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "include/pybind11/cast.h"
|
||||||
|
#include "include/pybind11/pybind11.h"
|
||||||
|
#include "include/pybind11/pytypes.h"
|
||||||
|
#include "include/pybind11/stl.h"
|
||||||
|
#include "tensorflow/compiler/aot/compile.h"
|
||||||
|
#include "tensorflow/compiler/aot/flags.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_lib.h"
|
||||||
|
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
PYBIND11_MODULE(_pywrap_tfcompile, m) {
|
||||||
|
m.doc() = R"pbdoc(
|
||||||
|
_pywrap_tfcompile
|
||||||
|
-----
|
||||||
|
)pbdoc";
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"Compile",
|
||||||
|
[](std::string graph, std::string config, std::string target_triple,
|
||||||
|
std::string target_cpu, std::string target_features,
|
||||||
|
std::string entry_point, std::string cpp_class,
|
||||||
|
std::string out_function_object, std::string out_metadata_object,
|
||||||
|
std::string out_header, std::string out_session_module,
|
||||||
|
std::string mlir_components, std::string tensorflow_header_root,
|
||||||
|
bool gen_name_to_index, bool gen_program_shape) {
|
||||||
|
tensorflow::tfcompile::MainFlags flags;
|
||||||
|
flags.graph = std::move(graph);
|
||||||
|
flags.config = std::move(config);
|
||||||
|
flags.target_triple = std::move(target_triple);
|
||||||
|
flags.target_cpu = std::move(target_cpu);
|
||||||
|
flags.target_features = std::move(target_features);
|
||||||
|
flags.entry_point = std::move(entry_point);
|
||||||
|
flags.cpp_class = std::move(cpp_class);
|
||||||
|
flags.out_function_object = std::move(out_function_object);
|
||||||
|
flags.out_metadata_object = std::move(out_metadata_object);
|
||||||
|
flags.out_header = std::move(out_header);
|
||||||
|
flags.out_session_module = std::move(out_session_module);
|
||||||
|
flags.mlir_components = std::move(mlir_components);
|
||||||
|
flags.tensorflow_header_root = std::move(tensorflow_header_root);
|
||||||
|
|
||||||
|
// C++ codegen options
|
||||||
|
flags.gen_name_to_index = gen_name_to_index;
|
||||||
|
flags.gen_program_shape = gen_program_shape;
|
||||||
|
|
||||||
|
tensorflow::MaybeRaiseFromStatus(tensorflow::tfcompile::Main(flags));
|
||||||
|
},
|
||||||
|
py::arg("graph") = "", py::arg("config") = "",
|
||||||
|
py::arg("target_triple") = "x86_64-pc-linux", py::arg("target_cpu") = "",
|
||||||
|
py::arg("target_features") = "", py::arg("entry_point") = "entry",
|
||||||
|
py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o",
|
||||||
|
py::arg("out_metadata_object") = "out_helper.o",
|
||||||
|
py::arg("out_header") = "out.h", py::arg("out_session_module") = "",
|
||||||
|
py::arg("mlir_components") = "",
|
||||||
|
py::arg("tensorflow_header_root") = "third_party/tensorflow",
|
||||||
|
py::arg("gen_name_to_index") = false,
|
||||||
|
py::arg("gen_program_shape") = false);
|
||||||
|
}
|
@ -5,6 +5,7 @@ load(
|
|||||||
)
|
)
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform:build_config.bzl",
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
|
"tf_proto_library",
|
||||||
"tf_proto_library_cc",
|
"tf_proto_library_cc",
|
||||||
)
|
)
|
||||||
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
|
load("//tensorflow/compiler/xla:xla.bzl", "xla_py_proto_library")
|
||||||
@ -62,7 +63,7 @@ tf_cc_binary(
|
|||||||
deps = [":tf2xla_supported_ops_lib"],
|
deps = [":tf2xla_supported_ops_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_proto_library_cc(
|
tf_proto_library(
|
||||||
name = "tf2xla_proto",
|
name = "tf2xla_proto",
|
||||||
srcs = ["tf2xla.proto"],
|
srcs = ["tf2xla.proto"],
|
||||||
cc_api_version = 2,
|
cc_api_version = 2,
|
||||||
@ -140,6 +141,7 @@ cc_library(
|
|||||||
":tf2xla_proto_cc",
|
":tf2xla_proto_cc",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
":xla_compiler",
|
":xla_compiler",
|
||||||
|
"//tensorflow/compiler/aot:aot_only_var_handle_op",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/client",
|
"//tensorflow/compiler/xla/client",
|
||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
|
#include "tensorflow/compiler/aot/aot_only_var_handle_op.h"
|
||||||
#include "tensorflow/compiler/tf2xla/graph_compiler_util.h"
|
#include "tensorflow/compiler/tf2xla/graph_compiler_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
@ -126,12 +127,28 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
|
||||||
|
for (auto& node : *graph_def->mutable_node()) {
|
||||||
|
if (node.op() == "VarHandleOp") {
|
||||||
|
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto& fn : *graph_def->mutable_library()->mutable_function()) {
|
||||||
|
for (auto& node : *fn.mutable_node_def()) {
|
||||||
|
if (node.op() == "VarHandleOp") {
|
||||||
|
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertGraphDefToXla(const GraphDef& graph_def,
|
Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
|
||||||
const tf2xla::Config& config, xla::Client* client,
|
xla::Client* client,
|
||||||
xla::XlaComputation* computation) {
|
xla::XlaComputation* computation) {
|
||||||
std::unique_ptr<Graph> graph;
|
std::unique_ptr<Graph> graph;
|
||||||
|
ConvertVarHandlesToAotVarHandles(&graph_def);
|
||||||
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
|
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ConvertGraphToXla(std::move(graph), config, client, computation));
|
ConvertGraphToXla(std::move(graph), config, client, computation));
|
||||||
|
@ -30,8 +30,8 @@ namespace tensorflow {
|
|||||||
//
|
//
|
||||||
// The computation is built in the context of the given `client`, which may
|
// The computation is built in the context of the given `client`, which may
|
||||||
// subsequently be used to compile or execute the computation.
|
// subsequently be used to compile or execute the computation.
|
||||||
Status ConvertGraphDefToXla(const GraphDef& graph_def,
|
Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
|
||||||
const tf2xla::Config& config, xla::Client* client,
|
xla::Client* client,
|
||||||
xla::XlaComputation* computation);
|
xla::XlaComputation* computation);
|
||||||
|
|
||||||
// Similar to ConvertGraphDefToXla, but uses MLIR.
|
// Similar to ConvertGraphDefToXla, but uses MLIR.
|
||||||
|
@ -119,12 +119,13 @@ using BufferInfo = cpu_function_runtime::BufferInfo;
|
|||||||
|
|
||||||
CpuAotCompilationOptions::CpuAotCompilationOptions(
|
CpuAotCompilationOptions::CpuAotCompilationOptions(
|
||||||
string triple, string cpu_name, string features, string entry_point_name,
|
string triple, string cpu_name, string features, string entry_point_name,
|
||||||
RelocationModel relocation_model)
|
RelocationModel relocation_model, string tensorflow_header_root)
|
||||||
: triple_(std::move(triple)),
|
: triple_(std::move(triple)),
|
||||||
cpu_name_(std::move(cpu_name)),
|
cpu_name_(std::move(cpu_name)),
|
||||||
features_(std::move(features)),
|
features_(std::move(features)),
|
||||||
entry_point_name_(std::move(entry_point_name)),
|
entry_point_name_(std::move(entry_point_name)),
|
||||||
relocation_model_(relocation_model) {}
|
relocation_model_(relocation_model),
|
||||||
|
tensorflow_header_root_(std::move(tensorflow_header_root)) {}
|
||||||
|
|
||||||
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
|
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
|
||||||
|
|
||||||
|
@ -53,7 +53,17 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
|||||||
|
|
||||||
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
||||||
string entry_point_name,
|
string entry_point_name,
|
||||||
RelocationModel relocation_model);
|
RelocationModel relocation_model,
|
||||||
|
string tensorflow_header_root);
|
||||||
|
|
||||||
|
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
||||||
|
string entry_point_name,
|
||||||
|
RelocationModel relocation_model)
|
||||||
|
: CpuAotCompilationOptions(
|
||||||
|
std::move(triple), std::move(cpu_name), std::move(features),
|
||||||
|
std::move(entry_point_name), relocation_model,
|
||||||
|
/*tensorflow_header_root=*/"third_party/tensorflow") {}
|
||||||
|
|
||||||
~CpuAotCompilationOptions() override;
|
~CpuAotCompilationOptions() override;
|
||||||
|
|
||||||
se::Platform::Id PlatformId() const override;
|
se::Platform::Id PlatformId() const override;
|
||||||
@ -66,6 +76,10 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
|||||||
const string& features() const { return features_; }
|
const string& features() const { return features_; }
|
||||||
// The name to be used for the compiled code's entry point.
|
// The name to be used for the compiled code's entry point.
|
||||||
const string& entry_point_name() const { return entry_point_name_; }
|
const string& entry_point_name() const { return entry_point_name_; }
|
||||||
|
// The prefix for tensorflow headers, e.g. "third_party/tensorflow".
|
||||||
|
const string& tensorflow_header_root() const {
|
||||||
|
return tensorflow_header_root_;
|
||||||
|
}
|
||||||
// The relocation model used for compilation.
|
// The relocation model used for compilation.
|
||||||
RelocationModel relocation_model() const { return relocation_model_; }
|
RelocationModel relocation_model() const { return relocation_model_; }
|
||||||
|
|
||||||
@ -75,6 +89,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
|||||||
const string features_;
|
const string features_;
|
||||||
const string entry_point_name_;
|
const string entry_point_name_;
|
||||||
const RelocationModel relocation_model_;
|
const RelocationModel relocation_model_;
|
||||||
|
const string tensorflow_header_root_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CpuAotCompilationResult : public AotCompilationResult {
|
class CpuAotCompilationResult : public AotCompilationResult {
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
load(
|
load(
|
||||||
"//tensorflow/core/platform/default:build_config.bzl",
|
"//tensorflow/core/platform/default:build_config.bzl",
|
||||||
|
_if_llvm_aarch64_available = "if_llvm_aarch64_available",
|
||||||
_pyx_library = "pyx_library",
|
_pyx_library = "pyx_library",
|
||||||
_tf_additional_all_protos = "tf_additional_all_protos",
|
_tf_additional_all_protos = "tf_additional_all_protos",
|
||||||
_tf_additional_binary_deps = "tf_additional_binary_deps",
|
_tf_additional_binary_deps = "tf_additional_binary_deps",
|
||||||
@ -80,3 +81,4 @@ tf_protos_profiler_impl = _tf_protos_profiler_impl
|
|||||||
tf_py_clif_cc = _tf_py_clif_cc
|
tf_py_clif_cc = _tf_py_clif_cc
|
||||||
tf_pyclif_proto_library = _tf_pyclif_proto_library
|
tf_pyclif_proto_library = _tf_pyclif_proto_library
|
||||||
tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
|
tf_windows_aware_platform_deps = _tf_windows_aware_platform_deps
|
||||||
|
if_llvm_aarch64_available = _if_llvm_aarch64_available
|
||||||
|
@ -770,3 +770,8 @@ def tf_google_mobile_srcs_no_runtime():
|
|||||||
|
|
||||||
def tf_google_mobile_srcs_only_runtime():
|
def tf_google_mobile_srcs_only_runtime():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def if_llvm_aarch64_available(then, otherwise = []):
|
||||||
|
# TODO(b/...): The TF XLA build fails when adding a dependency on
|
||||||
|
# @llvm/llvm-project/llvm:aarch64_target.
|
||||||
|
return otherwise
|
||||||
|
@ -34,6 +34,14 @@ bool IsBuiltWithROCm() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsBuiltWithXLA() {
|
||||||
|
#if TENSORFLOW_USE_XLA
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
bool IsBuiltWithNvcc() {
|
bool IsBuiltWithNvcc() {
|
||||||
#if TENSORFLOW_USE_NVCC
|
#if TENSORFLOW_USE_NVCC
|
||||||
return true;
|
return true;
|
||||||
|
@ -24,6 +24,9 @@ bool IsGoogleCudaEnabled();
|
|||||||
// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm)
|
// Returns true if TENSORFLOW_USE_ROCM is defined. (i.e. TF is built with ROCm)
|
||||||
bool IsBuiltWithROCm();
|
bool IsBuiltWithROCm();
|
||||||
|
|
||||||
|
// Returns true if TENSORFLOW_USE_XLA is defined. (i.e. TF is built with XLA)
|
||||||
|
bool IsBuiltWithXLA();
|
||||||
|
|
||||||
// Returns true if TENSORFLOW_USE_NVCC is defined. (i.e. TF is built with nvcc)
|
// Returns true if TENSORFLOW_USE_NVCC is defined. (i.e. TF is built with nvcc)
|
||||||
bool IsBuiltWithNvcc();
|
bool IsBuiltWithNvcc();
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# Public targets:
|
# Public targets:
|
||||||
# ":platform" - Low-level and platform-specific Python code.
|
# ":platform" - Low-level and platform-specific Python code.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "cc_header_only_library", "if_mlir", "if_not_windows", "if_xla_available", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||||
@ -1109,10 +1109,12 @@ py_library(
|
|||||||
":tensor_util",
|
":tensor_util",
|
||||||
":type_spec",
|
":type_spec",
|
||||||
":util",
|
":util",
|
||||||
"//tensorflow/python/eager:context",
|
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
"@six_archive//:six",
|
"@six_archive//:six",
|
||||||
],
|
"//tensorflow/python/eager:context",
|
||||||
|
] + if_xla_available([
|
||||||
|
"//tensorflow/compiler/aot:_pywrap_tfcompile",
|
||||||
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
@ -5553,6 +5555,8 @@ tf_py_wrap_cc(
|
|||||||
] + (tf_additional_lib_deps() +
|
] + (tf_additional_lib_deps() +
|
||||||
tf_additional_plugin_deps()) + if_ngraph([
|
tf_additional_plugin_deps()) + if_ngraph([
|
||||||
"@ngraph_tf//:ngraph_tf",
|
"@ngraph_tf//:ngraph_tf",
|
||||||
|
]) + if_xla_available([
|
||||||
|
"//tensorflow/compiler/aot:tfcompile_lib",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -284,6 +284,10 @@ def IsBuiltWithROCm():
|
|||||||
return _pywrap_util_port.IsBuiltWithROCm()
|
return _pywrap_util_port.IsBuiltWithROCm()
|
||||||
|
|
||||||
|
|
||||||
|
def IsBuiltWithXLA():
|
||||||
|
return _pywrap_util_port.IsBuiltWithXLA()
|
||||||
|
|
||||||
|
|
||||||
def IsBuiltWithNvcc():
|
def IsBuiltWithNvcc():
|
||||||
return _pywrap_util_port.IsBuiltWithNvcc()
|
return _pywrap_util_port.IsBuiltWithNvcc()
|
||||||
|
|
||||||
|
@ -106,3 +106,9 @@ def is_built_with_rocm():
|
|||||||
def is_built_with_gpu_support():
|
def is_built_with_gpu_support():
|
||||||
"""Returns whether TensorFlow was built with GPU (i.e. CUDA or ROCm) support."""
|
"""Returns whether TensorFlow was built with GPU (i.e. CUDA or ROCm) support."""
|
||||||
return is_built_with_cuda() or is_built_with_rocm()
|
return is_built_with_cuda() or is_built_with_rocm()
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('test.is_built_with_xla')
|
||||||
|
def is_built_with_xla():
|
||||||
|
"""Returns whether TensorFlow was built with XLA support."""
|
||||||
|
return _test_util.IsBuiltWithXLA()
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
# Description:
|
# Description:
|
||||||
# Tools for manipulating TensorFlow graphs.
|
# Tools for manipulating TensorFlow graphs.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test")
|
||||||
load("//tensorflow:tensorflow.bzl", "py_binary")
|
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//visibility:public"],
|
default_visibility = ["//visibility:public"],
|
||||||
@ -325,7 +324,10 @@ py_library(
|
|||||||
":saved_model_utils",
|
":saved_model_utils",
|
||||||
"//tensorflow/python",
|
"//tensorflow/python",
|
||||||
"//tensorflow/python/debug:local_cli_wrapper",
|
"//tensorflow/python/debug:local_cli_wrapper",
|
||||||
],
|
"//tensorflow/python:tf_optimizer",
|
||||||
|
] + if_xla_available(
|
||||||
|
["//tensorflow/compiler/tf2xla:tf2xla_proto_py"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
@ -339,7 +341,10 @@ py_test(
|
|||||||
tags = [
|
tags = [
|
||||||
"manual",
|
"manual",
|
||||||
"no-internal-py3",
|
"no-internal-py3",
|
||||||
|
"nosan",
|
||||||
],
|
],
|
||||||
|
# Force-include XLA dependencies of saved_model_cli_lib to ensure we test
|
||||||
|
# the AOT compilation.
|
||||||
deps = [
|
deps = [
|
||||||
":saved_model_cli_lib",
|
":saved_model_cli_lib",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
|
@ -25,34 +25,131 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
|
import copy
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
import pipes
|
||||||
import re
|
import re
|
||||||
|
import shlex
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.example import example_pb2
|
from tensorflow.core.example import example_pb2
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function as defun
|
from tensorflow.python.eager import function as defun
|
||||||
|
from tensorflow.python.framework import graph_util
|
||||||
from tensorflow.python.framework import meta_graph as meta_graph_lib
|
from tensorflow.python.framework import meta_graph as meta_graph_lib
|
||||||
from tensorflow.python.framework import ops as ops_lib
|
from tensorflow.python.framework import ops as ops_lib
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.framework import versions
|
||||||
|
from tensorflow.python.grappler import tf_optimizer
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import app # pylint: disable=unused-import
|
from tensorflow.python.platform import app # pylint: disable=unused-import
|
||||||
|
from tensorflow.python.platform import sysconfig as sysconfig_lib
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.saved_model import load
|
from tensorflow.python.saved_model import load
|
||||||
from tensorflow.python.saved_model import loader
|
from tensorflow.python.saved_model import loader
|
||||||
from tensorflow.python.saved_model import save
|
from tensorflow.python.saved_model import save
|
||||||
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.tools import saved_model_utils
|
from tensorflow.python.tools import saved_model_utils
|
||||||
|
from tensorflow.python.training import saver as saver_lib
|
||||||
|
|
||||||
|
|
||||||
|
_XLA_DEBUG_OPTIONS_URL = (
|
||||||
|
'https://github.com/tensorflow/tensorflow/blob/master/'
|
||||||
|
'tensorflow/compiler/xla/debug_options_flags.cc')
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tensorflow.compiler.aot import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
|
||||||
|
except ImportError as e:
|
||||||
|
_pywrap_tfcompile_import_error = ImportError(
|
||||||
|
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
|
||||||
|
'with XLA. You may need to build tensorflow with flag '
|
||||||
|
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
|
||||||
|
else:
|
||||||
|
_pywrap_tfcompile_import_error = None
|
||||||
|
|
||||||
|
|
||||||
# Set of ops to blacklist.
|
# Set of ops to blacklist.
|
||||||
_OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2'])
|
_OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2'])
|
||||||
|
|
||||||
|
|
||||||
|
def _shlex_quote(s):
|
||||||
|
if six.PY2:
|
||||||
|
return pipes.quote(s)
|
||||||
|
else:
|
||||||
|
return shlex.quote(s)
|
||||||
|
|
||||||
|
|
||||||
|
def _sysconfig_module():
|
||||||
|
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
|
||||||
|
try:
|
||||||
|
_ = sysconfig_lib.get_include()
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
return sysconfig_lib
|
||||||
|
|
||||||
|
|
||||||
|
_XLA_MAKEFILE_TEMPLATE = """
|
||||||
|
INC = -I{tensorflow_includes}
|
||||||
|
LIB = -L{compiled_dir}
|
||||||
|
CXXFLAGS = {cxx_flags}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _xla_makefile_string(output_prefix):
|
||||||
|
"""Returns a Makefile string with variables for using XLA binary object files.
|
||||||
|
|
||||||
|
Attempts to identify the right include header paths when run from either
|
||||||
|
an installed TensorFlow pip package, or from bazel run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_prefix: A string containing the output prefix for the XLA AOT
|
||||||
|
compiled header + object files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
|
||||||
|
"""
|
||||||
|
sysconfig = _sysconfig_module()
|
||||||
|
output_dir, _ = os.path.split(output_prefix)
|
||||||
|
if sysconfig:
|
||||||
|
tensorflow_includes = _shlex_quote(sysconfig.get_include())
|
||||||
|
else:
|
||||||
|
# Try hard to find the real source directory if this is a local bazel run.
|
||||||
|
if os.path.islink(__file__):
|
||||||
|
this_file = __file__
|
||||||
|
while os.path.islink(this_file):
|
||||||
|
this_file = os.readlink(this_file)
|
||||||
|
base = os.path.realpath(
|
||||||
|
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
|
||||||
|
else:
|
||||||
|
base = test.test_src_dir_path('')
|
||||||
|
expected_header = os.path.join(
|
||||||
|
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
|
||||||
|
if not os.path.exists(expected_header):
|
||||||
|
logging.error(
|
||||||
|
'Could not find includes path. Missing file: {}'
|
||||||
|
.format(expected_header))
|
||||||
|
tensorflow_includes = base
|
||||||
|
|
||||||
|
return _XLA_MAKEFILE_TEMPLATE.format(
|
||||||
|
tensorflow_includes=tensorflow_includes,
|
||||||
|
compiled_dir=_shlex_quote(output_dir),
|
||||||
|
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
|
||||||
|
versions.CXX11_ABI_FLAG))
|
||||||
|
|
||||||
|
|
||||||
def _show_tag_sets(saved_model_dir):
|
def _show_tag_sets(saved_model_dir):
|
||||||
"""Prints the tag-sets stored in SavedModel directory.
|
"""Prints the tag-sets stored in SavedModel directory.
|
||||||
|
|
||||||
@ -653,7 +750,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
|
|||||||
if variable_name:
|
if variable_name:
|
||||||
# if file contains a single ndarray, ignore the input name
|
# if file contains a single ndarray, ignore the input name
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
warnings.warn(
|
logging.warn(
|
||||||
'Input file %s contains a single ndarray. Name key \"%s\" ignored.'
|
'Input file %s contains a single ndarray. Name key \"%s\" ignored.'
|
||||||
% (filename, variable_name))
|
% (filename, variable_name))
|
||||||
tensor_key_feed_dict[input_tensor_key] = data
|
tensor_key_feed_dict[input_tensor_key] = data
|
||||||
@ -680,7 +777,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
|
|||||||
# When input is a python expression:
|
# When input is a python expression:
|
||||||
for input_tensor_key, py_expr_evaluated in input_exprs.items():
|
for input_tensor_key, py_expr_evaluated in input_exprs.items():
|
||||||
if input_tensor_key in tensor_key_feed_dict:
|
if input_tensor_key in tensor_key_feed_dict:
|
||||||
warnings.warn(
|
logging.warn(
|
||||||
'input_key %s has been specified with both --inputs and --input_exprs'
|
'input_key %s has been specified with both --inputs and --input_exprs'
|
||||||
' options. Value in --input_exprs will be used.' % input_tensor_key)
|
' options. Value in --input_exprs will be used.' % input_tensor_key)
|
||||||
tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated
|
tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated
|
||||||
@ -688,7 +785,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
|
|||||||
# When input is a tf.Example:
|
# When input is a tf.Example:
|
||||||
for input_tensor_key, example in input_examples.items():
|
for input_tensor_key, example in input_examples.items():
|
||||||
if input_tensor_key in tensor_key_feed_dict:
|
if input_tensor_key in tensor_key_feed_dict:
|
||||||
warnings.warn(
|
logging.warn(
|
||||||
'input_key %s has been specified in multiple options. Value in '
|
'input_key %s has been specified in multiple options. Value in '
|
||||||
'--input_examples will be used.' % input_tensor_key)
|
'--input_examples will be used.' % input_tensor_key)
|
||||||
tensor_key_feed_dict[input_tensor_key] = example
|
tensor_key_feed_dict[input_tensor_key] = example
|
||||||
@ -776,20 +873,193 @@ def convert_with_tensorrt(args):
|
|||||||
converter.save(output_saved_model_dir=args.output_dir)
|
converter.save(output_saved_model_dir=args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
def create_parser():
|
def aot_compile_cpu(args):
|
||||||
"""Creates a parser that parse the command line arguments.
|
"""Function triggered by aot_compile_cpu command.
|
||||||
|
|
||||||
Returns:
|
Args:
|
||||||
A namespace parsed from command line arguments.
|
args: A namespace parsed from command line.
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
checkpoint_path = (
|
||||||
description='saved_model_cli: Command-line interface for SavedModel')
|
args.checkpoint_path
|
||||||
parser.add_argument('-v', '--version', action='version', version='0.1.0')
|
or os.path.join(args.dir, 'variables/variables'))
|
||||||
|
aot_compile_cpu_meta_graph_def(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
meta_graph_def=saved_model_utils.get_meta_graph_def(
|
||||||
|
args.dir, args.tag_set),
|
||||||
|
signature_def_key=args.signature_def_key,
|
||||||
|
freeze_graph=args.freeze_graph,
|
||||||
|
output_prefix=args.output_prefix,
|
||||||
|
cpp_class=args.cpp_class)
|
||||||
|
|
||||||
subparsers = parser.add_subparsers(
|
|
||||||
title='commands', description='valid commands', help='additional help')
|
|
||||||
|
|
||||||
# show command
|
def aot_compile_cpu_meta_graph_def(
|
||||||
|
checkpoint_path,
|
||||||
|
meta_graph_def,
|
||||||
|
output_prefix,
|
||||||
|
signature_def_key,
|
||||||
|
cpp_class,
|
||||||
|
freeze_graph=True):
|
||||||
|
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
||||||
|
|
||||||
|
Use XLA AOT (`tfcompile`) to convert the given meta graph and
|
||||||
|
signature into a header + object files. Also create an include makefile
|
||||||
|
that helps identify the appropriate necessary include and library paths
|
||||||
|
to incorporate these files into your C++ program.
|
||||||
|
|
||||||
|
The graph is always optimized with grappler, and optionally (by default)
|
||||||
|
variables are frozen as constants, before compilation happens.
|
||||||
|
|
||||||
|
If the `freeze_graph` is `True`, all variables are embedded as constants
|
||||||
|
into the graph and binary objects. If it is `False`, then the variable
|
||||||
|
values become inputs and outputs of the compiled class and the C++
|
||||||
|
caller must set these values manually.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Python string. Path to checkpoints/variables.
|
||||||
|
meta_graph_def: Instance of `MetaGraphDef`.
|
||||||
|
output_prefix: Python string. Path prefix for outputs.
|
||||||
|
signature_def_key: String, the signature_def to use in the SavedModel.
|
||||||
|
cpp_class: Name of output C++ class.
|
||||||
|
freeze_graph: Whether to freeze the graph before compilation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If tensorflow was not built with XLA.
|
||||||
|
ImportError: If tensorflow was built with XLA but there was another
|
||||||
|
issue importing the tfcompile python wrapper.
|
||||||
|
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
|
||||||
|
missing or has empty outputs.
|
||||||
|
"""
|
||||||
|
if _pywrap_tfcompile_import_error:
|
||||||
|
raise _pywrap_tfcompile_import_error
|
||||||
|
|
||||||
|
signature_def_map = meta_graph_def.signature_def
|
||||||
|
if signature_def_key not in signature_def_map:
|
||||||
|
raise ValueError(
|
||||||
|
'Unable to find signature_def key \'{}\' in signature def map. '
|
||||||
|
'Available keys: {}'.format(
|
||||||
|
signature_def_key,
|
||||||
|
list(signature_def_map.keys())))
|
||||||
|
signature_def = signature_def_map[signature_def_key]
|
||||||
|
if not signature_def.outputs:
|
||||||
|
raise ValueError(
|
||||||
|
'Signature key {} must have outputs, but saw none:\n{}'.format(
|
||||||
|
signature_def_key, str(signature_def)))
|
||||||
|
|
||||||
|
# This updates graph_def in place.
|
||||||
|
_replace_input_placeholders_with_default_values(
|
||||||
|
meta_graph_def.graph_def, signature_def)
|
||||||
|
graph_def = _optimize_graph(meta_graph_def, signature_def)
|
||||||
|
|
||||||
|
if freeze_graph:
|
||||||
|
# Load the Variables so that we can freeze the graph.
|
||||||
|
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||||
|
restorer = saver_lib.import_meta_graph(
|
||||||
|
meta_graph_def, clear_devices=True)
|
||||||
|
restorer.restore(sess, checkpoint_path)
|
||||||
|
graph_def.CopyFrom(
|
||||||
|
graph_util.convert_variables_to_constants(
|
||||||
|
sess,
|
||||||
|
graph_def,
|
||||||
|
[n.name.split(':')[0] for n in signature_def.outputs.values()]))
|
||||||
|
|
||||||
|
temp_dir = test.get_temp_dir()
|
||||||
|
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
|
||||||
|
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
|
||||||
|
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
|
||||||
|
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
|
||||||
|
graph_writer.write(graph_def.SerializeToString())
|
||||||
|
config = _signature_to_tf2xla_config(
|
||||||
|
signature_def,
|
||||||
|
frozen_variables=freeze_graph)
|
||||||
|
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
|
||||||
|
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
|
||||||
|
config_writer.write(str(config))
|
||||||
|
|
||||||
|
output_dir = os.path.dirname(output_prefix)
|
||||||
|
file_io.recursive_create_dir(output_dir)
|
||||||
|
|
||||||
|
entry_digest = hashlib.md5()
|
||||||
|
entry_digest.update(str(config).encode())
|
||||||
|
entry_digest.update(str(graph_def).encode())
|
||||||
|
entry_digest = entry_digest.hexdigest()
|
||||||
|
|
||||||
|
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
|
||||||
|
|
||||||
|
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
|
||||||
|
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
|
||||||
|
makefile_writer.write(_xla_makefile_string(output_prefix))
|
||||||
|
|
||||||
|
output_prefix = _shlex_quote(output_prefix)
|
||||||
|
|
||||||
|
additional_compiler_args = {}
|
||||||
|
sysconfig = _sysconfig_module()
|
||||||
|
if sysconfig:
|
||||||
|
# We're inside PIP and need to pass a customized relative path to the
|
||||||
|
# appropriate tensorflow headers.
|
||||||
|
additional_compiler_args['tensorflow_header_root'] = 'tensorflow'
|
||||||
|
|
||||||
|
_pywrap_tfcompile.Compile(
|
||||||
|
graph=frozen_graph_def_location,
|
||||||
|
config=config_pbtxt_location,
|
||||||
|
cpp_class=cpp_class,
|
||||||
|
entry_point='entry_{}'.format(entry_digest),
|
||||||
|
out_function_object='{}.o'.format(output_prefix),
|
||||||
|
out_header='{}.h'.format(output_prefix),
|
||||||
|
out_metadata_object='{}_metadata.o'.format(output_prefix),
|
||||||
|
gen_name_to_index=True,
|
||||||
|
# ProgramShape isn't uniquefied by entry_point.
|
||||||
|
gen_program_shape=False,
|
||||||
|
**additional_compiler_args)
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_graph(meta_graph_def, signature_def):
|
||||||
|
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
|
||||||
|
# We need to add a collection called 'train_op' so that grappler
|
||||||
|
# knows what the outputs are.
|
||||||
|
new_meta_graph_def = copy.deepcopy(meta_graph_def)
|
||||||
|
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||||
|
for tensor_info in (
|
||||||
|
list(signature_def.inputs.values()) +
|
||||||
|
list(signature_def.outputs.values())):
|
||||||
|
fetch_collection.node_list.value.append(tensor_info.name)
|
||||||
|
|
||||||
|
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
|
||||||
|
|
||||||
|
config = config_pb2.ConfigProto()
|
||||||
|
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
|
||||||
|
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
|
||||||
|
name_to_node_map = dict((n.name, n) for n in graph_def.node)
|
||||||
|
temp_graph = ops_lib.Graph()
|
||||||
|
for name, input_ in signature_def.inputs.items():
|
||||||
|
tensor_name = input_.name.split(':')[0]
|
||||||
|
if tensor_name not in name_to_node_map:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
|
||||||
|
'Graph nodes are: {}'.format(tensor_name,
|
||||||
|
list(name_to_node_map.keys())))
|
||||||
|
node = name_to_node_map[tensor_name]
|
||||||
|
if node.op not in ('Placeholder', 'PlaceholderV2'):
|
||||||
|
logging.info(
|
||||||
|
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
|
||||||
|
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
|
||||||
|
node))
|
||||||
|
continue
|
||||||
|
shape = tensor_shape.TensorShape(input_.tensor_shape)
|
||||||
|
if not shape.is_fully_defined():
|
||||||
|
raise ValueError(
|
||||||
|
'Expected fully defined input shape for signature_def \'{}\', '
|
||||||
|
'tensor name: \'{}\'; but shape is: {}.'
|
||||||
|
.format(name, tensor_name, shape))
|
||||||
|
with temp_graph.as_default():
|
||||||
|
const = array_ops.zeros(shape, dtype=input_.dtype, name=tensor_name)
|
||||||
|
node.CopyFrom(const.op.node_def)
|
||||||
|
|
||||||
|
|
||||||
|
def add_show_subparser(subparsers):
|
||||||
|
"""Add parser for `show`."""
|
||||||
show_msg = (
|
show_msg = (
|
||||||
'Usage examples:\n'
|
'Usage examples:\n'
|
||||||
'To show all tag-sets in a SavedModel:\n'
|
'To show all tag-sets in a SavedModel:\n'
|
||||||
@ -833,7 +1103,9 @@ def create_parser():
|
|||||||
help='key of SignatureDef to display input(s) and output(s) for')
|
help='key of SignatureDef to display input(s) and output(s) for')
|
||||||
parser_show.set_defaults(func=show)
|
parser_show.set_defaults(func=show)
|
||||||
|
|
||||||
# run command
|
|
||||||
|
def add_run_subparser(subparsers):
|
||||||
|
"""Add parser for `run`."""
|
||||||
run_msg = ('Usage example:\n'
|
run_msg = ('Usage example:\n'
|
||||||
'To run input tensors from files through a MetaGraphDef and save'
|
'To run input tensors from files through a MetaGraphDef and save'
|
||||||
' the output tensors to files:\n'
|
' the output tensors to files:\n'
|
||||||
@ -909,7 +1181,9 @@ def create_parser():
|
|||||||
'This option should be only used if the worker is a TPU job.')
|
'This option should be only used if the worker is a TPU job.')
|
||||||
parser_run.set_defaults(func=run)
|
parser_run.set_defaults(func=run)
|
||||||
|
|
||||||
# scan command
|
|
||||||
|
def add_scan_subparser(subparsers):
|
||||||
|
"""Add parser for `scan`."""
|
||||||
scan_msg = ('Usage example:\n'
|
scan_msg = ('Usage example:\n'
|
||||||
'To scan for blacklisted ops in SavedModel:\n'
|
'To scan for blacklisted ops in SavedModel:\n'
|
||||||
'$saved_model_cli scan --dir /tmp/saved_model\n'
|
'$saved_model_cli scan --dir /tmp/saved_model\n'
|
||||||
@ -929,7 +1203,9 @@ def create_parser():
|
|||||||
help='tag-set of graph in SavedModel to scan, separated by \',\'')
|
help='tag-set of graph in SavedModel to scan, separated by \',\'')
|
||||||
parser_scan.set_defaults(func=scan)
|
parser_scan.set_defaults(func=scan)
|
||||||
|
|
||||||
# convert command
|
|
||||||
|
def add_convert_subparser(subparsers):
|
||||||
|
"""Add parser for `convert`."""
|
||||||
convert_msg = ('Usage example:\n'
|
convert_msg = ('Usage example:\n'
|
||||||
'To convert the SavedModel to one that have TensorRT ops:\n'
|
'To convert the SavedModel to one that have TensorRT ops:\n'
|
||||||
'$saved_model_cli convert \\\n'
|
'$saved_model_cli convert \\\n'
|
||||||
@ -983,9 +1259,161 @@ def create_parser():
|
|||||||
'in a TensorRT node'))
|
'in a TensorRT node'))
|
||||||
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
parser_convert_with_tensorrt.set_defaults(func=convert_with_tensorrt)
|
||||||
|
|
||||||
|
|
||||||
|
def add_aot_compile_cpu_subparser(subparsers):
|
||||||
|
"""Add parser for `aot_compile_cpu`."""
|
||||||
|
compile_msg = '\n'.join(
|
||||||
|
['Usage example:',
|
||||||
|
'To compile a SavedModel signature via (CPU) XLA AOT:',
|
||||||
|
'$saved_model_cli aot_compile_cpu \\',
|
||||||
|
' --dir /tmp/saved_model \\',
|
||||||
|
' --tag_set serve \\',
|
||||||
|
' --output_dir /tmp/saved_model_xla_aot',
|
||||||
|
'', '',
|
||||||
|
'Note: Additional XLA compilation options are available by setting the ',
|
||||||
|
'XLA_FLAGS environment variable. See the XLA debug options flags for ',
|
||||||
|
'all the options: ',
|
||||||
|
' {}'.format(_XLA_DEBUG_OPTIONS_URL),
|
||||||
|
'',
|
||||||
|
'For example, to disable XLA fast math when compiling:',
|
||||||
|
'',
|
||||||
|
'XLA_FLAGS="--xla_cpu_enable_fast_math=false" $saved_model_cli '
|
||||||
|
'aot_compile_cpu ...',
|
||||||
|
'',
|
||||||
|
'Some possibly useful flags:',
|
||||||
|
' --xla_cpu_enable_fast_math=false',
|
||||||
|
' --xla_cpu_multi_thread_eigen=false',
|
||||||
|
' --xla_force_host_platform_device_count=<num threads>',
|
||||||
|
' (useful in conjunction with disabling eigen multi threading)'
|
||||||
|
])
|
||||||
|
|
||||||
|
parser_compile = subparsers.add_parser(
|
||||||
|
'aot_compile_cpu',
|
||||||
|
description=compile_msg,
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--dir',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='directory containing the SavedModel to convert')
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--output_prefix',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help=('output directory + filename prefix for the resulting header(s) '
|
||||||
|
'and object file(s)'))
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--tag_set',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help='tag-set of graph in SavedModel to convert, separated by \',\'')
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--signature_def_key',
|
||||||
|
type=str,
|
||||||
|
default=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
|
||||||
|
help=('signature_def key to use. '
|
||||||
|
'default: DEFAULT_SERVING_SIGNATURE_DEF_KEY'))
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--checkpoint_path',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Custom checkpoint to use (default: use the SavedModel variables)')
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--cpp_class',
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help=('The name of the generated C++ class, wrapping the generated '
|
||||||
|
'function. The syntax of this flag is '
|
||||||
|
'[[<optional_namespace>::],...]<class_name>. This mirrors the '
|
||||||
|
'C++ syntax for referring to a class, where multiple namespaces '
|
||||||
|
'may precede the class name, separated by double-colons. '
|
||||||
|
'The class will be generated in the given namespace(s), or if no '
|
||||||
|
'namespaces are given, within the global namespace.'))
|
||||||
|
parser_compile.add_argument(
|
||||||
|
'--freeze_graph',
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help=('Whether to freeze the tf.Variables into the graph. If false, '
|
||||||
|
'then all Variables in the closure of the signature graph path '
|
||||||
|
'be be added as input and output args to the XLA-compiled graph '
|
||||||
|
'(not currently supported)'))
|
||||||
|
parser_compile.set_defaults(func=aot_compile_cpu)
|
||||||
|
|
||||||
|
|
||||||
|
def create_parser():
|
||||||
|
"""Creates a parser that parse the command line arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A namespace parsed from command line arguments.
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='saved_model_cli: Command-line interface for SavedModel')
|
||||||
|
parser.add_argument('-v', '--version', action='version', version='0.1.0')
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(
|
||||||
|
title='commands', description='valid commands', help='additional help')
|
||||||
|
|
||||||
|
# show command
|
||||||
|
add_show_subparser(subparsers)
|
||||||
|
|
||||||
|
# run command
|
||||||
|
add_run_subparser(subparsers)
|
||||||
|
|
||||||
|
# scan command
|
||||||
|
add_scan_subparser(subparsers)
|
||||||
|
|
||||||
|
# tensorrt convert command
|
||||||
|
add_convert_subparser(subparsers)
|
||||||
|
|
||||||
|
# aot_compile_cpu command
|
||||||
|
add_aot_compile_cpu_subparser(subparsers)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _signature_to_tf2xla_config(signature_def, frozen_variables):
|
||||||
|
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature_def: Instance of `SignatureDef`.
|
||||||
|
frozen_variables: Python bool, whether variables are being frozen or not.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `tf2xla.Config` proto.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If TensorFlow was not compiled with XLA.
|
||||||
|
"""
|
||||||
|
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
|
||||||
|
|
||||||
|
config = tf2xla_pb2.Config()
|
||||||
|
tensor_id = tf2xla_pb2.TensorId
|
||||||
|
|
||||||
|
for name, input_ in signature_def.inputs.items():
|
||||||
|
(node_name, output_index) = input_.name.split(':')
|
||||||
|
output_index = int(output_index)
|
||||||
|
config.feed.append(
|
||||||
|
tf2xla_pb2.Feed(
|
||||||
|
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||||
|
name=name,
|
||||||
|
type=input_.dtype,
|
||||||
|
shape=input_.tensor_shape))
|
||||||
|
for name, output_ in signature_def.outputs.items():
|
||||||
|
(node_name, output_index) = output_.name.split(':')
|
||||||
|
output_index = int(output_index)
|
||||||
|
config.fetch.append(
|
||||||
|
tf2xla_pb2.Fetch(
|
||||||
|
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||||
|
name=name,
|
||||||
|
type=output_.dtype,
|
||||||
|
shape=output_.tensor_shape))
|
||||||
|
if not frozen_variables:
|
||||||
|
# Extract all variables along the path and add to config
|
||||||
|
raise NotImplementedError('Non-frozen graphs are not supported.')
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -35,6 +35,8 @@ from tensorflow.python.eager import def_function
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.lib.io import file_io
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import save
|
from tensorflow.python.saved_model import save
|
||||||
from tensorflow.python.tools import saved_model_cli
|
from tensorflow.python.tools import saved_model_cli
|
||||||
@ -709,6 +711,63 @@ Defined Functions:
|
|||||||
output = out.getvalue().strip()
|
output = out.getvalue().strip()
|
||||||
self.assertTrue('\'VariableV2\'' in output)
|
self.assertTrue('\'VariableV2\'' in output)
|
||||||
|
|
||||||
|
def testAOTCompileCPUWrongSignatureDefKey(self):
|
||||||
|
if not test.is_built_with_xla():
|
||||||
|
self.skipTest('Skipping test because XLA is not compiled in.')
|
||||||
|
|
||||||
|
self.parser = saved_model_cli.create_parser()
|
||||||
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||||
|
output_dir = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir')
|
||||||
|
args = self.parser.parse_args(
|
||||||
|
['aot_compile_cpu', '--dir', base_path, '--tag_set', 'serve',
|
||||||
|
'--output_prefix', output_dir,
|
||||||
|
'--cpp_class', 'Compiled',
|
||||||
|
'--signature_def_key', 'MISSING'])
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
|
||||||
|
saved_model_cli.aot_compile_cpu(args)
|
||||||
|
|
||||||
|
def testAOTCompileCPUFreezesAndCompiles(self):
|
||||||
|
if not test.is_built_with_xla():
|
||||||
|
self.skipTest('Skipping test because XLA is not compiled in.')
|
||||||
|
|
||||||
|
class DummyModel(tracking.AutoTrackable):
|
||||||
|
"""Model compatible with XLA compilation."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.var = variables.Variable(1.0, name='my_var')
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[
|
||||||
|
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
|
||||||
|
])
|
||||||
|
def func2(self, x):
|
||||||
|
return {'res': x + self.var}
|
||||||
|
|
||||||
|
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||||
|
dummy_model = DummyModel()
|
||||||
|
with self.cached_session():
|
||||||
|
self.evaluate(dummy_model.var.initializer)
|
||||||
|
save.save(dummy_model, saved_model_dir)
|
||||||
|
|
||||||
|
self.parser = saved_model_cli.create_parser()
|
||||||
|
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
||||||
|
args = self.parser.parse_args(
|
||||||
|
['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
||||||
|
'--output_prefix', output_prefix,
|
||||||
|
'--cpp_class', 'Generated']) # Use the default seving signature_key.
|
||||||
|
saved_model_cli.aot_compile_cpu(args)
|
||||||
|
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
|
||||||
|
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
|
||||||
|
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
|
||||||
|
self.assertTrue(
|
||||||
|
file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
|
||||||
|
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
|
||||||
|
self.assertIn('class Generated', header_contents)
|
||||||
|
self.assertIn('arg_x_data', header_contents)
|
||||||
|
self.assertIn('result_res_data', header_contents)
|
||||||
|
makefile_contents = file_io.read_file_to_string(
|
||||||
|
'{}_makefile.inc'.format(output_prefix))
|
||||||
|
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
PYBIND11_MODULE(_pywrap_util_port, m) {
|
PYBIND11_MODULE(_pywrap_util_port, m) {
|
||||||
m.def("IsGoogleCudaEnabled", tensorflow::IsGoogleCudaEnabled);
|
m.def("IsGoogleCudaEnabled", tensorflow::IsGoogleCudaEnabled);
|
||||||
m.def("IsBuiltWithROCm", tensorflow::IsBuiltWithROCm);
|
m.def("IsBuiltWithROCm", tensorflow::IsBuiltWithROCm);
|
||||||
|
m.def("IsBuiltWithXLA", tensorflow::IsBuiltWithXLA);
|
||||||
m.def("IsBuiltWithNvcc", tensorflow::IsBuiltWithNvcc);
|
m.def("IsBuiltWithNvcc", tensorflow::IsBuiltWithNvcc);
|
||||||
m.def("GpuSupportsHalfMatMulAndConv",
|
m.def("GpuSupportsHalfMatMulAndConv",
|
||||||
tensorflow::GpuSupportsHalfMatMulAndConv);
|
tensorflow::GpuSupportsHalfMatMulAndConv);
|
||||||
|
@ -55,6 +55,11 @@ load(
|
|||||||
VERSION = "2.1.0"
|
VERSION = "2.1.0"
|
||||||
VERSION_MAJOR = VERSION.split(".")[0]
|
VERSION_MAJOR = VERSION.split(".")[0]
|
||||||
|
|
||||||
|
# Sanitize a dependency so that it works correctly from code that includes
|
||||||
|
# TensorFlow as a submodule.
|
||||||
|
def clean_dep(dep):
|
||||||
|
return str(Label(dep))
|
||||||
|
|
||||||
def if_v2(a):
|
def if_v2(a):
|
||||||
return select({
|
return select({
|
||||||
clean_dep("//tensorflow:api_version_2"): a,
|
clean_dep("//tensorflow:api_version_2"): a,
|
||||||
@ -76,6 +81,12 @@ def if_nvcc(a):
|
|||||||
def if_cuda_is_configured_compat(x):
|
def if_cuda_is_configured_compat(x):
|
||||||
return if_cuda_is_configured(x)
|
return if_cuda_is_configured(x)
|
||||||
|
|
||||||
|
def if_xla_available(if_true, if_false = []):
|
||||||
|
return select({
|
||||||
|
clean_dep("//tensorflow:with_xla_support"): if_true,
|
||||||
|
"//conditions:default": if_false,
|
||||||
|
})
|
||||||
|
|
||||||
# Given a source file, generate a test name.
|
# Given a source file, generate a test name.
|
||||||
# i.e. "common_runtime/direct_session_test.cc" becomes
|
# i.e. "common_runtime/direct_session_test.cc" becomes
|
||||||
# "common_runtime_direct_session_test"
|
# "common_runtime_direct_session_test"
|
||||||
@ -113,11 +124,6 @@ def tf_portable_proto_library(name, proto_deps, deps = [], **kwargs):
|
|||||||
_ignore = [kwargs]
|
_ignore = [kwargs]
|
||||||
native.cc_library(name = name, deps = deps + [dep + "_cc" for dep in proto_deps])
|
native.cc_library(name = name, deps = deps + [dep + "_cc" for dep in proto_deps])
|
||||||
|
|
||||||
# Sanitize a dependency so that it works correctly from code that includes
|
|
||||||
# TensorFlow as a submodule.
|
|
||||||
def clean_dep(dep):
|
|
||||||
return str(Label(dep))
|
|
||||||
|
|
||||||
def if_android_x86(a):
|
def if_android_x86(a):
|
||||||
return select({
|
return select({
|
||||||
clean_dep("//tensorflow:android_x86"): a,
|
clean_dep("//tensorflow:android_x86"): a,
|
||||||
@ -304,6 +310,7 @@ def tf_copts(
|
|||||||
(if_not_windows(["-fno-exceptions"]) if not allow_exceptions else []) +
|
(if_not_windows(["-fno-exceptions"]) if not allow_exceptions else []) +
|
||||||
if_cuda(["-DGOOGLE_CUDA=1"]) +
|
if_cuda(["-DGOOGLE_CUDA=1"]) +
|
||||||
if_nvcc(["-DTENSORFLOW_USE_NVCC=1"]) +
|
if_nvcc(["-DTENSORFLOW_USE_NVCC=1"]) +
|
||||||
|
if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) +
|
||||||
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
|
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
|
||||||
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
|
if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"]) +
|
||||||
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
|
if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) +
|
||||||
@ -1418,7 +1425,7 @@ def tf_gpu_library(deps = None, cuda_deps = None, copts = tf_copts(), **kwargs):
|
|||||||
]) + if_rocm_is_configured(cuda_deps + [
|
]) + if_rocm_is_configured(cuda_deps + [
|
||||||
"@local_config_rocm//rocm:rocm_headers",
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
]),
|
]),
|
||||||
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
copts = (copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_mkl_open_source_only(["-DINTEL_MKL_DNN_ONLY"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2458,6 +2465,7 @@ def pybind_extension(
|
|||||||
copts = [],
|
copts = [],
|
||||||
linkopts = [],
|
linkopts = [],
|
||||||
deps = [],
|
deps = [],
|
||||||
|
defines = [],
|
||||||
visibility = None,
|
visibility = None,
|
||||||
testonly = None,
|
testonly = None,
|
||||||
licenses = None,
|
licenses = None,
|
||||||
@ -2524,6 +2532,7 @@ def pybind_extension(
|
|||||||
exported_symbols_file,
|
exported_symbols_file,
|
||||||
version_script_file,
|
version_script_file,
|
||||||
],
|
],
|
||||||
|
defines = defines,
|
||||||
features = features + ["-use_header_modules"],
|
features = features + ["-use_header_modules"],
|
||||||
linkshared = 1,
|
linkshared = 1,
|
||||||
testonly = testonly,
|
testonly = testonly,
|
||||||
@ -2569,6 +2578,7 @@ def tf_python_pybind_extension(
|
|||||||
copts = [],
|
copts = [],
|
||||||
hdrs = [],
|
hdrs = [],
|
||||||
deps = [],
|
deps = [],
|
||||||
|
defines = [],
|
||||||
visibility = None):
|
visibility = None):
|
||||||
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
|
"""A wrapper macro for pybind_extension that is used in tensorflow/python/BUILD.
|
||||||
|
|
||||||
@ -2583,9 +2593,20 @@ def tf_python_pybind_extension(
|
|||||||
copts = copts,
|
copts = copts,
|
||||||
hdrs = hdrs,
|
hdrs = hdrs,
|
||||||
deps = deps + tf_binary_pybind_deps() + mkl_deps(),
|
deps = deps + tf_binary_pybind_deps() + mkl_deps(),
|
||||||
|
defines = defines,
|
||||||
visibility = visibility,
|
visibility = visibility,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def tf_pybind_cc_library_wrapper(name, deps, visibility = None):
|
||||||
|
"""Wrapper for cc_library and proto dependencies used by tf_python_pybind_extension.
|
||||||
|
|
||||||
|
This wrapper ensures that cc libraries' and protos' headers are made
|
||||||
|
available to pybind code, without creating ODR violations in the dynamically
|
||||||
|
linked case. The symbols in these deps symbols should be linked to, and
|
||||||
|
exported by, the core pywrap_tensorflow_internal.so
|
||||||
|
"""
|
||||||
|
cc_header_only_library(name = name, deps = deps, visibility = visibility)
|
||||||
|
|
||||||
def if_cuda_or_rocm(if_true, if_false = []):
|
def if_cuda_or_rocm(if_true, if_false = []):
|
||||||
"""Shorthand for select()'ing whether to build for either CUDA or ROCm.
|
"""Shorthand for select()'ing whether to build for either CUDA or ROCm.
|
||||||
|
|
||||||
@ -2621,8 +2642,8 @@ def tf_jit_compilation_passes_extra_deps():
|
|||||||
|
|
||||||
def if_mlir(if_true, if_false = []):
|
def if_mlir(if_true, if_false = []):
|
||||||
return select({
|
return select({
|
||||||
|
str(Label("//tensorflow:with_mlir_support")): if_true,
|
||||||
"//conditions:default": if_false,
|
"//conditions:default": if_false,
|
||||||
"//tensorflow:with_mlir_support": if_true,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def tfcompile_extra_flags():
|
def tfcompile_extra_flags():
|
||||||
|
@ -7,3 +7,4 @@
|
|||||||
*TFE_*
|
*TFE_*
|
||||||
*nsync_*
|
*nsync_*
|
||||||
*stream_executor*
|
*stream_executor*
|
||||||
|
*xla*
|
||||||
|
@ -8,6 +8,7 @@ tensorflow {
|
|||||||
*TFE_*;
|
*TFE_*;
|
||||||
*nsync_*;
|
*nsync_*;
|
||||||
*stream_executor*;
|
*stream_executor*;
|
||||||
|
*xla*;
|
||||||
local:
|
local:
|
||||||
*;
|
*;
|
||||||
};
|
};
|
||||||
|
@ -56,6 +56,10 @@ tf_module {
|
|||||||
name: "is_built_with_rocm"
|
name: "is_built_with_rocm"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "is_built_with_xla"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "is_gpu_available"
|
name: "is_gpu_available"
|
||||||
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
@ -40,6 +40,10 @@ tf_module {
|
|||||||
name: "is_built_with_rocm"
|
name: "is_built_with_rocm"
|
||||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "is_built_with_xla"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "is_gpu_available"
|
name: "is_gpu_available"
|
||||||
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'cuda_only\', \'min_cuda_compute_capability\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Description:
|
# Description:
|
||||||
# Tools for building the TensorFlow pip package.
|
# Tools for building the TensorFlow pip package.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "if_windows", "transitive_hdrs")
|
load("//tensorflow:tensorflow.bzl", "if_windows", "if_xla_available", "transitive_hdrs")
|
||||||
load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
|
load("//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml")
|
||||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||||
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
|
load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib")
|
||||||
@ -104,7 +104,9 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/tools/docs:generate_lib",
|
"//tensorflow/tools/docs:generate_lib",
|
||||||
"//tensorflow/tools/docs:parser",
|
"//tensorflow/tools/docs:parser",
|
||||||
"//tensorflow/tools/docs:py_guide_parser",
|
"//tensorflow/tools/docs:py_guide_parser",
|
||||||
]
|
] + if_xla_available([
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||||
|
])
|
||||||
|
|
||||||
# On Windows, python binary is a zip file of runfiles tree.
|
# On Windows, python binary is a zip file of runfiles tree.
|
||||||
# Add everything to its data dependency for generating a runfiles tree
|
# Add everything to its data dependency for generating a runfiles tree
|
||||||
|
@ -18,3 +18,4 @@ recursive-include tensorflow_core/include/google *.inc
|
|||||||
recursive-include tensorflow_core/include/include *.h
|
recursive-include tensorflow_core/include/include *.h
|
||||||
recursive-include tensorflow_core/include/third_party *
|
recursive-include tensorflow_core/include/third_party *
|
||||||
recursive-include tensorflow_core/include/unsupported *
|
recursive-include tensorflow_core/include/unsupported *
|
||||||
|
|
||||||
|
@ -245,6 +245,7 @@ else:
|
|||||||
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.so'
|
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.so'
|
||||||
|
|
||||||
headers = (
|
headers = (
|
||||||
|
list(find_files('*.h', 'tensorflow_core/compiler')) +
|
||||||
list(find_files('*.h', 'tensorflow_core/core')) +
|
list(find_files('*.h', 'tensorflow_core/core')) +
|
||||||
list(find_files('*.h', 'tensorflow_core/stream_executor')) +
|
list(find_files('*.h', 'tensorflow_core/stream_executor')) +
|
||||||
list(find_files('*.h', 'google/com_google_protobuf/src')) +
|
list(find_files('*.h', 'google/com_google_protobuf/src')) +
|
||||||
|
Loading…
x
Reference in New Issue
Block a user