[TF] Fixes and new features to saved_model_cli aot_compile_cpu, adds e2e test.

* Can now specify which variables can be fed/fetched.
* Bugfix when signature name contains slashes or starts with integers.
* Prune input config entries from tf2xla config when graph freeze removes
  unused input feed.
* Fixed a bug where third_party/tensorflow/ isn't properly renamed to tensorflow/
  in opensource HOST build (identified during the new genrule test).
  Solution: bring back the hardcoded #include in codegen.cc; it's always correct.

NOTE: The bugfix to the #include line in the compiler/ codebase is a partial rollback
of the initial tfcompile + saved_model_cli CL which moved from the hard-coded
include path to a parameterized value.  It turns out we don't need the complexity
of this approach and it's incorrect in the host opensource build.

TESTED:

Includes a bonafide genrule test which runs saved_model_cli to generate the
header and object files, and includes them in a c++ unit test and ensures
that they compile and the resulting object runs correctly.

PiperOrigin-RevId: 290655683
Change-Id: I4cfa2c595ebe56f8bdd47853f82371d97b92b081
This commit is contained in:
Eugene Brevdo 2020-01-20 15:55:51 -08:00 committed by TensorFlower Gardener
parent a901c88061
commit b26e1efece
14 changed files with 281 additions and 109 deletions

View File

@ -457,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
{{INCLUDE_XLA_DATA_PROTO}}
{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "{{TF_HEADER_ROOT}}/core/platform/types.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
namespace xla { class ExecutableRunOptions; }
@ -659,7 +659,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
absl::StrJoin(metadata_result.header_variable_decls, "\n")},
{"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root},
{"{{ENTRY}}", compile_result.entry_point},
{"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
metadata_result.hlo_profile_printer_data_access_shim},

View File

@ -197,7 +197,6 @@ TEST(CodegenTest, Golden) {
variable3->mutable_shape()->add_dim()->set_size(5);
variable3->set_type(DT_INT32);
CompileResult compile_result;
compile_result.tensorflow_header_root = "third_party/tensorflow";
compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
{},
{BufferInfo::MakeTempBuffer(1),

View File

@ -85,7 +85,6 @@ Status CompileXla(xla::CompileOnlyClient* client,
xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
std::move(aot_or.ValueOrDie().back()));
compile_result->entry_point = aot_opts.entry_point_name();
compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root();
compile_result->pointer_size =
xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple());
return Status::OK();
@ -130,8 +129,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config,
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic,
flags.tensorflow_header_root);
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
return CompileXla(client, computation, aot_opts, compile_result);
}

View File

@ -35,7 +35,6 @@ struct CompileResult {
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
xla::ProgramShapeProto program_shape; // Static shape of args and results.
string entry_point; // Name of generated function.
string tensorflow_header_root; // Prefix for tensorflow headers.
int pointer_size = 0; // Size of a pointer in bytes.
};

View File

@ -74,8 +74,6 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape,
"Generate program shape data for the ProgramShape method."},
{"tensorflow_header_root", &flags->tensorflow_header_root,
"Root directory of tensorflow headers."},
};
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
}

View File

