[TF] Fixes and new features to saved_model_cli aot_compile_cpu, adds e2e test.
* Can now specify which variables can be fed/fetched. * Bugfix when signature name contains slashes or starts with integers. * Prune input config entries from tf2xla config when graph freeze removes unused input feed. * Fixed a bug where third_party/tensorflow/ isn't properly renamed to tensorflow/ in opensource HOST build (identified during the new genrule test). Solution: bring back the hardcoded #include in codegen.cc; it's always correct. NOTE: The bugfix to the #include line in the compiler/ codebase is a partial rollback of the initial tfcompile + saved_model_cli CL which moved from the hard-coded include path to a parameterized value. It turns out we don't need the complexity of this approach and it's incorrect in the host opensource build. TESTED: Includes a bonafide genrule test which runs saved_model_cli to generate the header and object files, and includes them in a c++ unit test and ensures that they compile and the resulting object runs correctly. PiperOrigin-RevId: 290655683 Change-Id: I4cfa2c595ebe56f8bdd47853f82371d97b92b081
This commit is contained in:
parent
a901c88061
commit
b26e1efece
@ -457,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
|
||||
{{INCLUDE_XLA_DATA_PROTO}}
|
||||
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
|
||||
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace Eigen { struct ThreadPoolDevice; }
|
||||
namespace xla { class ExecutableRunOptions; }
|
||||
@ -659,7 +659,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{CLASS}}", opts.class_name},
|
||||
{"{{DECLS_FROM_OBJ_FILE}}",
|
||||
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
|
||||
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
|
||||
{"{{ENTRY}}", compile_result.entry_point},
|
||||
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
|
||||
metadata_result.hlo_profile_printer_data_access_shim},
|
||||
|
@ -197,7 +197,6 @@ TEST(CodegenTest, Golden) {
|
||||
variable3->mutable_shape()->add_dim()->set_size(5);
|
||||
variable3->set_type(DT_INT32);
|
||||
CompileResult compile_result;
|
||||
compile_result.tensorflow_header_root = "third_party/tensorflow";
|
||||
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
|
||||
{},
|
||||
{BufferInfo::MakeTempBuffer(1),
|
||||
|
@ -85,7 +85,6 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
|
||||
std::move(aot_or.ValueOrDie().back()));
|
||||
compile_result->entry_point = aot_opts.entry_point_name();
|
||||
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
|
||||
compile_result->pointer_size =
|
||||
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
|
||||
return Status::OK();
|
||||
@ -130,8 +129,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
|
||||
xla::cpu::CpuAotCompilationOptions aot_opts(
|
||||
flags.target_triple, flags.target_cpu, flags.target_features,
|
||||
flags.entry_point,
|
||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
|
||||
flags.tensorflow_header_root);
|
||||
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
|
||||
|
||||
return CompileXla(client, computation, aot_opts, compile_result);
|
||||
}
|
||||
|
@ -35,7 +35,6 @@ struct CompileResult {
|
||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||
string entry_point; // Name of generated function.
|
||||
string tensorflow_header_root; // Prefix for tensorflow headers.
|
||||
int pointer_size = 0; // Size of a pointer in bytes.
|
||||
};
|
||||
|
||||
|
@ -74,8 +74,6 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
|
||||
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
|
||||
{"gen_program_shape", &flags->gen_program_shape,
|
||||
"Generate program shape data for the ProgramShape method."},
|
||||
{"tensorflow_header_root", &flags->tensorflow_header_root,
|
||||
"Root directory of tensorflow headers."},
|
||||
};
|
||||
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
|
@ -40,7 +40,6 @@ struct MainFlags {
|
||||
string out_header;
|
||||
string out_session_module;
|
||||
string mlir_components;
|
||||
string tensorflow_header_root;
|
||||
|
||||
// C++ codegen options
|
||||
bool gen_name_to_index = false;
|
||||
|
@ -65,7 +65,6 @@ int main(int argc, char** argv) {
|
||||
flags.out_metadata_object = "out_helper.o";
|
||||
flags.out_header = "out.h";
|
||||
flags.entry_point = "entry";
|
||||
flags.tensorflow_header_root = "third_party/tensorflow";
|
||||
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
AppendMainFlags(&flag_list, &flags);
|
||||
|
@ -119,13 +119,12 @@ using BufferInfo = cpu_function_runtime::BufferInfo;
|
||||
|
||||
CpuAotCompilationOptions::CpuAotCompilationOptions(
|
||||
string triple, string cpu_name, string features, string entry_point_name,
|
||||
RelocationModel relocation_model, string tensorflow_header_root)
|
||||
RelocationModel relocation_model)
|
||||
: triple_(std::move(triple)),
|
||||
cpu_name_(std::move(cpu_name)),
|
||||
features_(std::move(features)),
|
||||
entry_point_name_(std::move(entry_point_name)),
|
||||
relocation_model_(relocation_model),
|
||||
tensorflow_header_root_(std::move(tensorflow_header_root)) {}
|
||||
relocation_model_(relocation_model) {}
|
||||
|
||||
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
|
||||
|
||||
|
@ -53,16 +53,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
|
||||
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
||||
string entry_point_name,
|
||||
RelocationModel relocation_model,
|
||||
string tensorflow_header_root);
|
||||
|
||||
CpuAotCompilationOptions(string triple, string cpu_name, string features,
|
||||
string entry_point_name,
|
||||
RelocationModel relocation_model)
|
||||
: CpuAotCompilationOptions(
|
||||
std::move(triple), std::move(cpu_name), std::move(features),
|
||||
std::move(entry_point_name), relocation_model,
|
||||
/*tensorflow_header_root=*/"third_party/tensorflow") {}
|
||||
RelocationModel relocation_model);
|
||||
|
||||
~CpuAotCompilationOptions() override;
|
||||
|
||||
@ -76,10 +67,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
const string& features() const { return features_; }
|
||||
// The name to be used for the compiled code's entry point.
|
||||
const string& entry_point_name() const { return entry_point_name_; }
|
||||
// The prefix for tensorflow headers, e.g. "third_party/tensorflow".
|
||||
const string& tensorflow_header_root() const {
|
||||
return tensorflow_header_root_;
|
||||
}
|
||||
// The relocation model used for compilation.
|
||||
RelocationModel relocation_model() const { return relocation_model_; }
|
||||
|
||||
@ -89,7 +76,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
|
||||
const string features_;
|
||||
const string entry_point_name_;
|
||||
const RelocationModel relocation_model_;
|
||||
const string tensorflow_header_root_;
|
||||
};
|
||||
|
||||
class CpuAotCompilationResult : public AotCompilationResult {
|
||||
|
@ -39,8 +39,8 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
|
||||
std::string entry_point, std::string cpp_class,
|
||||
std::string out_function_object, std::string out_metadata_object,
|
||||
std::string out_header, std::string out_session_module,
|
||||
std::string mlir_components, std::string tensorflow_header_root,
|
||||
bool gen_name_to_index, bool gen_program_shape) {
|
||||
std::string mlir_components, bool gen_name_to_index,
|
||||
bool gen_program_shape) {
|
||||
tensorflow::tfcompile::MainFlags flags;
|
||||
flags.graph = std::move(graph);
|
||||
flags.config = std::move(config);
|
||||
@ -54,7 +54,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
|
||||
flags.out_header = std::move(out_header);
|
||||
flags.out_session_module = std::move(out_session_module);
|
||||
flags.mlir_components = std::move(mlir_components);
|
||||
flags.tensorflow_header_root = std::move(tensorflow_header_root);
|
||||
|
||||
// C++ codegen options
|
||||
flags.gen_name_to_index = gen_name_to_index;
|
||||
@ -68,8 +67,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
|
||||
py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o",
|
||||
py::arg("out_metadata_object") = "out_helper.o",
|
||||
py::arg("out_header") = "out.h", py::arg("out_session_module") = "",
|
||||
py::arg("mlir_components") = "",
|
||||
py::arg("tensorflow_header_root") = "third_party/tensorflow",
|
||||
py::arg("gen_name_to_index") = false,
|
||||
py::arg("mlir_components") = "", py::arg("gen_name_to_index") = false,
|
||||
py::arg("gen_program_shape") = false);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Description:
|
||||
# Tools for manipulating TensorFlow graphs.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test", "tf_cc_test")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -343,10 +343,63 @@ py_test(
|
||||
"no-internal-py3",
|
||||
"nosan",
|
||||
],
|
||||
# Force-include XLA dependencies of saved_model_cli_lib to ensure we test
|
||||
# the AOT compilation.
|
||||
deps = [
|
||||
":saved_model_cli_lib",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "aot_compiled_x_plus_y_gen",
|
||||
srcs = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
"//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb",
|
||||
],
|
||||
outs = [
|
||||
"compiled_model.h",
|
||||
"compiled_model.o",
|
||||
"compiled_model_metadata.o",
|
||||
"compiled_model_makefile.inc",
|
||||
],
|
||||
cmd = (
|
||||
"$(location :saved_model_cli) aot_compile_cpu " +
|
||||
"--dir \"$$(dirname $(location //tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb))\" " +
|
||||
"--output_prefix $(@D)/compiled_model " +
|
||||
"--cpp_class CompiledModel " +
|
||||
"--tag_set serve "
|
||||
),
|
||||
tools = [
|
||||
":saved_model_cli",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "aot_compiled_x_plus_y",
|
||||
srcs = if_xla_available([
|
||||
":compiled_model.o",
|
||||
":compiled_model_metadata.o",
|
||||
]),
|
||||
hdrs = if_xla_available([
|
||||
":compiled_model.h",
|
||||
]),
|
||||
deps = if_xla_available([
|
||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||
"//tensorflow/core/platform:types",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "binary_using_aot_compiled_x_plus_y_test",
|
||||
srcs = if_xla_available([
|
||||
"binary_using_aot_compiled_x_plus_y_test.cc",
|
||||
]),
|
||||
deps = [
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_xla_available([
|
||||
":aot_compiled_x_plus_y",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core/platform:logging",
|
||||
]),
|
||||
)
|
||||
|
@ -0,0 +1,30 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/python/tools/compiled_model.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
TEST(AOTCompiledSavedModelTest, Run) {
|
||||
CompiledModel model;
|
||||
*model.arg_feed_x_data() = 3.0f;
|
||||
*model.arg_feed_y_data() = 4.0f;
|
||||
CHECK(model.Run());
|
||||
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -101,6 +101,16 @@ def _sysconfig_module():
|
||||
return sysconfig_lib
|
||||
|
||||
|
||||
def _parse_tensor_name(name):
|
||||
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
|
||||
if ':' in name and not name.endswith(':'):
|
||||
node_name = name[:name.rfind(':')]
|
||||
output_slot = int(name[name.rfind(':') + 1:])
|
||||
return node_name, output_slot
|
||||
else:
|
||||
return name, None
|
||||
|
||||
|
||||
_XLA_MAKEFILE_TEMPLATE = """
|
||||
INC = -I{tensorflow_includes}
|
||||
LIB = -L{compiled_dir}
|
||||
@ -134,7 +144,11 @@ def _xla_makefile_string(output_prefix):
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
|
||||
else:
|
||||
base = test.test_src_dir_path('')
|
||||
try:
|
||||
base = test.test_src_dir_path('')
|
||||
except KeyError: # Can't find TEST_SRCDIR in environment path.
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
|
||||
expected_header = os.path.join(
|
||||
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
|
||||
if not os.path.exists(expected_header):
|
||||
@ -164,6 +178,47 @@ def _show_tag_sets(saved_model_dir):
|
||||
print('%r' % ', '.join(sorted(tag_set)))
|
||||
|
||||
|
||||
def _get_variable_nodes_from_graph_def(graph_def):
|
||||
"""Get the list of Variable nodes from `graph_def`.
|
||||
|
||||
Args:
|
||||
graph_def: An instance of `GraphDef`.
|
||||
|
||||
Returns:
|
||||
A list of `NodeDef` corresponding to variables in the graph.
|
||||
"""
|
||||
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
|
||||
|
||||
for f in graph_def.library.function:
|
||||
variables += [n for n in f.node_def if n.op == 'VarHandleOp']
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def _prune_removed_feed_nodes(signature_def, graph_def):
|
||||
"""Identify the inputs in the signature no longer in graph_def, prune them.
|
||||
|
||||
Args:
|
||||
signature_def: A `SignatureDef` instance.
|
||||
graph_def: A `GraphDef` instance.
|
||||
|
||||
Returns:
|
||||
A new pruned `SignatureDef`.
|
||||
"""
|
||||
node_names = set([n.name for n in graph_def.node])
|
||||
new_signature_def = meta_graph_pb2.SignatureDef()
|
||||
new_signature_def.CopyFrom(signature_def)
|
||||
for (k, v) in signature_def.inputs.items():
|
||||
tensor_name, _ = _parse_tensor_name(v.name)
|
||||
if tensor_name not in node_names:
|
||||
logging.warn(
|
||||
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
|
||||
'while freezing the graph. Removing it from the compiled signatures.'
|
||||
.format(k, tensor_name))
|
||||
del new_signature_def.inputs[k]
|
||||
return new_signature_def
|
||||
|
||||
|
||||
def _show_signature_def_map_keys(saved_model_dir, tag_set):
|
||||
"""Prints the keys for each SignatureDef in the SignatureDef map.
|
||||
|
||||
@ -882,23 +937,28 @@ def aot_compile_cpu(args):
|
||||
checkpoint_path = (
|
||||
args.checkpoint_path
|
||||
or os.path.join(args.dir, 'variables/variables'))
|
||||
if not args.variables_to_feed:
|
||||
variables_to_feed = []
|
||||
elif args.variables_to_feed.lower() == 'all':
|
||||
variables_to_feed = None # We will identify them after.
|
||||
else:
|
||||
variables_to_feed = args.variables_to_feed.split(',')
|
||||
aot_compile_cpu_meta_graph_def(
|
||||
checkpoint_path=checkpoint_path,
|
||||
meta_graph_def=saved_model_utils.get_meta_graph_def(
|
||||
args.dir, args.tag_set),
|
||||
signature_def_key=args.signature_def_key,
|
||||
freeze_graph=args.freeze_graph,
|
||||
variables_to_feed=variables_to_feed,
|
||||
output_prefix=args.output_prefix,
|
||||
cpp_class=args.cpp_class)
|
||||
|
||||
|
||||
def aot_compile_cpu_meta_graph_def(
|
||||
checkpoint_path,
|
||||
meta_graph_def,
|
||||
output_prefix,
|
||||
signature_def_key,
|
||||
cpp_class,
|
||||
freeze_graph=True):
|
||||
def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
meta_graph_def,
|
||||
output_prefix,
|
||||
signature_def_key,
|
||||
cpp_class,
|
||||
variables_to_feed=()):
|
||||
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
||||
|
||||
Use XLA AOT (`tfcompile`) to convert the given meta graph and
|
||||
@ -920,7 +980,10 @@ def aot_compile_cpu_meta_graph_def(
|
||||
output_prefix: Python string. Path prefix for outputs.
|
||||
signature_def_key: String, the signature_def to use in the SavedModel.
|
||||
cpp_class: Name of output C++ class.
|
||||
freeze_graph: Whether to freeze the graph before compilation.
|
||||
variables_to_feed: A list of strings, the variables that will be fed by the
|
||||
user; these won't be frozen. If `None`, then we will extract all the
|
||||
variables in the graph and mark them as to-feed. The default behavior is
|
||||
an empty tuple: all variables must be frozen.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If tensorflow was not built with XLA.
|
||||
@ -945,32 +1008,62 @@ def aot_compile_cpu_meta_graph_def(
|
||||
'Signature key {} must have outputs, but saw none:\n{}'.format(
|
||||
signature_def_key, str(signature_def)))
|
||||
|
||||
temp_dir = test.get_temp_dir()
|
||||
file_io.recursive_create_dir(temp_dir)
|
||||
if logging.get_verbosity() >= logging.INFO:
|
||||
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
|
||||
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
|
||||
|
||||
# This updates graph_def in place.
|
||||
_replace_input_placeholders_with_default_values(
|
||||
meta_graph_def.graph_def, signature_def)
|
||||
graph_def = _optimize_graph(meta_graph_def, signature_def)
|
||||
|
||||
if freeze_graph:
|
||||
# Load the Variables so that we can freeze the graph.
|
||||
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||
restorer = saver_lib.import_meta_graph(
|
||||
meta_graph_def, clear_devices=True)
|
||||
restorer.restore(sess, checkpoint_path)
|
||||
graph_def.CopyFrom(
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
graph_def,
|
||||
[n.name.split(':')[0] for n in signature_def.outputs.values()]))
|
||||
all_variables = _get_variable_nodes_from_graph_def(graph_def)
|
||||
if variables_to_feed is None:
|
||||
variable_nodes_to_feed = list(all_variables)
|
||||
else:
|
||||
not_in_graph = (
|
||||
set(variables_to_feed).difference([x.name for x in all_variables]))
|
||||
if not_in_graph:
|
||||
raise ValueError(
|
||||
'Asked to feed variables that were not found in graph: {}. '
|
||||
'Variables contained in the graph: {}'.format(
|
||||
not_in_graph, set([x.name for x in all_variables])))
|
||||
all_variables_map = dict((x.name, x) for x in all_variables)
|
||||
variable_nodes_to_feed = [
|
||||
all_variables_map[name] for name in variables_to_feed
|
||||
]
|
||||
|
||||
if logging.get_verbosity() >= logging.INFO:
|
||||
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
|
||||
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
|
||||
|
||||
# Load the Variables so that we can freeze the graph.
|
||||
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
|
||||
restorer.restore(sess, checkpoint_path)
|
||||
graph_def.CopyFrom(
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
graph_def,
|
||||
output_node_names=[
|
||||
_parse_tensor_name(n.name)[0]
|
||||
for n in signature_def.outputs.values()
|
||||
],
|
||||
))
|
||||
|
||||
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
|
||||
|
||||
temp_dir = test.get_temp_dir()
|
||||
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
|
||||
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
|
||||
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
|
||||
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(graph_def.SerializeToString())
|
||||
config = _signature_to_tf2xla_config(
|
||||
signature_def,
|
||||
frozen_variables=freeze_graph)
|
||||
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
|
||||
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
|
||||
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
|
||||
config_writer.write(str(config))
|
||||
@ -991,13 +1084,6 @@ def aot_compile_cpu_meta_graph_def(
|
||||
|
||||
output_prefix = _shlex_quote(output_prefix)
|
||||
|
||||
additional_compiler_args = {}
|
||||
sysconfig = _sysconfig_module()
|
||||
if sysconfig:
|
||||
# We're inside PIP and need to pass a customized relative path to the
|
||||
# appropriate tensorflow headers.
|
||||
additional_compiler_args['tensorflow_header_root'] = 'tensorflow'
|
||||
|
||||
_pywrap_tfcompile.Compile(
|
||||
graph=frozen_graph_def_location,
|
||||
config=config_pbtxt_location,
|
||||
@ -1008,8 +1094,7 @@ def aot_compile_cpu_meta_graph_def(
|
||||
out_metadata_object='{}_metadata.o'.format(output_prefix),
|
||||
gen_name_to_index=True,
|
||||
# ProgramShape isn't uniquefied by entry_point.
|
||||
gen_program_shape=False,
|
||||
**additional_compiler_args)
|
||||
gen_program_shape=False)
|
||||
|
||||
|
||||
def _optimize_graph(meta_graph_def, signature_def):
|
||||
@ -1034,7 +1119,7 @@ def _replace_input_placeholders_with_default_values(graph_def, signature_def):
|
||||
name_to_node_map = dict((n.name, n) for n in graph_def.node)
|
||||
temp_graph = ops_lib.Graph()
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
tensor_name = input_.name.split(':')[0]
|
||||
tensor_name, _ = _parse_tensor_name(input_.name)
|
||||
if tensor_name not in name_to_node_map:
|
||||
raise RuntimeError(
|
||||
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
|
||||
@ -1330,13 +1415,16 @@ def add_aot_compile_cpu_subparser(subparsers):
|
||||
'The class will be generated in the given namespace(s), or if no '
|
||||
'namespaces are given, within the global namespace.'))
|
||||
parser_compile.add_argument(
|
||||
'--freeze_graph',
|
||||
type=bool,
|
||||
default=True,
|
||||
help=('Whether to freeze the tf.Variables into the graph. If false, '
|
||||
'then all Variables in the closure of the signature graph path '
|
||||
'be be added as input and output args to the XLA-compiled graph '
|
||||
'(not currently supported)'))
|
||||
'--variables_to_feed',
|
||||
type=str,
|
||||
default='',
|
||||
help=('The names of variables that will be fed into the network. '
|
||||
'Options are: empty (default; all variables are frozen, none may '
|
||||
'be fed), \'all\' (all variables may be fed), or a '
|
||||
'comma-delimited list of names of variables that may be fed. In '
|
||||
'the last case, the non-fed variables will be frozen in the graph.')
|
||||
)
|
||||
|
||||
parser_compile.set_defaults(func=aot_compile_cpu)
|
||||
|
||||
|
||||
@ -1371,12 +1459,13 @@ def create_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def _signature_to_tf2xla_config(signature_def, frozen_variables):
|
||||
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
|
||||
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
|
||||
|
||||
Args:
|
||||
signature_def: Instance of `SignatureDef`.
|
||||
frozen_variables: Python bool, whether variables are being frozen or not.
|
||||
variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp,
|
||||
the list of variables to feed.
|
||||
|
||||
Returns:
|
||||
An instance of `tf2xla.Config` proto.
|
||||
@ -1390,7 +1479,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
|
||||
tensor_id = tf2xla_pb2.TensorId
|
||||
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
(node_name, output_index) = input_.name.split(':')
|
||||
name = name.replace('/', '_')
|
||||
name = 'feed_{}'.format(name)
|
||||
(node_name, output_index) = _parse_tensor_name(input_.name)
|
||||
output_index = int(output_index)
|
||||
config.feed.append(
|
||||
tf2xla_pb2.Feed(
|
||||
@ -1399,7 +1490,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
|
||||
type=input_.dtype,
|
||||
shape=input_.tensor_shape))
|
||||
for name, output_ in signature_def.outputs.items():
|
||||
(node_name, output_index) = output_.name.split(':')
|
||||
name = name.replace('/', '_')
|
||||
name = 'fetch_{}'.format(name)
|
||||
(node_name, output_index) = _parse_tensor_name(output_.name)
|
||||
output_index = int(output_index)
|
||||
config.fetch.append(
|
||||
tf2xla_pb2.Fetch(
|
||||
@ -1407,14 +1500,22 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
|
||||
name=name,
|
||||
type=output_.dtype,
|
||||
shape=output_.tensor_shape))
|
||||
if not frozen_variables:
|
||||
# Extract all variables along the path and add to config
|
||||
raise NotImplementedError('Non-frozen graphs are not supported.')
|
||||
for node in variable_nodes_to_feed:
|
||||
name = node.name.replace('/', '_')
|
||||
name = 'param_{}'.format(name)
|
||||
config.variable.append(
|
||||
tf2xla_pb2.Variable(
|
||||
node_name=node.name,
|
||||
name=name,
|
||||
type=node.attr['dtype'].type,
|
||||
shape=node.attr['shape'].shape,
|
||||
readonly=True))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
logging.set_verbosity(logging.INFO)
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
if not hasattr(args, 'func'):
|
||||
|
@ -25,6 +25,7 @@ import pickle
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six import StringIO
|
||||
|
||||
@ -38,6 +39,7 @@ from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.tools import saved_model_cli
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
@ -56,7 +58,7 @@ def captured_output():
|
||||
sys.stdout, sys.stderr = old_out, old_err
|
||||
|
||||
|
||||
class SavedModelCLITestCase(test.TestCase):
|
||||
class SavedModelCLITestCase(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testShowCommandAll(self):
|
||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||
@ -726,35 +728,44 @@ Defined Functions:
|
||||
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
|
||||
def testAOTCompileCPUFreezesAndCompiles(self):
|
||||
class AOTCompileDummyModel(tracking.AutoTrackable):
|
||||
"""Model compatible with XLA compilation."""
|
||||
|
||||
def __init__(self):
|
||||
self.var = variables.Variable(1.0, name='my_var')
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
||||
])
|
||||
def func2(self, x, y):
|
||||
return {'res': x + self.var}
|
||||
|
||||
@parameterized.named_parameters(('VariablesToFeedNone', ''),
|
||||
('VariablesToFeedAll', 'all'),
|
||||
('VariablesToFeedMyVar', 'my_var'))
|
||||
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed):
|
||||
if not test.is_built_with_xla():
|
||||
self.skipTest('Skipping test because XLA is not compiled in.')
|
||||
|
||||
class DummyModel(tracking.AutoTrackable):
|
||||
"""Model compatible with XLA compilation."""
|
||||
|
||||
def __init__(self):
|
||||
self.var = variables.Variable(1.0, name='my_var')
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
|
||||
])
|
||||
def func2(self, x):
|
||||
return {'res': x + self.var}
|
||||
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||
dummy_model = DummyModel()
|
||||
dummy_model = self.AOTCompileDummyModel()
|
||||
with self.cached_session():
|
||||
self.evaluate(dummy_model.var.initializer)
|
||||
save.save(dummy_model, saved_model_dir)
|
||||
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
|
||||
args = self.parser.parse_args(
|
||||
['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
||||
'--output_prefix', output_prefix,
|
||||
'--cpp_class', 'Generated']) # Use the default seving signature_key.
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
args = self.parser.parse_args([
|
||||
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
|
||||
'--output_prefix', output_prefix, '--variables_to_feed',
|
||||
variables_to_feed, '--cpp_class', 'Generated'
|
||||
]) # Use the default seving signature_key.
|
||||
with test.mock.patch.object(logging, 'warn') as captured_warn:
|
||||
saved_model_cli.aot_compile_cpu(args)
|
||||
self.assertRegexpMatches(
|
||||
str(captured_warn.call_args),
|
||||
'Signature input key \'y\'.*has been pruned while freezing the graph.')
|
||||
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
|
||||
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
|
||||
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
|
||||
@ -762,8 +773,12 @@ Defined Functions:
|
||||
file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
|
||||
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
|
||||
self.assertIn('class Generated', header_contents)
|
||||
self.assertIn('arg_x_data', header_contents)
|
||||
self.assertIn('result_res_data', header_contents)
|
||||
self.assertIn('arg_feed_x_data', header_contents)
|
||||
self.assertIn('result_fetch_res_data', header_contents)
|
||||
# arg_y got filtered out as it's not used by the output.
|
||||
self.assertNotIn('arg_feed_y_data', header_contents)
|
||||
if variables_to_feed:
|
||||
self.assertIn('var_param_my_var', header_contents)
|
||||
makefile_contents = file_io.read_file_to_string(
|
||||
'{}_makefile.inc'.format(output_prefix))
|
||||
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
|
||||
|
Loading…
x
Reference in New Issue
Block a user