From 384f1a5507febeab911de830357aa698c8f65032 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Thu, 27 Feb 2020 14:13:28 -0800 Subject: [PATCH] Add function to export metagraph from a Trackable object. (i.e. export SavedModel without checkpoints) PiperOrigin-RevId: 297685153 Change-Id: I2440a9ace1700c400e09137b4834f2d8cadf154a --- tensorflow/python/saved_model/save.py | 80 ++++++++++++++-------- tensorflow/python/saved_model/save_test.py | 79 ++++++++++++++++++--- 2 files changed, 120 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index ced4135526a..c2774a98b86 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -941,6 +941,55 @@ def save(obj, export_dir, signatures=None, options=None): May not be called from within a function body. @end_compatibility """ + # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x + # compatible (no sessions) and share it with this export API rather than + # making a SavedModel proto and writing it directly. + saved_model = saved_model_pb2.SavedModel() + meta_graph_def = saved_model.meta_graphs.add() + + _, exported_graph, object_saver, asset_info = _build_meta_graph( + obj, export_dir, signatures, options, meta_graph_def) + saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION + + # Write the checkpoint, copy assets into the assets directory, and write out + # the SavedModel proto itself. + utils_impl.get_or_create_variables_dir(export_dir) + object_saver.save(utils_impl.get_variables_path(export_dir)) + builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, + export_dir) + # Note that this needs to be the last file operation when saving the + # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an + # indication that the SavedModel is completely written. + path = os.path.join( + compat.as_str(export_dir), + compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) + file_io.atomic_write_string_to_file( + path, saved_model.SerializeToString(deterministic=True)) + + # Clean reference cycles so repeated export()s don't make work for the garbage + # collector. Before this point, we need to keep references to captured + # constants in the saved graph. + ops.dismantle_graph(exported_graph) + + +def export_meta_graph(obj, filename, signatures=None, options=None): + """Exports the MetaGraph proto to a file.""" + export_dir = os.path.dirname(filename) + meta_graph_def, exported_graph, _, _ = _build_meta_graph( + obj, export_dir, signatures, options) + + file_io.atomic_write_string_to_file( + filename, meta_graph_def.SerializeToString(deterministic=True)) + + # Clean reference cycles so repeated export()s don't make work for the garbage + # collector. Before this point, we need to keep references to captured + # constants in the saved graph. + ops.dismantle_graph(exported_graph) + + +def _build_meta_graph(obj, export_dir, signatures, options, + meta_graph_def=None): + """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( "tf.saved_model.save is not supported inside a traced " @@ -951,6 +1000,7 @@ def save(obj, export_dir, signatures=None, options=None): raise ValueError( "Expected a Trackable object for export, got {}.".format(obj)) options = options or save_options.SaveOptions() + meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef() checkpoint_graph_view = _AugmentedGraphView(obj) if signatures is None: @@ -971,12 +1021,6 @@ def save(obj, export_dir, signatures=None, options=None): # there can be side effects of creating variables. _ = _SaveableView(checkpoint_graph_view) saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions) - - # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x - # compatible (no sessions) and share it with this export API rather than - # making a SavedModel proto and writing it directly. - saved_model = saved_model_pb2.SavedModel() - meta_graph_def = saved_model.meta_graphs.add() object_saver = util.TrackableSaver(checkpoint_graph_view) asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def, saveable_view, signatures, @@ -988,18 +1032,7 @@ def save(obj, export_dir, signatures=None, options=None): function_aliases[fdef.name] = alias for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias - saved_model.saved_model_schema_version = ( - constants.SAVED_MODEL_SCHEMA_VERSION) - # So far we've just been generating protocol buffers with no I/O. Now we write - # the checkpoint, copy assets into the assets directory, and write out the - # SavedModel proto itself. - utils_impl.get_or_create_variables_dir(export_dir) - object_saver.save(utils_impl.get_variables_path(export_dir)) - builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map, - export_dir) - path = os.path.join( - compat.as_str(export_dir), - compat.as_str(constants.SAVED_MODEL_FILENAME_PB)) + object_graph_proto = _serialize_object_graph(saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) @@ -1013,13 +1046,4 @@ def save(obj, export_dir, signatures=None, options=None): constants.DEBUG_INFO_FILENAME_PB), graph_debug_info.SerializeToString(deterministic=True)) - # Note that this needs to be the last file operation when saving the - # SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an - # indication that the SavedModel is completely written. - file_io.atomic_write_string_to_file( - path, saved_model.SerializeToString(deterministic=True)) - - # Clean reference cycles so repeated export()s don't make work for the garbage - # collector. Before this point, we need to keep references to captured - # constants in the saved graph. - ops.dismantle_graph(exported_graph) + return meta_graph_def, exported_graph, object_saver, asset_info diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 05187c92b81..cae8c4c7c96 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -33,6 +33,7 @@ from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util @@ -53,6 +54,7 @@ from tensorflow.python.saved_model import save from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import saver from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import util from tensorflow.python.util import compat @@ -76,6 +78,21 @@ class _ModelWithOptimizer(util.Checkpoint): return {"loss": loss} +def _run_signature(session, meta_graph_def, inputs, signature_key): + signature = meta_graph_def.signature_def[signature_key] + assert set(inputs.keys()) == set(signature.inputs.keys()) + feed_dict = {} + for arg_name in inputs.keys(): + input_tensor = session.graph.get_tensor_by_name( + signature.inputs[arg_name].name) + feed_dict[input_tensor] = inputs[arg_name] + output_dict = {} + for output_name, output_tensor_info in signature.outputs.items(): + output_dict[output_name] = session.graph.get_tensor_by_name( + output_tensor_info.name) + return session.run(output_dict, feed_dict=feed_dict) + + def _import_and_infer( save_dir, inputs, signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): @@ -83,17 +100,7 @@ def _import_and_infer( graph = ops.Graph() with graph.as_default(), session_lib.Session() as session: model = loader.load(session, [tag_constants.SERVING], save_dir) - signature = model.signature_def[signature_key] - assert set(inputs.keys()) == set(signature.inputs.keys()) - feed_dict = {} - for arg_name in inputs.keys(): - feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = ( - inputs[arg_name]) - output_dict = {} - for output_name, output_tensor_info in signature.outputs.items(): - output_dict[output_name] = graph.get_tensor_by_name( - output_tensor_info.name) - return session.run(output_dict, feed_dict=feed_dict) + return _run_signature(session, model, inputs, signature_key) class SaveTest(test.TestCase): @@ -685,5 +692,55 @@ class MemoryTests(test.TestCase): save.save(self._model, save_dir, self._model.call) +class ExportMetaGraphTests(test.TestCase): + + def test_export_meta_graph(self): + root = tracking.AutoTrackable() + root.variable = resource_variable_ops.UninitializedVariable( + name="some_variable", dtype=dtypes.float32) + + @def_function.function(input_signature=[tensor_spec.TensorSpec(None)]) + def multiply_var(x): + return root.variable * x + + @def_function.function(input_signature=[tensor_spec.TensorSpec([])]) + def update(y): + root.variable.assign_add(y) + # TODO(b/150393409): All functions exported as signatures must have at + # least one output. + return 0 + + @def_function.function(input_signature=[]) + def initialize(): + root.variable.assign(1.0) + # TODO(b/150393409): All functions exported as signatures must have at + # least one output. + return 0 + + save_path = os.path.join(self.get_temp_dir(), "meta_graph.pb") + save.export_meta_graph( + root, + save_path, + signatures={ + "multiply_var": multiply_var, + "initialize": initialize, + "update": update + }) + + with ops.Graph().as_default(), session_lib.Session() as session: + saver.import_meta_graph(save_path) + meta_graph_def = meta_graph.read_meta_graph_file(save_path) + + # Initialize variable to 1 + _run_signature(session, meta_graph_def, {}, "initialize") + out = _run_signature(session, meta_graph_def, {"x": 3}, "multiply_var") + self.assertAllEqual(out, {"output_0": 3}) + + # Adds 2 to the variable. Variable is now 3 + _run_signature(session, meta_graph_def, {"y": 2}, "update") + out = _run_signature(session, meta_graph_def, {"x": 4}, "multiply_var") + self.assertAllEqual(out, {"output_0": 12}) + + if __name__ == "__main__": test.main()