@ -40,7 +40,6 @@ struct MainFlags {
string out_header;
string out_session_module;
string mlir_components;
string tensorflow_header_root;
// C++ codegen options
bool gen_name_to_index = false;

View File

@ -65,7 +65,6 @@ int main(int argc, char** argv) {
flags.out_metadata_object = "out_helper.o";
flags.out_header = "out.h";
flags.entry_point = "entry";
flags.tensorflow_header_root = "third_party/tensorflow";
std::vector<tensorflow::Flag> flag_list;
AppendMainFlags(&flag_list, &flags);

View File

@ -119,13 +119,12 @@ using BufferInfo = cpu_function_runtime::BufferInfo;
CpuAotCompilationOptions::CpuAotCompilationOptions(
string triple, string cpu_name, string features, string entry_point_name,
RelocationModel relocation_model, string tensorflow_header_root)
RelocationModel relocation_model)
: triple_(std::move(triple)),
cpu_name_(std::move(cpu_name)),
features_(std::move(features)),
entry_point_name_(std::move(entry_point_name)),
relocation_model_(relocation_model),
tensorflow_header_root_(std::move(tensorflow_header_root)) {}
relocation_model_(relocation_model) {}
CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;

View File

@ -53,16 +53,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
CpuAotCompilationOptions(string triple, string cpu_name, string features,
string entry_point_name,
RelocationModel relocation_model,
string tensorflow_header_root);
CpuAotCompilationOptions(string triple, string cpu_name, string features,
string entry_point_name,
RelocationModel relocation_model)
: CpuAotCompilationOptions(
std::move(triple), std::move(cpu_name), std::move(features),
std::move(entry_point_name), relocation_model,
/*tensorflow_header_root=*/"third_party/tensorflow") {}
RelocationModel relocation_model);
~CpuAotCompilationOptions() override;
@ -76,10 +67,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
const string& features() const { return features_; }
// The name to be used for the compiled code's entry point.
const string& entry_point_name() const { return entry_point_name_; }
// The prefix for tensorflow headers, e.g. "third_party/tensorflow".
const string& tensorflow_header_root() const {
return tensorflow_header_root_;
}
// The relocation model used for compilation.
RelocationModel relocation_model() const { return relocation_model_; }
@ -89,7 +76,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
const string features_;
const string entry_point_name_;
const RelocationModel relocation_model_;
const string tensorflow_header_root_;
};
class CpuAotCompilationResult : public AotCompilationResult {

View File

@ -39,8 +39,8 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
std::string entry_point, std::string cpp_class,
std::string out_function_object, std::string out_metadata_object,
std::string out_header, std::string out_session_module,
std::string mlir_components, std::string tensorflow_header_root,
bool gen_name_to_index, bool gen_program_shape) {
std::string mlir_components, bool gen_name_to_index,
bool gen_program_shape) {
tensorflow::tfcompile::MainFlags flags;
flags.graph = std::move(graph);
flags.config = std::move(config);
@ -54,7 +54,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
flags.out_header = std::move(out_header);
flags.out_session_module = std::move(out_session_module);
flags.mlir_components = std::move(mlir_components);
flags.tensorflow_header_root = std::move(tensorflow_header_root);
// C++ codegen options
flags.gen_name_to_index = gen_name_to_index;
@ -68,8 +67,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) {
py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o",
py::arg("out_metadata_object") = "out_helper.o",
py::arg("out_header") = "out.h", py::arg("out_session_module") = "",
py::arg("mlir_components") = "",
py::arg("tensorflow_header_root") = "third_party/tensorflow",
py::arg("gen_name_to_index") = false,
py::arg("mlir_components") = "", py::arg("gen_name_to_index") = false,
py::arg("gen_program_shape") = false);
}

View File

@ -1,7 +1,7 @@
# Description:
# Tools for manipulating TensorFlow graphs.
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test")
load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test", "tf_cc_test")
package(
default_visibility = ["//visibility:public"],
@ -343,10 +343,63 @@ py_test(
"no-internal-py3",
"nosan",
],
# Force-include XLA dependencies of saved_model_cli_lib to ensure we test
# the AOT compilation.
deps = [
":saved_model_cli_lib",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"@absl_py//absl/testing:parameterized",
],
)
genrule(
name = "aot_compiled_x_plus_y_gen",
srcs = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
"//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb",
],
outs = [
"compiled_model.h",
"compiled_model.o",
"compiled_model_metadata.o",
"compiled_model_makefile.inc",
],
cmd = (
"$(location :saved_model_cli) aot_compile_cpu " +
"--dir \"$$(dirname $(location //tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb))\" " +
"--output_prefix $(@D)/compiled_model " +
"--cpp_class CompiledModel " +
"--tag_set serve "
),
tools = [
":saved_model_cli",
],
)
cc_library(
name = "aot_compiled_x_plus_y",
srcs = if_xla_available([
":compiled_model.o",
":compiled_model_metadata.o",
]),
hdrs = if_xla_available([
":compiled_model.h",
]),
deps = if_xla_available([
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core/platform:types",
]),
)
tf_cc_test(
name = "binary_using_aot_compiled_x_plus_y_test",
srcs = if_xla_available([
"binary_using_aot_compiled_x_plus_y_test.cc",
]),
deps = [
"//tensorflow/core:test_main",
] + if_xla_available([
":aot_compiled_x_plus_y",
"//tensorflow/core:test",
"//tensorflow/core/platform:logging",
]),
)

View File

@ -0,0 +1,30 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/python/tools/compiled_model.h"
namespace tensorflow {
namespace {
TEST(AOTCompiledSavedModelTest, Run) {
CompiledModel model;
*model.arg_feed_x_data() = 3.0f;
*model.arg_feed_y_data() = 4.0f;
CHECK(model.Run());
ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
}
} // namespace
} // namespace tensorflow

View File

