[TF] Fixes and new features to saved_model_cli aot_compile_cpu, adds e2e test.

* Can now specify which variables can be fed/fetched.
* Bugfix when signature name contains slashes or starts with integers.
* Prune input config entries from tf2xla config when graph freeze removes
  unused input feed.
* Fixed a bug where third_party/tensorflow/ isn't properly renamed to tensorflow/
  in opensource HOST build (identified during the new genrule test).
  Solution: bring back the hardcoded #include in codegen.cc; it's always correct.

NOTE: The bugfix to the #include line in the compiler/ codebase is a partial rollback
of the initial tfcompile + saved_model_cli CL which moved from the hard-coded
include path to a parameterized value.  It turns out we don't need the complexity
of this approach and it's incorrect in the host opensource build.

TESTED:

Includes a bonafide genrule test which runs saved_model_cli to generate the
header and object files, and includes them in a c++ unit test and ensures
that they compile and the resulting object runs correctly.

PiperOrigin-RevId: 290655683
Change-Id: I4cfa2c595ebe56f8bdd47853f82371d97b92b081
This commit is contained in:
Eugene Brevdo 2020-01-20 15:55:51 -08:00 committed by TensorFlower Gardener
parent a901c88061
commit b26e1efece
14 changed files with 281 additions and 109 deletions

View File

@ -457,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
{{INCLUDE_XLA_DATA_PROTO}} {{INCLUDE_XLA_DATA_PROTO}}
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}} {{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "{{TF_HEADER_ROOT}}/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; } namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; } namespace xla { class ExecutableRunOptions; }
@ -659,7 +659,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{CLASS}}", opts.class_name}, {"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}", {"{{DECLS_FROM_OBJ_FILE}}",
absl::StrJoin(metadata_result.header_variable_decls, "\n")}, absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
{"{{ENTRY}}", compile_result.entry_point}, {"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim}, metadata_result.hlo_profile_printer_data_access_shim},

View File

@ -197,7 +197,6 @@ TEST(CodegenTest, Golden) {
variable3->mutable_shape()->add_dim()->set_size(5); variable3->mutable_shape()->add_dim()->set_size(5);
variable3->set_type(DT_INT32); variable3->set_type(DT_INT32);
CompileResult compile_result; CompileResult compile_result;
compile_result.tensorflow_header_root = "third_party/tensorflow";
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{}, {},
{BufferInfo::MakeTempBuffer(1), {BufferInfo::MakeTempBuffer(1),

View File

@ -85,7 +85,6 @@ Status CompileXla(xla::CompileOnlyClient* client,
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>( xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
std::move(aot_or.ValueOrDie().back())); std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name(); compile_result->entry_point = aot_opts.entry_point_name();
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
compile_result->pointer_size = compile_result->pointer_size =
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK(); return Status::OK();
@ -130,8 +129,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
xla::cpu::CpuAotCompilationOptions aot_opts( xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features, flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point, flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic, xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
flags.tensorflow_header_root);
return CompileXla(client, computation, aot_opts, compile_result); return CompileXla(client, computation, aot_opts, compile_result);
} }

View File

@ -35,7 +35,6 @@ struct CompileResult {
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot; std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShapeProto program_shape; // Static shape of args and results. xla::ProgramShapeProto program_shape; // Static shape of args and results.
string entry_point; // Name of generated function. string entry_point; // Name of generated function.
string tensorflow_header_root; // Prefix for tensorflow headers.
int pointer_size = 0; // Size of a pointer in bytes. int pointer_size = 0; // Size of a pointer in bytes.
}; };

View File

@ -74,8 +74,6 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Generate name-to-index data for Lookup{Arg,Result}Index methods."}, "Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape, {"gen_program_shape", &flags->gen_program_shape,
"Generate program shape data for the ProgramShape method."}, "Generate program shape data for the ProgramShape method."},
{"tensorflow_header_root", &flags->tensorflow_header_root,
"Root directory of tensorflow headers."},
}; };
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
} }

View File

