Add function to export metagraph from a Trackable object. (i.e. export SavedModel without checkpoints)

PiperOrigin-RevId: 297685153
Change-Id: I2440a9ace1700c400e09137b4834f2d8cadf154a
This commit is contained in:
Katherine Wu 2020-02-27 14:13:28 -08:00 committed by TensorFlower Gardener
parent aad9a544a4
commit 384f1a5507
2 changed files with 120 additions and 39 deletions

View File

@ -941,6 +941,55 @@ def save(obj, export_dir, signatures=None, options=None):
May not be called from within a function body. May not be called from within a function body.
@end_compatibility @end_compatibility
""" """
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly.
saved_model = saved_model_pb2.SavedModel()
meta_graph_def = saved_model.meta_graphs.add()
_, exported_graph, object_saver, asset_info = _build_meta_graph(
obj, export_dir, signatures, options, meta_graph_def)
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
# Write the checkpoint, copy assets into the assets directory, and write out
# the SavedModel proto itself.
utils_impl.get_or_create_variables_dir(export_dir)
object_saver.save(utils_impl.get_variables_path(export_dir))
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
export_dir)
# Note that this needs to be the last file operation when saving the
# SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
# indication that the SavedModel is completely written.
path = os.path.join(
compat.as_str(export_dir),
compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
file_io.atomic_write_string_to_file(
path, saved_model.SerializeToString(deterministic=True))
# Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point, we need to keep references to captured
# constants in the saved graph.
ops.dismantle_graph(exported_graph)
def export_meta_graph(obj, filename, signatures=None, options=None):
"""Exports the MetaGraph proto to a file."""
export_dir = os.path.dirname(filename)
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
obj, export_dir, signatures, options)
file_io.atomic_write_string_to_file(
filename, meta_graph_def.SerializeToString(deterministic=True))
# Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point, we need to keep references to captured
# constants in the saved graph.
ops.dismantle_graph(exported_graph)
def _build_meta_graph(obj, export_dir, signatures, options,
meta_graph_def=None):
"""Creates a MetaGraph containing the resources and functions of an object."""
if ops.inside_function(): if ops.inside_function():
raise AssertionError( raise AssertionError(
"tf.saved_model.save is not supported inside a traced " "tf.saved_model.save is not supported inside a traced "
@ -951,6 +1000,7 @@ def save(obj, export_dir, signatures=None, options=None):
raise ValueError( raise ValueError(
"Expected a Trackable object for export, got {}.".format(obj)) "Expected a Trackable object for export, got {}.".format(obj))
options = options or save_options.SaveOptions() options = options or save_options.SaveOptions()
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
checkpoint_graph_view = _AugmentedGraphView(obj) checkpoint_graph_view = _AugmentedGraphView(obj)
if signatures is None: if signatures is None:
@ -971,12 +1021,6 @@ def save(obj, export_dir, signatures=None, options=None):
# there can be side effects of creating variables. # there can be side effects of creating variables.
_ = _SaveableView(checkpoint_graph_view) _ = _SaveableView(checkpoint_graph_view)
saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions) saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly.
saved_model = saved_model_pb2.SavedModel()
meta_graph_def = saved_model.meta_graphs.add()
object_saver = util.TrackableSaver(checkpoint_graph_view) object_saver = util.TrackableSaver(checkpoint_graph_view)
asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def, asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def,
saveable_view, signatures, saveable_view, signatures,
@ -988,18 +1032,7 @@ def save(obj, export_dir, signatures=None, options=None):
function_aliases[fdef.name] = alias function_aliases[fdef.name] = alias
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
function_aliases[fdef.name] = alias function_aliases[fdef.name] = alias
saved_model.saved_model_schema_version = (
constants.SAVED_MODEL_SCHEMA_VERSION)
# So far we've just been generating protocol buffers with no I/O. Now we write
# the checkpoint, copy assets into the assets directory, and write out the
# SavedModel proto itself.
utils_impl.get_or_create_variables_dir(export_dir)
object_saver.save(utils_impl.get_variables_path(export_dir))
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
export_dir)
path = os.path.join(
compat.as_str(export_dir),
compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
object_graph_proto = _serialize_object_graph(saveable_view, object_graph_proto = _serialize_object_graph(saveable_view,
asset_info.asset_index) asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
@ -1013,13 +1046,4 @@ def save(obj, export_dir, signatures=None, options=None):
constants.DEBUG_INFO_FILENAME_PB), constants.DEBUG_INFO_FILENAME_PB),
graph_debug_info.SerializeToString(deterministic=True)) graph_debug_info.SerializeToString(deterministic=True))
# Note that this needs to be the last file operation when saving the return meta_graph_def, exported_graph, object_saver, asset_info
# SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
# indication that the SavedModel is completely written.
file_io.atomic_write_string_to_file(
path, saved_model.SerializeToString(deterministic=True))
# Clean reference cycles so repeated export()s don't make work for the garbage
# collector. Before this point, we need to keep references to captured
# constants in the saved graph.
ops.dismantle_graph(exported_graph)