@ -101,6 +101,16 @@ def _sysconfig_module():
return sysconfig_lib
def _parse_tensor_name(name):
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
if ':' in name and not name.endswith(':'):
node_name = name[:name.rfind(':')]
output_slot = int(name[name.rfind(':') + 1:])
return node_name, output_slot
else:
return name, None
_XLA_MAKEFILE_TEMPLATE = """
INC = -I{tensorflow_includes}
LIB = -L{compiled_dir}
@ -134,7 +144,11 @@ def _xla_makefile_string(output_prefix):
base = os.path.realpath(
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
else:
base = test.test_src_dir_path('')
try:
base = test.test_src_dir_path('')
except KeyError: # Can't find TEST_SRCDIR in environment path.
base = os.path.realpath(
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
expected_header = os.path.join(
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
if not os.path.exists(expected_header):
@ -164,6 +178,47 @@ def _show_tag_sets(saved_model_dir):
print('%r' % ', '.join(sorted(tag_set)))
def _get_variable_nodes_from_graph_def(graph_def):
"""Get the list of Variable nodes from `graph_def`.
Args:
graph_def: An instance of `GraphDef`.
Returns:
A list of `NodeDef` corresponding to variables in the graph.
"""
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
for f in graph_def.library.function:
variables += [n for n in f.node_def if n.op == 'VarHandleOp']
return variables
def _prune_removed_feed_nodes(signature_def, graph_def):
"""Identify the inputs in the signature no longer in graph_def, prune them.
Args:
signature_def: A `SignatureDef` instance.
graph_def: A `GraphDef` instance.
Returns:
A new pruned `SignatureDef`.
"""
node_names = set([n.name for n in graph_def.node])
new_signature_def = meta_graph_pb2.SignatureDef()
new_signature_def.CopyFrom(signature_def)
for (k, v) in signature_def.inputs.items():
tensor_name, _ = _parse_tensor_name(v.name)
if tensor_name not in node_names:
logging.warn(
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
'while freezing the graph. Removing it from the compiled signatures.'
.format(k, tensor_name))
del new_signature_def.inputs[k]
return new_signature_def
def _show_signature_def_map_keys(saved_model_dir, tag_set):
"""Prints the keys for each SignatureDef in the SignatureDef map.
@ -882,23 +937,28 @@ def aot_compile_cpu(args):
checkpoint_path = (
args.checkpoint_path
or os.path.join(args.dir, 'variables/variables'))
if not args.variables_to_feed:
variables_to_feed = []
elif args.variables_to_feed.lower() == 'all':
variables_to_feed = None # We will identify them after.
else:
variables_to_feed = args.variables_to_feed.split(',')
aot_compile_cpu_meta_graph_def(
checkpoint_path=checkpoint_path,
meta_graph_def=saved_model_utils.get_meta_graph_def(
args.dir, args.tag_set),
signature_def_key=args.signature_def_key,
freeze_graph=args.freeze_graph,
variables_to_feed=variables_to_feed,
output_prefix=args.output_prefix,
cpp_class=args.cpp_class)
def aot_compile_cpu_meta_graph_def(
checkpoint_path,
meta_graph_def,
output_prefix,
signature_def_key,
cpp_class,
freeze_graph=True):
def aot_compile_cpu_meta_graph_def(checkpoint_path,
meta_graph_def,
output_prefix,
signature_def_key,
cpp_class,
variables_to_feed=()):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
Use XLA AOT (`tfcompile`) to convert the given meta graph and
@ -920,7 +980,10 @@ def aot_compile_cpu_meta_graph_def(
output_prefix: Python string. Path prefix for outputs.
signature_def_key: String, the signature_def to use in the SavedModel.
cpp_class: Name of output C++ class.
freeze_graph: Whether to freeze the graph before compilation.
variables_to_feed: A list of strings, the variables that will be fed by the
user; these won't be frozen. If `None`, then we will extract all the
variables in the graph and mark them as to-feed. The default behavior is
an empty tuple: all variables must be frozen.
Raises:
RuntimeError: If tensorflow was not built with XLA.
@ -945,32 +1008,62 @@ def aot_compile_cpu_meta_graph_def(
'Signature key {} must have outputs, but saw none:\n{}'.format(
signature_def_key, str(signature_def)))
temp_dir = test.get_temp_dir()
file_io.recursive_create_dir(temp_dir)
if logging.get_verbosity() >= logging.INFO:
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# This updates graph_def in place.
_replace_input_placeholders_with_default_values(
meta_graph_def.graph_def, signature_def)
graph_def = _optimize_graph(meta_graph_def, signature_def)
if freeze_graph:
# Load the Variables so that we can freeze the graph.
with session.Session(graph=ops_lib.Graph()) as sess:
restorer = saver_lib.import_meta_graph(
meta_graph_def, clear_devices=True)
restorer.restore(sess, checkpoint_path)
graph_def.CopyFrom(
graph_util.convert_variables_to_constants(
sess,
graph_def,
[n.name.split(':')[0] for n in signature_def.outputs.values()]))
all_variables = _get_variable_nodes_from_graph_def(graph_def)
if variables_to_feed is None:
variable_nodes_to_feed = list(all_variables)
else:
not_in_graph = (
set(variables_to_feed).difference([x.name for x in all_variables]))
if not_in_graph:
raise ValueError(
'Asked to feed variables that were not found in graph: {}. '
'Variables contained in the graph: {}'.format(
not_in_graph, set([x.name for x in all_variables])))
all_variables_map = dict((x.name, x) for x in all_variables)
variable_nodes_to_feed = [
all_variables_map[name] for name in variables_to_feed
]
if logging.get_verbosity() >= logging.INFO:
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
# Load the Variables so that we can freeze the graph.
with session.Session(graph=ops_lib.Graph()) as sess:
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
restorer.restore(sess, checkpoint_path)
graph_def.CopyFrom(
graph_util.convert_variables_to_constants(
sess,
graph_def,
output_node_names=[
_parse_tensor_name(n.name)[0]
for n in signature_def.outputs.values()
],
))
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
temp_dir = test.get_temp_dir()
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
graph_writer.write(graph_def.SerializeToString())
config = _signature_to_tf2xla_config(
signature_def,
frozen_variables=freeze_graph)
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
config_writer.write(str(config))
@ -991,13 +1084,6 @@ def aot_compile_cpu_meta_graph_def(
output_prefix = _shlex_quote(output_prefix)
additional_compiler_args = {}
sysconfig = _sysconfig_module()
if sysconfig:
# We're inside PIP and need to pass a customized relative path to the
# appropriate tensorflow headers.
additional_compiler_args['tensorflow_header_root'] = 'tensorflow'
_pywrap_tfcompile.Compile(
graph=frozen_graph_def_location,
config=config_pbtxt_location,
@ -1008,8 +1094,7 @@ def aot_compile_cpu_meta_graph_def(
out_metadata_object='{}_metadata.o'.format(output_prefix),
gen_name_to_index=True,
# ProgramShape isn't uniquefied by entry_point.
gen_program_shape=False,
**additional_compiler_args)
gen_program_shape=False)
def _optimize_graph(meta_graph_def, signature_def):
@ -1034,7 +1119,7 @@ def _replace_input_placeholders_with_default_values(graph_def, signature_def):
name_to_node_map = dict((n.name, n) for n in graph_def.node)
temp_graph = ops_lib.Graph()
for name, input_ in signature_def.inputs.items():
tensor_name = input_.name.split(':')[0]
tensor_name, _ = _parse_tensor_name(input_.name)
if tensor_name not in name_to_node_map:
raise RuntimeError(
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
@ -1330,13 +1415,16 @@ def add_aot_compile_cpu_subparser(subparsers):
'The class will be generated in the given namespace(s), or if no '
'namespaces are given, within the global namespace.'))
parser_compile.add_argument(
'--freeze_graph',
type=bool,
default=True,
help=('Whether to freeze the tf.Variables into the graph. If false, '
'then all Variables in the closure of the signature graph path '
'be be added as input and output args to the XLA-compiled graph '
'(not currently supported)'))
'--variables_to_feed',
type=str,
default='',
help=('The names of variables that will be fed into the network. '
'Options are: empty (default; all variables are frozen, none may '
'be fed), \'all\' (all variables may be fed), or a '
'comma-delimited list of names of variables that may be fed. In '
'the last case, the non-fed variables will be frozen in the graph.')
)
parser_compile.set_defaults(func=aot_compile_cpu)
@ -1371,12 +1459,13 @@ def create_parser():
return parser
def _signature_to_tf2xla_config(signature_def, frozen_variables):
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
Args:
signature_def: Instance of `SignatureDef`.
frozen_variables: Python bool, whether variables are being frozen or not.
variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp,
the list of variables to feed.
Returns:
An instance of `tf2xla.Config` proto.
@ -1390,7 +1479,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
tensor_id = tf2xla_pb2.TensorId
for name, input_ in signature_def.inputs.items():
(node_name, output_index) = input_.name.split(':')
name = name.replace('/', '_')
name = 'feed_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(input_.name)
output_index = int(output_index)
config.feed.append(
tf2xla_pb2.Feed(
@ -1399,7 +1490,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
type=input_.dtype,
shape=input_.tensor_shape))
for name, output_ in signature_def.outputs.items():
(node_name, output_index) = output_.name.split(':')
name = name.replace('/', '_')
name = 'fetch_{}'.format(name)
(node_name, output_index) = _parse_tensor_name(output_.name)
output_index = int(output_index)
config.fetch.append(
tf2xla_pb2.Fetch(
@ -1407,14 +1500,22 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables):
name=name,
type=output_.dtype,
shape=output_.tensor_shape))
if not frozen_variables:
# Extract all variables along the path and add to config
raise NotImplementedError('Non-frozen graphs are not supported.')
for node in variable_nodes_to_feed:
name = node.name.replace('/', '_')
name = 'param_{}'.format(name)
config.variable.append(
tf2xla_pb2.Variable(
node_name=node.name,
name=name,
type=node.attr['dtype'].type,
shape=node.attr['shape'].shape,
readonly=True))
return config
def main():
logging.set_verbosity(logging.INFO)
parser = create_parser()
args = parser.parse_args()
if not hasattr(args, 'func'):