@ -40,7 +40,6 @@ struct MainFlags {
string out_header; string out_header;
string out_session_module; string out_session_module;
string mlir_components; string mlir_components;
string tensorflow_header_root;
// C++ codegen options // C++ codegen options
bool gen_name_to_index = false; bool gen_name_to_index = false;

View File

@ -65,7 +65,6 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o"; flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h"; flags.out_header = "out.h";
flags.entry_point = "entry"; flags.entry_point = "entry";
flags.tensorflow_header_root = "third_party/tensorflow";
std::vector<tensorflow::Flag> flag_list; std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags); AppendMainFlags(&flag_list, &flags);

View File

@ -119,13 +119,12 @@ using BufferInfo = cpu_function_runtime::BufferInfo;
CpuAotCompilationOptions::CpuAotCompilationOptions( CpuAotCompilationOptions::CpuAotCompilationOptions(
string triple, string cpu_name, string features, string entry_point_name, string triple, string cpu_name, string features, string entry_point_name,
RelocationModel relocation_model, string tensorflow_header_root) RelocationModel relocation_model)
: triple_(std::move(triple)), : triple_(std::move(triple)),
cpu_name_(std::move(cpu_name)), cpu_name_(std::move(cpu_name)),
features_(std::move(features)), features_(std::move(features)),
entry_point_name_(std::move(entry_point_name)), entry_point_name_(std::move(entry_point_name)),
relocation_model_(relocation_model), relocation_model_(relocation_model) {}
tensorflow_header_root_(std::move(tensorflow_header_root)) {}
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default; CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;

View File

@ -53,16 +53,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
CpuAotCompilationOptions(string triple, string cpu_name, string features, CpuAotCompilationOptions(string triple, string cpu_name, string features,
string entry_point_name, string entry_point_name,
RelocationModel relocation_model, RelocationModel relocation_model);
string tensorflow_header_root);
CpuAotCompilationOptions(string triple, string cpu_name, string features,
string entry_point_name,
RelocationModel relocation_model)
: CpuAotCompilationOptions(
std::move(triple), std::move(cpu_name), std::move(features),
std::move(entry_point_name), relocation_model,
/*tensorflow_header_root=*/"third_party/tensorflow") {}
~CpuAotCompilationOptions() override; ~CpuAotCompilationOptions() override;
@ -76,10 +67,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
const string& features() const { return features_; } const string& features() const { return features_; }
// The name to be used for the compiled code's entry point. // The name to be used for the compiled code's entry point.
const string& entry_point_name() const { return entry_point_name_; } const string& entry_point_name() const { return entry_point_name_; }
// The prefix for tensorflow headers, e.g. "third_party/tensorflow".
const string& tensorflow_header_root() const {
return tensorflow_header_root_;
}
// The relocation model used for compilation. // The relocation model used for compilation.
RelocationModel relocation_model() const { return relocation_model_; } RelocationModel relocation_model() const { return relocation_model_; }
@ -89,7 +76,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
const string features_; const string features_;
const string entry_point_name_; const string entry_point_name_;
const RelocationModel relocation_model_; const RelocationModel relocation_model_;
const string tensorflow_header_root_;
}; };
class CpuAotCompilationResult : public AotCompilationResult { class CpuAotCompilationResult : public AotCompilationResult {

View File

@ -39,8 +39,8 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
std::string entry_point, std::string cpp_class, std::string entry_point, std::string cpp_class,
std::string out_function_object, std::string out_metadata_object, std::string out_function_object, std::string out_metadata_object,
std::string out_header, std::string out_session_module, std::string out_header, std::string out_session_module,
std::string mlir_components, std::string tensorflow_header_root, std::string mlir_components, bool gen_name_to_index,
bool gen_name_to_index, bool gen_program_shape) { bool gen_program_shape) {
tensorflow::tfcompile::MainFlags flags; tensorflow::tfcompile::MainFlags flags;
flags.graph = std::move(graph); flags.graph = std::move(graph);
flags.config = std::move(config); flags.config = std::move(config);
@ -54,7 +54,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
flags.out_header = std::move(out_header); flags.out_header = std::move(out_header);
flags.out_session_module = std::move(out_session_module); flags.out_session_module = std::move(out_session_module);
flags.mlir_components = std::move(mlir_components); flags.mlir_components = std::move(mlir_components);
flags.tensorflow_header_root = std::move(tensorflow_header_root);
// C++ codegen options // C++ codegen options
flags.gen_name_to_index = gen_name_to_index; flags.gen_name_to_index = gen_name_to_index;
@ -68,8 +67,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o", py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o",
py::arg("out_metadata_object") = "out_helper.o", py::arg("out_metadata_object") = "out_helper.o",
py::arg("out_header") = "out.h", py::arg("out_session_module") = "", py::arg("out_header") = "out.h", py::arg("out_session_module") = "",
py::arg("mlir_components") = "", py::arg("mlir_components") = "", py::arg("gen_name_to_index") = false,
py::arg("tensorflow_header_root") = "third_party/tensorflow",
py::arg("gen_name_to_index") = false,
py::arg("gen_program_shape") = false); py::arg("gen_program_shape") = false);
} }