View File

@ -33,6 +33,7 @@ from tensorflow.python.eager import function
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -53,6 +54,7 @@ from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import save_options from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util from tensorflow.python.training.tracking import util
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -76,6 +78,21 @@ class _ModelWithOptimizer(util.Checkpoint):
return {"loss": loss} return {"loss": loss}
def _run_signature(session, meta_graph_def, inputs, signature_key):
signature = meta_graph_def.signature_def[signature_key]
assert set(inputs.keys()) == set(signature.inputs.keys())
feed_dict = {}
for arg_name in inputs.keys():
input_tensor = session.graph.get_tensor_by_name(
signature.inputs[arg_name].name)
feed_dict[input_tensor] = inputs[arg_name]
output_dict = {}
for output_name, output_tensor_info in signature.outputs.items():
output_dict[output_name] = session.graph.get_tensor_by_name(
output_tensor_info.name)
return session.run(output_dict, feed_dict=feed_dict)
def _import_and_infer( def _import_and_infer(
save_dir, inputs, save_dir, inputs,
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY): signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
@ -83,17 +100,7 @@ def _import_and_infer(
graph = ops.Graph() graph = ops.Graph()
with graph.as_default(), session_lib.Session() as session: with graph.as_default(), session_lib.Session() as session:
model = loader.load(session, [tag_constants.SERVING], save_dir) model = loader.load(session, [tag_constants.SERVING], save_dir)
signature = model.signature_def[signature_key] return _run_signature(session, model, inputs, signature_key)
assert set(inputs.keys()) == set(signature.inputs.keys())
feed_dict = {}
for arg_name in inputs.keys():
feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = (
inputs[arg_name])
output_dict = {}
for output_name, output_tensor_info in signature.outputs.items():
output_dict[output_name] = graph.get_tensor_by_name(
output_tensor_info.name)
return session.run(output_dict, feed_dict=feed_dict)
class SaveTest(test.TestCase): class SaveTest(test.TestCase):
@ -685,5 +692,55 @@ class MemoryTests(test.TestCase):
save.save(self._model, save_dir, self._model.call) save.save(self._model, save_dir, self._model.call)
class ExportMetaGraphTests(test.TestCase):
def test_export_meta_graph(self):
root = tracking.AutoTrackable()
root.variable = resource_variable_ops.UninitializedVariable(
name="some_variable", dtype=dtypes.float32)
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
def multiply_var(x):
return root.variable * x
@def_function.function(input_signature=[tensor_spec.TensorSpec([])])
def update(y):
root.variable.assign_add(y)
# TODO(b/150393409): All functions exported as signatures must have at
# least one output.
return 0
@def_function.function(input_signature=[])
def initialize():
root.variable.assign(1.0)
# TODO(b/150393409): All functions exported as signatures must have at
# least one output.
return 0
save_path = os.path.join(self.get_temp_dir(), "meta_graph.pb")
save.export_meta_graph(
root,
save_path,
signatures={
"multiply_var": multiply_var,
"initialize": initialize,
"update": update
})
with ops.Graph().as_default(), session_lib.Session() as session:
saver.import_meta_graph(save_path)
meta_graph_def = meta_graph.read_meta_graph_file(save_path)
# Initialize variable to 1
_run_signature(session, meta_graph_def, {}, "initialize")
out = _run_signature(session, meta_graph_def, {"x": 3}, "multiply_var")
self.assertAllEqual(out, {"output_0": 3})
# Adds 2 to the variable. Variable is now 3
_run_signature(session, meta_graph_def, {"y": 2}, "update")
out = _run_signature(session, meta_graph_def, {"x": 4}, "multiply_var")
self.assertAllEqual(out, {"output_0": 12})
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()