View File

@ -25,6 +25,7 @@ import pickle
import shutil
import sys
from absl.testing import parameterized
import numpy as np
from six import StringIO
@ -38,6 +39,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import save
from tensorflow.python.tools import saved_model_cli
from tensorflow.python.training.tracking import tracking
@ -56,7 +58,7 @@ def captured_output():
sys.stdout, sys.stderr = old_out, old_err
class SavedModelCLITestCase(test.TestCase):
class SavedModelCLITestCase(test.TestCase, parameterized.TestCase):
def testShowCommandAll(self):
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
@ -726,35 +728,44 @@ Defined Functions:
with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'):
saved_model_cli.aot_compile_cpu(args)
def testAOTCompileCPUFreezesAndCompiles(self):
class AOTCompileDummyModel(tracking.AutoTrackable):
"""Model compatible with XLA compilation."""
def __init__(self):
self.var = variables.Variable(1.0, name='my_var')
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
])
def func2(self, x, y):
return {'res': x + self.var}
@parameterized.named_parameters(('VariablesToFeedNone', ''),
('VariablesToFeedAll', 'all'),
('VariablesToFeedMyVar', 'my_var'))
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed):
if not test.is_built_with_xla():
self.skipTest('Skipping test because XLA is not compiled in.')
class DummyModel(tracking.AutoTrackable):
"""Model compatible with XLA compilation."""
def __init__(self):
self.var = variables.Variable(1.0, name='my_var')
@def_function.function(input_signature=[
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
])
def func2(self, x):
return {'res': x + self.var}
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
dummy_model = DummyModel()
dummy_model = self.AOTCompileDummyModel()
with self.cached_session():
self.evaluate(dummy_model.var.initializer)
save.save(dummy_model, saved_model_dir)
self.parser = saved_model_cli.create_parser()
output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out')
args = self.parser.parse_args(
['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
'--output_prefix', output_prefix,
'--cpp_class', 'Generated']) # Use the default seving signature_key.
saved_model_cli.aot_compile_cpu(args)
args = self.parser.parse_args([
'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve',
'--output_prefix', output_prefix, '--variables_to_feed',
variables_to_feed, '--cpp_class', 'Generated'
]) # Use the default seving signature_key.
with test.mock.patch.object(logging, 'warn') as captured_warn:
saved_model_cli.aot_compile_cpu(args)
self.assertRegexpMatches(
str(captured_warn.call_args),
'Signature input key \'y\'.*has been pruned while freezing the graph.')
self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix)))
self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix)))
self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix)))
@ -762,8 +773,12 @@ Defined Functions:
file_io.file_exists('{}_makefile.inc'.format(output_prefix)))
header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix))
self.assertIn('class Generated', header_contents)
self.assertIn('arg_x_data', header_contents)
self.assertIn('result_res_data', header_contents)
self.assertIn('arg_feed_x_data', header_contents)
self.assertIn('result_fetch_res_data', header_contents)
# arg_y got filtered out as it's not used by the output.
self.assertNotIn('arg_feed_y_data', header_contents)
if variables_to_feed:
self.assertIn('var_param_my_var', header_contents)
makefile_contents = file_io.read_file_to_string(
'{}_makefile.inc'.format(output_prefix))
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)