View File

@ -1,7 +1,7 @@
# Description: # Description:
# Tools for manipulating TensorFlow graphs. # Tools for manipulating TensorFlow graphs.
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test") load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test", "tf_cc_test")
package( package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
@ -343,10 +343,63 @@ py_test(
"no-internal-py3", "no-internal-py3",
"nosan", "nosan",
], ],
# Force-include XLA dependencies of saved_model_cli_lib to ensure we test
# the AOT compilation.
deps = [ deps = [
":saved_model_cli_lib", ":saved_model_cli_lib",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
], ],
) )
genrule(
name = "aot_compiled_x_plus_y_gen",
srcs = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
"//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb",
],
outs = [
"compiled_model.h",
"compiled_model.o",
"compiled_model_metadata.o",
"compiled_model_makefile.inc",
],
cmd = (
"$(location :saved_model_cli) aot_compile_cpu " +
"--dir \"$$(dirname $(location //tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb))\" " +
"--output_prefix $(@D)/compiled_model " +
"--cpp_class CompiledModel " +
"--tag_set serve "
),
tools = [
":saved_model_cli",
],
)
cc_library(
name = "aot_compiled_x_plus_y",
srcs = if_xla_available([
":compiled_model.o",
":compiled_model_metadata.o",
]),
hdrs = if_xla_available([
":compiled_model.h",
]),
deps = if_xla_available([
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core/platform:types",
]),
)
tf_cc_test(
name = "binary_using_aot_compiled_x_plus_y_test",
srcs = if_xla_available([
"binary_using_aot_compiled_x_plus_y_test.cc",
]),
deps = [
"//tensorflow/core:test_main",
] + if_xla_available([
":aot_compiled_x_plus_y",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
]),
)

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/python/tools/compiled_model.h"
namespace tensorflow {
namespace {
TEST(AOTCompiledSavedModelTest, Run) {
CompiledModel model;
*model.arg_feed_x_data() = 3.0f;
*model.arg_feed_y_data() = 4.0f;
CHECK(model.Run());
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
}
} // namespace
} // namespace tensorflow

View File

