From b26e1efeceda14f2f0a72f607ab857c3a5be7978 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Mon, 20 Jan 2020 15:55:51 -0800 Subject: [PATCH] [TF] Fixes and new features to saved_model_cli aot_compile_cpu, adds e2e test. * Can now specify which variables can be fed/fetched. * Bugfix when signature name contains slashes or starts with integers. * Prune input config entries from tf2xla config when graph freeze removes unused input feed. * Fixed a bug where third_party/tensorflow/ isn't properly renamed to tensorflow/ in opensource HOST build (identified during the new genrule test). Solution: bring back the hardcoded #include in codegen.cc; it's always correct. NOTE: The bugfix to the #include line in the compiler/ codebase is a partial rollback of the initial tfcompile + saved_model_cli CL which moved from the hard-coded include path to a parameterized value. It turns out we don't need the complexity of this approach and it's incorrect in the host opensource build. TESTED: Includes a bonafide genrule test which runs saved_model_cli to generate the header and object files, and includes them in a c++ unit test and ensures that they compile and the resulting object runs correctly. PiperOrigin-RevId: 290655683 Change-Id: I4cfa2c595ebe56f8bdd47853f82371d97b92b081 --- tensorflow/compiler/aot/codegen.cc | 5 +- tensorflow/compiler/aot/codegen_test.cc | 1 - tensorflow/compiler/aot/compile.cc | 4 +- tensorflow/compiler/aot/compile.h | 1 - tensorflow/compiler/aot/flags.cc | 2 - tensorflow/compiler/aot/flags.h | 1 - tensorflow/compiler/aot/tfcompile_main.cc | 1 - .../compiler/xla/service/cpu/cpu_compiler.cc | 5 +- .../compiler/xla/service/cpu/cpu_compiler.h | 16 +- tensorflow/python/tfcompile_wrapper.cc | 9 +- tensorflow/python/tools/BUILD | 59 +++++- ...binary_using_aot_compiled_x_plus_y_test.cc | 30 +++ tensorflow/python/tools/saved_model_cli.py | 197 +++++++++++++----- .../python/tools/saved_model_cli_test.py | 59 ++++-- 14 files changed, 281 insertions(+), 109 deletions(-) create mode 100644 tensorflow/python/tools/binary_using_aot_compiled_x_plus_y_test.cc diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 188ec6bdfda..53150e991cc 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -457,8 +457,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, {{INCLUDE_XLA_DATA_PROTO}} {{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}} -#include "{{TF_HEADER_ROOT}}/compiler/tf2xla/xla_compiled_cpu_function.h" -#include "{{TF_HEADER_ROOT}}/core/platform/types.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } namespace xla { class ExecutableRunOptions; } @@ -659,7 +659,6 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{CLASS}}", opts.class_name}, {"{{DECLS_FROM_OBJ_FILE}}", absl::StrJoin(metadata_result.header_variable_decls, "\n")}, - {"{{TF_HEADER_ROOT}}", compile_result.tensorflow_header_root}, {"{{ENTRY}}", compile_result.entry_point}, {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}", metadata_result.hlo_profile_printer_data_access_shim}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index c73724b26b2..a7294323d1d 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -197,7 +197,6 @@ TEST(CodegenTest, Golden) { variable3->mutable_shape()->add_dim()->set_size(5); variable3->set_type(DT_INT32); CompileResult compile_result; - compile_result.tensorflow_header_root = "third_party/tensorflow"; compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( {}, {BufferInfo::MakeTempBuffer(1), diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 3d450696aab..bd6c3bc8467 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -85,7 +85,6 @@ Status CompileXla(xla::CompileOnlyClient* client, xla::unique_ptr_static_cast( std::move(aot_or.ValueOrDie().back())); compile_result->entry_point = aot_opts.entry_point_name(); - compile_result->tensorflow_header_root = aot_opts.tensorflow_header_root(); compile_result->pointer_size = xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); return Status::OK(); @@ -130,8 +129,7 @@ Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, flags.entry_point, - xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic, - flags.tensorflow_header_root); + xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic); return CompileXla(client, computation, aot_opts, compile_result); } diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 7b465ccf941..9978d52390d 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -35,7 +35,6 @@ struct CompileResult { std::unique_ptr aot; xla::ProgramShapeProto program_shape; // Static shape of args and results. string entry_point; // Name of generated function. - string tensorflow_header_root; // Prefix for tensorflow headers. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 2e53f7c02aa..e7040d12b8b 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -74,8 +74,6 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, {"gen_program_shape", &flags->gen_program_shape, "Generate program shape data for the ProgramShape method."}, - {"tensorflow_header_root", &flags->tensorflow_header_root, - "Root directory of tensorflow headers."}, }; flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); } diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 5a8476c001b..451a0455977 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -40,7 +40,6 @@ struct MainFlags { string out_header; string out_session_module; string mlir_components; - string tensorflow_header_root; // C++ codegen options bool gen_name_to_index = false; diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 83aa79f0072..d027bae5d04 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -65,7 +65,6 @@ int main(int argc, char** argv) { flags.out_metadata_object = "out_helper.o"; flags.out_header = "out.h"; flags.entry_point = "entry"; - flags.tensorflow_header_root = "third_party/tensorflow"; std::vector flag_list; AppendMainFlags(&flag_list, &flags); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index c10448b281e..a04a39b4461 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -119,13 +119,12 @@ using BufferInfo = cpu_function_runtime::BufferInfo; CpuAotCompilationOptions::CpuAotCompilationOptions( string triple, string cpu_name, string features, string entry_point_name, - RelocationModel relocation_model, string tensorflow_header_root) + RelocationModel relocation_model) : triple_(std::move(triple)), cpu_name_(std::move(cpu_name)), features_(std::move(features)), entry_point_name_(std::move(entry_point_name)), - relocation_model_(relocation_model), - tensorflow_header_root_(std::move(tensorflow_header_root)) {} + relocation_model_(relocation_model) {} CpuAotCompilationOptions::~CpuAotCompilationOptions() = default; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index b7e78c38126..537bf8b87c6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -53,16 +53,7 @@ class CpuAotCompilationOptions : public AotCompilationOptions { CpuAotCompilationOptions(string triple, string cpu_name, string features, string entry_point_name, - RelocationModel relocation_model, - string tensorflow_header_root); - - CpuAotCompilationOptions(string triple, string cpu_name, string features, - string entry_point_name, - RelocationModel relocation_model) - : CpuAotCompilationOptions( - std::move(triple), std::move(cpu_name), std::move(features), - std::move(entry_point_name), relocation_model, - /*tensorflow_header_root=*/"third_party/tensorflow") {} + RelocationModel relocation_model); ~CpuAotCompilationOptions() override; @@ -76,10 +67,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions { const string& features() const { return features_; } // The name to be used for the compiled code's entry point. const string& entry_point_name() const { return entry_point_name_; } - // The prefix for tensorflow headers, e.g. "third_party/tensorflow". - const string& tensorflow_header_root() const { - return tensorflow_header_root_; - } // The relocation model used for compilation. RelocationModel relocation_model() const { return relocation_model_; } @@ -89,7 +76,6 @@ class CpuAotCompilationOptions : public AotCompilationOptions { const string features_; const string entry_point_name_; const RelocationModel relocation_model_; - const string tensorflow_header_root_; }; class CpuAotCompilationResult : public AotCompilationResult { diff --git a/tensorflow/python/tfcompile_wrapper.cc b/tensorflow/python/tfcompile_wrapper.cc index 7ab251ab1da..ac69d326663 100644 --- a/tensorflow/python/tfcompile_wrapper.cc +++ b/tensorflow/python/tfcompile_wrapper.cc @@ -39,8 +39,8 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) { std::string entry_point, std::string cpp_class, std::string out_function_object, std::string out_metadata_object, std::string out_header, std::string out_session_module, - std::string mlir_components, std::string tensorflow_header_root, - bool gen_name_to_index, bool gen_program_shape) { + std::string mlir_components, bool gen_name_to_index, + bool gen_program_shape) { tensorflow::tfcompile::MainFlags flags; flags.graph = std::move(graph); flags.config = std::move(config); @@ -54,7 +54,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) { flags.out_header = std::move(out_header); flags.out_session_module = std::move(out_session_module); flags.mlir_components = std::move(mlir_components); - flags.tensorflow_header_root = std::move(tensorflow_header_root); // C++ codegen options flags.gen_name_to_index = gen_name_to_index; @@ -68,8 +67,6 @@ PYBIND11_MODULE(_pywrap_tfcompile, m) { py::arg("cpp_class") = "", py::arg("out_function_object") = "out_model.o", py::arg("out_metadata_object") = "out_helper.o", py::arg("out_header") = "out.h", py::arg("out_session_module") = "", - py::arg("mlir_components") = "", - py::arg("tensorflow_header_root") = "third_party/tensorflow", - py::arg("gen_name_to_index") = false, + py::arg("mlir_components") = "", py::arg("gen_name_to_index") = false, py::arg("gen_program_shape") = false); } diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index fe9cb1bc5a2..ba473808ab0 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -1,7 +1,7 @@ # Description: # Tools for manipulating TensorFlow graphs. -load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test") +load("//tensorflow:tensorflow.bzl", "if_xla_available", "py_binary", "py_test", "tf_cc_test") package( default_visibility = ["//visibility:public"], @@ -343,10 +343,63 @@ py_test( "no-internal-py3", "nosan", ], - # Force-include XLA dependencies of saved_model_cli_lib to ensure we test - # the AOT compilation. deps = [ ":saved_model_cli_lib", "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "@absl_py//absl/testing:parameterized", ], ) + +genrule( + name = "aot_compiled_x_plus_y_gen", + srcs = [ + "//tensorflow/cc/saved_model:saved_model_half_plus_two", + "//tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb", + ], + outs = [ + "compiled_model.h", + "compiled_model.o", + "compiled_model_metadata.o", + "compiled_model_makefile.inc", + ], + cmd = ( + "$(location :saved_model_cli) aot_compile_cpu " + + "--dir \"$$(dirname $(location //tensorflow/cc/saved_model:testdata/x_plus_y_v2_debuginfo/saved_model.pb))\" " + + "--output_prefix $(@D)/compiled_model " + + "--cpp_class CompiledModel " + + "--tag_set serve " + ), + tools = [ + ":saved_model_cli", + ], +) + +cc_library( + name = "aot_compiled_x_plus_y", + srcs = if_xla_available([ + ":compiled_model.o", + ":compiled_model_metadata.o", + ]), + hdrs = if_xla_available([ + ":compiled_model.h", + ]), + deps = if_xla_available([ + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core/platform:types", + ]), +) + +tf_cc_test( + name = "binary_using_aot_compiled_x_plus_y_test", + srcs = if_xla_available([ + "binary_using_aot_compiled_x_plus_y_test.cc", + ]), + deps = [ + "//tensorflow/core:test_main", + ] + if_xla_available([ + ":aot_compiled_x_plus_y", + "//tensorflow/core:test", + "//tensorflow/core/platform:logging", + ]), +) diff --git a/tensorflow/python/tools/binary_using_aot_compiled_x_plus_y_test.cc b/tensorflow/python/tools/binary_using_aot_compiled_x_plus_y_test.cc new file mode 100644 index 00000000000..3f7cf72cd54 --- /dev/null +++ b/tensorflow/python/tools/binary_using_aot_compiled_x_plus_y_test.cc @@ -0,0 +1,30 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/python/tools/compiled_model.h" + +namespace tensorflow { +namespace { +TEST(AOTCompiledSavedModelTest, Run) { + CompiledModel model; + *model.arg_feed_x_data() = 3.0f; + *model.arg_feed_y_data() = 4.0f; + CHECK(model.Run()); + ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 2514ed19d6f..f846f43127f 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -101,6 +101,16 @@ def _sysconfig_module(): return sysconfig_lib +def _parse_tensor_name(name): + """Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0).""" + if ':' in name and not name.endswith(':'): + node_name = name[:name.rfind(':')] + output_slot = int(name[name.rfind(':') + 1:]) + return node_name, output_slot + else: + return name, None + + _XLA_MAKEFILE_TEMPLATE = """ INC = -I{tensorflow_includes} LIB = -L{compiled_dir} @@ -134,7 +144,11 @@ def _xla_makefile_string(output_prefix): base = os.path.realpath( os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3))) else: - base = test.test_src_dir_path('') + try: + base = test.test_src_dir_path('') + except KeyError: # Can't find TEST_SRCDIR in environment path. + base = os.path.realpath( + os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3))) expected_header = os.path.join( base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h') if not os.path.exists(expected_header): @@ -164,6 +178,47 @@ def _show_tag_sets(saved_model_dir): print('%r' % ', '.join(sorted(tag_set))) +def _get_variable_nodes_from_graph_def(graph_def): + """Get the list of Variable nodes from `graph_def`. + + Args: + graph_def: An instance of `GraphDef`. + + Returns: + A list of `NodeDef` corresponding to variables in the graph. + """ + variables = [n for n in graph_def.node if n.op == 'VarHandleOp'] + + for f in graph_def.library.function: + variables += [n for n in f.node_def if n.op == 'VarHandleOp'] + + return variables + + +def _prune_removed_feed_nodes(signature_def, graph_def): + """Identify the inputs in the signature no longer in graph_def, prune them. + + Args: + signature_def: A `SignatureDef` instance. + graph_def: A `GraphDef` instance. + + Returns: + A new pruned `SignatureDef`. + """ + node_names = set([n.name for n in graph_def.node]) + new_signature_def = meta_graph_pb2.SignatureDef() + new_signature_def.CopyFrom(signature_def) + for (k, v) in signature_def.inputs.items(): + tensor_name, _ = _parse_tensor_name(v.name) + if tensor_name not in node_names: + logging.warn( + 'Signature input key \'{}\', tensor name \'{}\', has been pruned ' + 'while freezing the graph. Removing it from the compiled signatures.' + .format(k, tensor_name)) + del new_signature_def.inputs[k] + return new_signature_def + + def _show_signature_def_map_keys(saved_model_dir, tag_set): """Prints the keys for each SignatureDef in the SignatureDef map. @@ -882,23 +937,28 @@ def aot_compile_cpu(args): checkpoint_path = ( args.checkpoint_path or os.path.join(args.dir, 'variables/variables')) + if not args.variables_to_feed: + variables_to_feed = [] + elif args.variables_to_feed.lower() == 'all': + variables_to_feed = None # We will identify them after. + else: + variables_to_feed = args.variables_to_feed.split(',') aot_compile_cpu_meta_graph_def( checkpoint_path=checkpoint_path, meta_graph_def=saved_model_utils.get_meta_graph_def( args.dir, args.tag_set), signature_def_key=args.signature_def_key, - freeze_graph=args.freeze_graph, + variables_to_feed=variables_to_feed, output_prefix=args.output_prefix, cpp_class=args.cpp_class) -def aot_compile_cpu_meta_graph_def( - checkpoint_path, - meta_graph_def, - output_prefix, - signature_def_key, - cpp_class, - freeze_graph=True): +def aot_compile_cpu_meta_graph_def(checkpoint_path, + meta_graph_def, + output_prefix, + signature_def_key, + cpp_class, + variables_to_feed=()): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. Use XLA AOT (`tfcompile`) to convert the given meta graph and @@ -920,7 +980,10 @@ def aot_compile_cpu_meta_graph_def( output_prefix: Python string. Path prefix for outputs. signature_def_key: String, the signature_def to use in the SavedModel. cpp_class: Name of output C++ class. - freeze_graph: Whether to freeze the graph before compilation. + variables_to_feed: A list of strings, the variables that will be fed by the + user; these won't be frozen. If `None`, then we will extract all the + variables in the graph and mark them as to-feed. The default behavior is + an empty tuple: all variables must be frozen. Raises: RuntimeError: If tensorflow was not built with XLA. @@ -945,32 +1008,62 @@ def aot_compile_cpu_meta_graph_def( 'Signature key {} must have outputs, but saw none:\n{}'.format( signature_def_key, str(signature_def))) + temp_dir = test.get_temp_dir() + file_io.recursive_create_dir(temp_dir) + if logging.get_verbosity() >= logging.INFO: + original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb') + with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer: + graph_writer.write(meta_graph_def.graph_def.SerializeToString()) + # This updates graph_def in place. _replace_input_placeholders_with_default_values( meta_graph_def.graph_def, signature_def) graph_def = _optimize_graph(meta_graph_def, signature_def) - if freeze_graph: - # Load the Variables so that we can freeze the graph. - with session.Session(graph=ops_lib.Graph()) as sess: - restorer = saver_lib.import_meta_graph( - meta_graph_def, clear_devices=True) - restorer.restore(sess, checkpoint_path) - graph_def.CopyFrom( - graph_util.convert_variables_to_constants( - sess, - graph_def, - [n.name.split(':')[0] for n in signature_def.outputs.values()])) + all_variables = _get_variable_nodes_from_graph_def(graph_def) + if variables_to_feed is None: + variable_nodes_to_feed = list(all_variables) + else: + not_in_graph = ( + set(variables_to_feed).difference([x.name for x in all_variables])) + if not_in_graph: + raise ValueError( + 'Asked to feed variables that were not found in graph: {}. ' + 'Variables contained in the graph: {}'.format( + not_in_graph, set([x.name for x in all_variables]))) + all_variables_map = dict((x.name, x) for x in all_variables) + variable_nodes_to_feed = [ + all_variables_map[name] for name in variables_to_feed + ] + + if logging.get_verbosity() >= logging.INFO: + prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb') + with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer: + graph_writer.write(meta_graph_def.graph_def.SerializeToString()) + + # Load the Variables so that we can freeze the graph. + with session.Session(graph=ops_lib.Graph()) as sess: + restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True) + restorer.restore(sess, checkpoint_path) + graph_def.CopyFrom( + graph_util.convert_variables_to_constants( + sess, + graph_def, + output_node_names=[ + _parse_tensor_name(n.name)[0] + for n in signature_def.outputs.values() + ], + )) + + signature_def = _prune_removed_feed_nodes(signature_def, graph_def) - temp_dir = test.get_temp_dir() frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb') config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt') logging.info('Writing graph def to: {}'.format(frozen_graph_def_location)) with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer: graph_writer.write(graph_def.SerializeToString()) config = _signature_to_tf2xla_config( - signature_def, - frozen_variables=freeze_graph) + signature_def, variable_nodes_to_feed=variable_nodes_to_feed) logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location)) with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer: config_writer.write(str(config)) @@ -991,13 +1084,6 @@ def aot_compile_cpu_meta_graph_def( output_prefix = _shlex_quote(output_prefix) - additional_compiler_args = {} - sysconfig = _sysconfig_module() - if sysconfig: - # We're inside PIP and need to pass a customized relative path to the - # appropriate tensorflow headers. - additional_compiler_args['tensorflow_header_root'] = 'tensorflow' - _pywrap_tfcompile.Compile( graph=frozen_graph_def_location, config=config_pbtxt_location, @@ -1008,8 +1094,7 @@ def aot_compile_cpu_meta_graph_def( out_metadata_object='{}_metadata.o'.format(output_prefix), gen_name_to_index=True, # ProgramShape isn't uniquefied by entry_point. - gen_program_shape=False, - **additional_compiler_args) + gen_program_shape=False) def _optimize_graph(meta_graph_def, signature_def): @@ -1034,7 +1119,7 @@ def _replace_input_placeholders_with_default_values(graph_def, signature_def): name_to_node_map = dict((n.name, n) for n in graph_def.node) temp_graph = ops_lib.Graph() for name, input_ in signature_def.inputs.items(): - tensor_name = input_.name.split(':')[0] + tensor_name, _ = _parse_tensor_name(input_.name) if tensor_name not in name_to_node_map: raise RuntimeError( 'Unable to find input signature tensor \'{}\' in optimized GraphDef. ' @@ -1330,13 +1415,16 @@ def add_aot_compile_cpu_subparser(subparsers): 'The class will be generated in the given namespace(s), or if no ' 'namespaces are given, within the global namespace.')) parser_compile.add_argument( - '--freeze_graph', - type=bool, - default=True, - help=('Whether to freeze the tf.Variables into the graph. If false, ' - 'then all Variables in the closure of the signature graph path ' - 'be be added as input and output args to the XLA-compiled graph ' - '(not currently supported)')) + '--variables_to_feed', + type=str, + default='', + help=('The names of variables that will be fed into the network. ' + 'Options are: empty (default; all variables are frozen, none may ' + 'be fed), \'all\' (all variables may be fed), or a ' + 'comma-delimited list of names of variables that may be fed. In ' + 'the last case, the non-fed variables will be frozen in the graph.') + ) + parser_compile.set_defaults(func=aot_compile_cpu) @@ -1371,12 +1459,13 @@ def create_parser(): return parser -def _signature_to_tf2xla_config(signature_def, frozen_variables): +def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed): """Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto. Args: signature_def: Instance of `SignatureDef`. - frozen_variables: Python bool, whether variables are being frozen or not. + variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp, + the list of variables to feed. Returns: An instance of `tf2xla.Config` proto. @@ -1390,7 +1479,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables): tensor_id = tf2xla_pb2.TensorId for name, input_ in signature_def.inputs.items(): - (node_name, output_index) = input_.name.split(':') + name = name.replace('/', '_') + name = 'feed_{}'.format(name) + (node_name, output_index) = _parse_tensor_name(input_.name) output_index = int(output_index) config.feed.append( tf2xla_pb2.Feed( @@ -1399,7 +1490,9 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables): type=input_.dtype, shape=input_.tensor_shape)) for name, output_ in signature_def.outputs.items(): - (node_name, output_index) = output_.name.split(':') + name = name.replace('/', '_') + name = 'fetch_{}'.format(name) + (node_name, output_index) = _parse_tensor_name(output_.name) output_index = int(output_index) config.fetch.append( tf2xla_pb2.Fetch( @@ -1407,14 +1500,22 @@ def _signature_to_tf2xla_config(signature_def, frozen_variables): name=name, type=output_.dtype, shape=output_.tensor_shape)) - if not frozen_variables: - # Extract all variables along the path and add to config - raise NotImplementedError('Non-frozen graphs are not supported.') + for node in variable_nodes_to_feed: + name = node.name.replace('/', '_') + name = 'param_{}'.format(name) + config.variable.append( + tf2xla_pb2.Variable( + node_name=node.name, + name=name, + type=node.attr['dtype'].type, + shape=node.attr['shape'].shape, + readonly=True)) return config def main(): + logging.set_verbosity(logging.INFO) parser = create_parser() args = parser.parse_args() if not hasattr(args, 'func'): diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index fd3257e9a73..6e503d1cfe5 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -25,6 +25,7 @@ import pickle import shutil import sys +from absl.testing import parameterized import numpy as np from six import StringIO @@ -38,6 +39,7 @@ from tensorflow.python.framework import tensor_spec from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import save from tensorflow.python.tools import saved_model_cli from tensorflow.python.training.tracking import tracking @@ -56,7 +58,7 @@ def captured_output(): sys.stdout, sys.stderr = old_out, old_err -class SavedModelCLITestCase(test.TestCase): +class SavedModelCLITestCase(test.TestCase, parameterized.TestCase): def testShowCommandAll(self): base_path = test.test_src_dir_path(SAVED_MODEL_PATH) @@ -726,35 +728,44 @@ Defined Functions: with self.assertRaisesRegexp(ValueError, 'Unable to find signature_def'): saved_model_cli.aot_compile_cpu(args) - def testAOTCompileCPUFreezesAndCompiles(self): + class AOTCompileDummyModel(tracking.AutoTrackable): + """Model compatible with XLA compilation.""" + + def __init__(self): + self.var = variables.Variable(1.0, name='my_var') + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), + ]) + def func2(self, x, y): + return {'res': x + self.var} + + @parameterized.named_parameters(('VariablesToFeedNone', ''), + ('VariablesToFeedAll', 'all'), + ('VariablesToFeedMyVar', 'my_var')) + def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed): if not test.is_built_with_xla(): self.skipTest('Skipping test because XLA is not compiled in.') - class DummyModel(tracking.AutoTrackable): - """Model compatible with XLA compilation.""" - - def __init__(self): - self.var = variables.Variable(1.0, name='my_var') - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) - ]) - def func2(self, x): - return {'res': x + self.var} - saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model') - dummy_model = DummyModel() + dummy_model = self.AOTCompileDummyModel() with self.cached_session(): self.evaluate(dummy_model.var.initializer) save.save(dummy_model, saved_model_dir) self.parser = saved_model_cli.create_parser() output_prefix = os.path.join(test.get_temp_dir(), 'aot_compile_cpu_dir/out') - args = self.parser.parse_args( - ['aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve', - '--output_prefix', output_prefix, - '--cpp_class', 'Generated']) # Use the default seving signature_key. - saved_model_cli.aot_compile_cpu(args) + args = self.parser.parse_args([ + 'aot_compile_cpu', '--dir', saved_model_dir, '--tag_set', 'serve', + '--output_prefix', output_prefix, '--variables_to_feed', + variables_to_feed, '--cpp_class', 'Generated' + ]) # Use the default seving signature_key. + with test.mock.patch.object(logging, 'warn') as captured_warn: + saved_model_cli.aot_compile_cpu(args) + self.assertRegexpMatches( + str(captured_warn.call_args), + 'Signature input key \'y\'.*has been pruned while freezing the graph.') self.assertTrue(file_io.file_exists('{}.o'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}.h'.format(output_prefix))) self.assertTrue(file_io.file_exists('{}_metadata.o'.format(output_prefix))) @@ -762,8 +773,12 @@ Defined Functions: file_io.file_exists('{}_makefile.inc'.format(output_prefix))) header_contents = file_io.read_file_to_string('{}.h'.format(output_prefix)) self.assertIn('class Generated', header_contents) - self.assertIn('arg_x_data', header_contents) - self.assertIn('result_res_data', header_contents) + self.assertIn('arg_feed_x_data', header_contents) + self.assertIn('result_fetch_res_data', header_contents) + # arg_y got filtered out as it's not used by the output. + self.assertNotIn('arg_feed_y_data', header_contents) + if variables_to_feed: + self.assertIn('var_param_my_var', header_contents) makefile_contents = file_io.read_file_to_string( '{}_makefile.inc'.format(output_prefix)) self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)