@ -101,6 +101,16 @@ def _sysconfig_module():
return sysconfig_lib return sysconfig_lib
def _parse_tensor_name(name):
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
if ':' in name and not name.endswith(':'):
node_name = name[:name.rfind(':')]
output_slot = int(name[name.rfind(':') + 1:])
return node_name, output_slot
else:
return name, None
_XLA_MAKEFILE_TEMPLATE = """ _XLA_MAKEFILE_TEMPLATE = """
INC = -I{tensorflow_includes} INC = -I{tensorflow_includes}
LIB = -L{compiled_dir} LIB = -L{compiled_dir}
@ -134,7 +144,11 @@ def _xla_makefile_string(output_prefix):
base = os.path.realpath( base = os.path.realpath(
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
else: else:
base = test.test_src_dir_path('') try:
base = test.test_src_dir_path('')
except KeyError: # Can't find TEST_SRCDIR in environment path.
base = os.path.realpath(
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
expected_header = os.path.join( expected_header = os.path.join(
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
if not os.path.exists(expected_header): if not os.path.exists(expected_header):
@ -164,6 +178,47 @@ def _show_tag_sets(saved_model_dir):
print('%r' % ', '.join(sorted(tag_set))) print('%r' % ', '.join(sorted(tag_set)))
def _get_variable_nodes_from_graph_def(graph_def):
"""Get the list of Variable nodes from `graph_def`.
Args:
graph_def: An instance of `GraphDef`.
Returns:
A list of `NodeDef` corresponding to variables in the graph.
"""
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
for f in graph_def.library.function:
variables += [n for n in f.node_def if n.op == 'VarHandleOp']
return variables
def _prune_removed_feed_nodes(signature_def, graph_def):
"""Identify the inputs in the signature no longer in graph_def, prune them.
Args:
signature_def: A `SignatureDef` instance.
graph_def: A `GraphDef` instance.
Returns:
A new pruned `SignatureDef`.
"""
node_names = set([n.name for n in graph_def.node])
new_signature_def = meta_graph_pb2.SignatureDef()
new_signature_def.CopyFrom(signature_def)
for (k, v) in signature_def.inputs.items():
tensor_name, _ = _parse_tensor_name(v.name)
if tensor_name not in node_names:
logging.warn(
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
'while freezing the graph. Removing it from the compiled signatures.'
.format(k, tensor_name))
del new_signature_def.inputs[k]
return new_signature_def
def _show_signature_def_map_keys(saved_model_dir, tag_set): def _show_signature_def_map_keys(saved_model_dir, tag_set):
"""Prints the keys for each SignatureDef in the SignatureDef map. """Prints the keys for each SignatureDef in the SignatureDef map.
@ -882,23 +937,28 @@ def aot_compile_cpu(args):
checkpoint_path = ( checkpoint_path = (
args.checkpoint_path args.checkpoint_path
or os.path.join(args.dir, 'variables/variables')) or os.path.join(args.dir, 'variables/variables'))
if not args.variables_to_feed:
variables_to_feed = []
elif args.variables_to_feed.lower() == 'all':
variables_to_feed = None # We will identify them after.
else:
variables_to_feed = args.variables_to_feed.split(',')
aot_compile_cpu_meta_graph_def( aot_compile_cpu_meta_graph_def(
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
meta_graph_def=saved_model_utils.get_meta_graph_def( meta_graph_def=saved_model_utils.get_meta_graph_def(
args.dir, args.tag_set), args.dir, args.tag_set),
signature_def_key=args.signature_def_key, signature_def_key=args.signature_def_key,
freeze_graph=args.freeze_graph, variables_to_feed=variables_to_feed,
output_prefix=args.output_prefix, output_prefix=args.output_prefix,
cpp_class=args.cpp_class) cpp_class=args.cpp_class)
def aot_compile_cpu_meta_graph_def( def aot_compile_cpu_meta_graph_def(checkpoint_path,
checkpoint_path, meta_graph_def,
meta_graph_def, output_prefix,
output_prefix, signature_def_key,
signature_def_key, cpp_class,
cpp_class, variables_to_feed=()):
freeze_graph=True):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`. """Compile a `MetaGraphDef` to header+object files in `output_prefix`.
Use XLA AOT (`tfcompile`) to convert the given meta graph and Use XLA AOT (`tfcompile`) to convert the given meta graph and
@ -920,7 +980,10 @@ def aot_compile_cpu_meta_graph_def(
output_prefix: Python string. Path prefix for outputs. output_prefix: Python string. Path prefix for outputs.
signature_def_key: String, the signature_def to use in the SavedModel. signature_def_key: String, the signature_def to use in the SavedModel.
cpp_class: Name of output C++ class. cpp_class: Name of output C++ class.
freeze_graph: Whether to freeze the graph before compilation. variables_to_feed: A list of strings, the variables that will be fed by the
user; these won't be frozen. If `None`, then we will extract all the
variables in the graph and mark them as to-feed. The default behavior is
an empty tuple: all variables must be frozen.
Raises: Raises:
RuntimeError: If tensorflow was not built with XLA. RuntimeError: If tensorflow was not built with XLA.
@ -945,32 +1008,62 @@ def aot_compile_cpu_meta_graph_def(
'Signature key {} must have outputs, but saw none:\n{}'.format( 'Signature key {} must have outputs, but saw none:\n{}'.format(
signature_def_key, str(signature_def))) signature_def_key, str(signature_def)))
temp_dir = test.get_temp_dir()
file_io.recursive_create_dir(temp_dir)
if logging.get_verbosity() >= logging.INFO:
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# This updates graph_def in place. # This updates graph_def in place.
_replace_input_placeholders_with_default_values( _replace_input_placeholders_with_default_values(
meta_graph_def.graph_def, signature_def) meta_graph_def.graph_def, signature_def)
graph_def = _optimize_graph(meta_graph_def, signature_def) graph_def = _optimize_graph(meta_graph_def, signature_def)
if freeze_graph: all_variables = _get_variable_nodes_from_graph_def(graph_def)
# Load the Variables so that we can freeze the graph. if variables_to_feed is None:
with session.Session(graph=ops_lib.Graph()) as sess: variable_nodes_to_feed = list(all_variables)
restorer = saver_lib.import_meta_graph( else:
meta_graph_def, clear_devices=True) not_in_graph = (
restorer.restore(sess, checkpoint_path) set(variables_to_feed).difference([x.name for x in all_variables]))
graph_def.CopyFrom( if not_in_graph:
graph_util.convert_variables_to_constants( raise ValueError(
sess, 'Asked to feed variables that were not found in graph: {}. '
graph_def, 'Variables contained in the graph: {}'.format(
[n.name.split(':')[0] for n in signature_def.outputs.values()])) not_in_graph, set([x.name for x in all_variables])))
all_variables_map = dict((x.name, x) for x in all_variables)
variable_nodes_to_feed = [
all_variables_map[name] for name in variables_to_feed
]
if logging.get_verbosity() >= logging.INFO:
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# Load the Variables so that we can freeze the graph.
with session.Session(graph=ops_lib.Graph()) as sess:
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
restorer.restore(sess, checkpoint_path)
graph_def.CopyFrom(
graph_util.convert_variables_to_constants(
sess,
graph_def,
output_node_names=[
_parse_tensor_name(n.name)[0]
for n in signature_def.outputs.values()
],
))
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
temp_dir = test.get_temp_dir()
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(graph_def.SerializeToString()) graph_writer.write(graph_def.SerializeToString())
config = _signature_to_tf2xla_config( config = _signature_to_tf2xla_config(
signature_def, signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
frozen_variables=freeze_graph)
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
config_writer.write(str(config)) config_writer.write(str(config))
@ -991,13 +1084,6 @@ def aot_compile_cpu_meta_graph_def(
output_prefix = _shlex_quote(output_prefix) output_prefix = _shlex_quote(output_prefix)
additional_compiler_args = {}
sysconfig = _sysconfig_module()
if sysconfig:
# We're inside PIP and need to pass a customized relative path to the
# appropriate tensorflow headers.
additional_compiler_args['tensorflow_header_root'] = 'tensorflow'
_pywrap_tfcompile.Compile( _pywrap_tfcompile.Compile(
graph=frozen_graph_def_location, graph=frozen_graph_def_location,
config=config_pbtxt_location, config=config_pbtxt_location,
@ -1008,8 +1094,7 @@ def aot_compile_cpu_meta_graph_def(
out_metadata_object='{}_metadata.o'.format(output_prefix), out_metadata_object='{}_metadata.o'.format(output_prefix),
gen_name_to_index=True, gen_name_to_index=True,
# ProgramShape isn't uniquefied by entry_point. # ProgramShape isn't uniquefied by entry_point.
gen_program_shape=False, gen_program_shape=False)
**additional_compiler_args)
def _optimize_graph(meta_graph_def, signature_def): def _optimize_graph(meta_graph_def, signature_def):
@ -1034,7 +1119,7 @@ def _replace_input_placeholders_with_default_values(graph_def, signature_def):
name_to_node_map = dict((n.name, n) for n in graph_def.node) name_to_node_map = dict((n.name, n) for n in graph_def.node)
temp_graph = ops_lib.Graph() temp_graph = ops_lib.Graph()
for name, input_ in signature_def.inputs.items(): for name, input_ in signature_def.inputs.items():
tensor_name = input_.name.split(':')[0] tensor_name, _ = _parse_tensor_name(input_.name)
if tensor_name not in name_to_node_map: if tensor_name not in name_to_node_map:
raise RuntimeError( raise RuntimeError(
'Unable to find input signature tensor \'{}\' in optimized GraphDef. ' 'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
@ -1330,13 +1415,16 @@ def add_aot_compile_cpu_subparser(subparsers):
'The class will be generated in the given namespace(s), or if no ' 'The class will be generated in the given namespace(s), or if no '
'namespaces are given, within the global namespace.')) 'namespaces are given, within the global namespace.'))
parser_compile.add_argument( parser_compile.add_argument(
'--freeze_graph', '--variables_to_feed',
type=bool, type=str,
default=True, default='',
help=('Whether to freeze the tf.Variables into the graph. If false, ' help=('The names of variables that will be fed into the network. '
'then all Variables in the closure of the signature graph path ' 'Options are: empty (default; all variables are frozen, none may '
'be be added as input and output args to the XLA-compiled graph ' 'be fed), \'all\' (all variables may be fed), or a '
'(not currently supported)')) 'comma-delimited list of names of variables that may be fed. In '
'the last case, the non-fed variables will be frozen in the graph.')
)
parser_compile.set_defaults(func=aot_compile_cpu) parser_compile.set_defaults(func=aot_compile_cpu)
@ -1371,12 +1459,13 @@ def create_parser():
return parser return parser
def _signature_to_tf2xla_config(signature_def, frozen_variables): def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
Args: Args:
signature_def: Instance of `SignatureDef`. signature_def: Instance of `SignatureDef`.
frozen_variables: Python bool, whether variables are being frozen or not. variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp,
the list of variables to feed.
Returns: Returns:
An instance of `tf2xla.Config` proto. An instance of `tf2xla.Config` proto.
@ -1390,7 +1479,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
tensor_id = tf2xla_pb2.TensorId tensor_id = tf2xla_pb2.TensorId
for name, input_ in signature_def.inputs.items(): for name, input_ in signature_def.inputs.items():
(node_name, output_index) = input_.name.split(':') name = name.replace('/', '_')
name = 'feed_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(input_.name)
output_index = int(output_index) output_index = int(output_index)
config.feed.append( config.feed.append(
tf2xla_pb2.Feed( tf2xla_pb2.Feed(
@ -1399,7 +1490,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
type=input_.dtype, type=input_.dtype,
shape=input_.tensor_shape)) shape=input_.tensor_shape))
for name, output_ in signature_def.outputs.items(): for name, output_ in signature_def.outputs.items():
(node_name, output_index) = output_.name.split(':') name = name.replace('/', '_')
name = 'fetch_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(output_.name)
output_index = int(output_index) output_index = int(output_index)
config.fetch.append( config.fetch.append(
tf2xla_pb2.Fetch( tf2xla_pb2.Fetch(
@ -1407,14 +1500,22 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
name=name, name=name,
type=output_.dtype, type=output_.dtype,
shape=output_.tensor_shape)) shape=output_.tensor_shape))
if not frozen_variables: for node in variable_nodes_to_feed:
# Extract all variables along the path and add to config name = node.name.replace('/', '_')
raise NotImplementedError('Non-frozen graphs are not supported.') name = 'param_{}'.format(name)
config.variable.append(
tf2xla_pb2.Variable(
node_name=node.name,
name=name,
type=node.attr['dtype'].type,
shape=node.attr['shape'].shape,
readonly=True))
return config return config
def main(): def main():
logging.set_verbosity(logging.INFO)
parser = create_parser() parser = create_parser()
args = parser.parse_args() args = parser.parse_args()
if not hasattr(args, 'func'): if not hasattr(args, 'func'):

View File

@ -25,6 +25,7 @@ import pickle
import shutil import shutil
import sys import sys
from absl.testing import parameterized
import numpy as np import numpy as np
from six import StringIO from six import StringIO
@ -38,6 +39,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save
from tensorflow.python.tools import saved_model_cli from tensorflow.python.tools import saved_model_cli
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
@ -56,7 +58,7 @@ def captured_output():
sys.stdout, sys.stderr = old_out, old_err sys.stdout, sys.stderr = old_out, old_err
class SavedModelCLITestCase(test.TestCase): class SavedModelCLITestCase(test.TestCase, parameterized.TestCase):
def testShowCommandAll(self): def testShowCommandAll(self):
base_path = test.test_src_dir_path(SAVED_MODEL_PATH) base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
@ -726,35 +728,44 @@ Defined Functions:
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'): with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
saved_model_cli.aot_compile_cpu(args) saved_model_cli.aot_compile_cpu(args)
def testAOTCompileCPUFreezesAndCompiles(self): class AOTCompileDummyModel(tracking.AutoTrackable):
"""Model compatible with XLA compilation."""
def __init__(self):
self.var = variables.Variable(1.0, name='my_var')
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
])
def func2(self, x, y):
return {'res': x + self.var}
@parameterized.named_parameters(('VariablesToFeedNone', ''),
('VariablesToFeedAll', 'all'),
('VariablesToFeedMyVar', 'my_var'))
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed):
if not test.is_built_with_xla(): if not test.is_built_with_xla():
self.skipTest('Skipping test because XLA is not compiled in.') self.skipTest('Skipping test because XLA is not compiled in.')
class DummyModel(tracking.AutoTrackable):
"""Model compatible with XLA compilation."""
def __init__(self):
self.var = variables.Variable(1.0, name='my_var')
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
])
def func2(self, x):
return {'res': x + self.var}
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
dummy_model = DummyModel() dummy_model = self.AOTCompileDummyModel()
with self.cached_session(): with self.cached_session():
self.evaluate(dummy_model.var.initializer) self.evaluate(dummy_model.var.initializer)
save.save(dummy_model, saved_model_dir) save.save(dummy_model, saved_model_dir)
self.parser = saved_model_cli.create_parser() self.parser = saved_model_cli.create_parser()
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out') output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
args = self.parser.parse_args( args = self.parser.parse_args([
['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve', 'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
'--output_prefix', output_prefix, '--output_prefix', output_prefix, '--variables_to_feed',
'--cpp_class', 'Generated']) # Use the default seving signature_key. variables_to_feed, '--cpp_class', 'Generated'
saved_model_cli.aot_compile_cpu(args) ]) # Use the default seving signature_key.
with test.mock.patch.object(logging, 'warn') as captured_warn:
saved_model_cli.aot_compile_cpu(args)
self.assertRegexpMatches(
str(captured_warn.call_args),
'Signature input key \'y\'.*has been pruned while freezing the graph.')
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
@ -762,8 +773,12 @@ Defined Functions:
file_io.file_exists('{}_makefile.inc'.format(output_prefix))) file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix)) header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
self.assertIn('class Generated', header_contents) self.assertIn('class Generated', header_contents)
self.assertIn('arg_x_data', header_contents) self.assertIn('arg_feed_x_data', header_contents)
self.assertIn('result_res_data', header_contents) self.assertIn('result_fetch_res_data', header_contents)
# arg_y got filtered out as it's not used by the output.
self.assertNotIn('arg_feed_y_data', header_contents)
if variables_to_feed:
self.assertIn('var_param_my_var', header_contents)
makefile_contents = file_io.read_file_to_string( makefile_contents = file_io.read_file_to_string(
'{}_makefile.inc'.format(output_prefix)) '{}_makefile.inc'.format(output_prefix))
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents) self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)