Add function to export metagraph from a Trackable object. (i.e. export SavedModel without checkpoints)
PiperOrigin-RevId: 297685153 Change-Id: I2440a9ace1700c400e09137b4834f2d8cadf154a
This commit is contained in:
parent
aad9a544a4
commit
384f1a5507
@ -941,6 +941,55 @@ def save(obj, export_dir, signatures=None, options=None):
|
|||||||
May not be called from within a function body.
|
May not be called from within a function body.
|
||||||
@end_compatibility
|
@end_compatibility
|
||||||
"""
|
"""
|
||||||
|
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
||||||
|
# compatible (no sessions) and share it with this export API rather than
|
||||||
|
# making a SavedModel proto and writing it directly.
|
||||||
|
saved_model = saved_model_pb2.SavedModel()
|
||||||
|
meta_graph_def = saved_model.meta_graphs.add()
|
||||||
|
|
||||||
|
_, exported_graph, object_saver, asset_info = _build_meta_graph(
|
||||||
|
obj, export_dir, signatures, options, meta_graph_def)
|
||||||
|
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
|
||||||
|
|
||||||
|
# Write the checkpoint, copy assets into the assets directory, and write out
|
||||||
|
# the SavedModel proto itself.
|
||||||
|
utils_impl.get_or_create_variables_dir(export_dir)
|
||||||
|
object_saver.save(utils_impl.get_variables_path(export_dir))
|
||||||
|
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
|
||||||
|
export_dir)
|
||||||
|
# Note that this needs to be the last file operation when saving the
|
||||||
|
# SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
|
||||||
|
# indication that the SavedModel is completely written.
|
||||||
|
path = os.path.join(
|
||||||
|
compat.as_str(export_dir),
|
||||||
|
compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
|
||||||
|
file_io.atomic_write_string_to_file(
|
||||||
|
path, saved_model.SerializeToString(deterministic=True))
|
||||||
|
|
||||||
|
# Clean reference cycles so repeated export()s don't make work for the garbage
|
||||||
|
# collector. Before this point, we need to keep references to captured
|
||||||
|
# constants in the saved graph.
|
||||||
|
ops.dismantle_graph(exported_graph)
|
||||||
|
|
||||||
|
|
||||||
|
def export_meta_graph(obj, filename, signatures=None, options=None):
|
||||||
|
"""Exports the MetaGraph proto to a file."""
|
||||||
|
export_dir = os.path.dirname(filename)
|
||||||
|
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
|
||||||
|
obj, export_dir, signatures, options)
|
||||||
|
|
||||||
|
file_io.atomic_write_string_to_file(
|
||||||
|
filename, meta_graph_def.SerializeToString(deterministic=True))
|
||||||
|
|
||||||
|
# Clean reference cycles so repeated export()s don't make work for the garbage
|
||||||
|
# collector. Before this point, we need to keep references to captured
|
||||||
|
# constants in the saved graph.
|
||||||
|
ops.dismantle_graph(exported_graph)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_meta_graph(obj, export_dir, signatures, options,
|
||||||
|
meta_graph_def=None):
|
||||||
|
"""Creates a MetaGraph containing the resources and functions of an object."""
|
||||||
if ops.inside_function():
|
if ops.inside_function():
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"tf.saved_model.save is not supported inside a traced "
|
"tf.saved_model.save is not supported inside a traced "
|
||||||
@ -951,6 +1000,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected a Trackable object for export, got {}.".format(obj))
|
"Expected a Trackable object for export, got {}.".format(obj))
|
||||||
options = options or save_options.SaveOptions()
|
options = options or save_options.SaveOptions()
|
||||||
|
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
|
||||||
|
|
||||||
checkpoint_graph_view = _AugmentedGraphView(obj)
|
checkpoint_graph_view = _AugmentedGraphView(obj)
|
||||||
if signatures is None:
|
if signatures is None:
|
||||||
@ -971,12 +1021,6 @@ def save(obj, export_dir, signatures=None, options=None):
|
|||||||
# there can be side effects of creating variables.
|
# there can be side effects of creating variables.
|
||||||
_ = _SaveableView(checkpoint_graph_view)
|
_ = _SaveableView(checkpoint_graph_view)
|
||||||
saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
|
saveable_view = _SaveableView(checkpoint_graph_view, wrapped_functions)
|
||||||
|
|
||||||
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
|
||||||
# compatible (no sessions) and share it with this export API rather than
|
|
||||||
# making a SavedModel proto and writing it directly.
|
|
||||||
saved_model = saved_model_pb2.SavedModel()
|
|
||||||
meta_graph_def = saved_model.meta_graphs.add()
|
|
||||||
object_saver = util.TrackableSaver(checkpoint_graph_view)
|
object_saver = util.TrackableSaver(checkpoint_graph_view)
|
||||||
asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def,
|
asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def,
|
||||||
saveable_view, signatures,
|
saveable_view, signatures,
|
||||||
@ -988,18 +1032,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
|||||||
function_aliases[fdef.name] = alias
|
function_aliases[fdef.name] = alias
|
||||||
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
||||||
function_aliases[fdef.name] = alias
|
function_aliases[fdef.name] = alias
|
||||||
saved_model.saved_model_schema_version = (
|
|
||||||
constants.SAVED_MODEL_SCHEMA_VERSION)
|
|
||||||
# So far we've just been generating protocol buffers with no I/O. Now we write
|
|
||||||
# the checkpoint, copy assets into the assets directory, and write out the
|
|
||||||
# SavedModel proto itself.
|
|
||||||
utils_impl.get_or_create_variables_dir(export_dir)
|
|
||||||
object_saver.save(utils_impl.get_variables_path(export_dir))
|
|
||||||
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
|
|
||||||
export_dir)
|
|
||||||
path = os.path.join(
|
|
||||||
compat.as_str(export_dir),
|
|
||||||
compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
|
|
||||||
object_graph_proto = _serialize_object_graph(saveable_view,
|
object_graph_proto = _serialize_object_graph(saveable_view,
|
||||||
asset_info.asset_index)
|
asset_info.asset_index)
|
||||||
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
||||||
@ -1013,13 +1046,4 @@ def save(obj, export_dir, signatures=None, options=None):
|
|||||||
constants.DEBUG_INFO_FILENAME_PB),
|
constants.DEBUG_INFO_FILENAME_PB),
|
||||||
graph_debug_info.SerializeToString(deterministic=True))
|
graph_debug_info.SerializeToString(deterministic=True))
|
||||||
|
|
||||||
# Note that this needs to be the last file operation when saving the
|
return meta_graph_def, exported_graph, object_saver, asset_info
|
||||||
# SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
|
|
||||||
# indication that the SavedModel is completely written.
|
|
||||||
file_io.atomic_write_string_to_file(
|
|
||||||
path, saved_model.SerializeToString(deterministic=True))
|
|
||||||
|
|
||||||
# Clean reference cycles so repeated export()s don't make work for the garbage
|
|
||||||
# collector. Before this point, we need to keep references to captured
|
|
||||||
# constants in the saved graph.
|
|
||||||
ops.dismantle_graph(exported_graph)
|
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.eager import function
|
|||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import meta_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
@ -53,6 +54,7 @@ from tensorflow.python.saved_model import save
|
|||||||
from tensorflow.python.saved_model import save_options
|
from tensorflow.python.saved_model import save_options
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.saved_model import tag_constants
|
from tensorflow.python.saved_model import tag_constants
|
||||||
|
from tensorflow.python.training import saver
|
||||||
from tensorflow.python.training.tracking import tracking
|
from tensorflow.python.training.tracking import tracking
|
||||||
from tensorflow.python.training.tracking import util
|
from tensorflow.python.training.tracking import util
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -76,6 +78,21 @@ class _ModelWithOptimizer(util.Checkpoint):
|
|||||||
return {"loss": loss}
|
return {"loss": loss}
|
||||||
|
|
||||||
|
|
||||||
|
def _run_signature(session, meta_graph_def, inputs, signature_key):
|
||||||
|
signature = meta_graph_def.signature_def[signature_key]
|
||||||
|
assert set(inputs.keys()) == set(signature.inputs.keys())
|
||||||
|
feed_dict = {}
|
||||||
|
for arg_name in inputs.keys():
|
||||||
|
input_tensor = session.graph.get_tensor_by_name(
|
||||||
|
signature.inputs[arg_name].name)
|
||||||
|
feed_dict[input_tensor] = inputs[arg_name]
|
||||||
|
output_dict = {}
|
||||||
|
for output_name, output_tensor_info in signature.outputs.items():
|
||||||
|
output_dict[output_name] = session.graph.get_tensor_by_name(
|
||||||
|
output_tensor_info.name)
|
||||||
|
return session.run(output_dict, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
|
||||||
def _import_and_infer(
|
def _import_and_infer(
|
||||||
save_dir, inputs,
|
save_dir, inputs,
|
||||||
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
|
signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY):
|
||||||
@ -83,17 +100,7 @@ def _import_and_infer(
|
|||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
with graph.as_default(), session_lib.Session() as session:
|
with graph.as_default(), session_lib.Session() as session:
|
||||||
model = loader.load(session, [tag_constants.SERVING], save_dir)
|
model = loader.load(session, [tag_constants.SERVING], save_dir)
|
||||||
signature = model.signature_def[signature_key]
|
return _run_signature(session, model, inputs, signature_key)
|
||||||
assert set(inputs.keys()) == set(signature.inputs.keys())
|
|
||||||
feed_dict = {}
|
|
||||||
for arg_name in inputs.keys():
|
|
||||||
feed_dict[graph.get_tensor_by_name(signature.inputs[arg_name].name)] = (
|
|
||||||
inputs[arg_name])
|
|
||||||
output_dict = {}
|
|
||||||
for output_name, output_tensor_info in signature.outputs.items():
|
|
||||||
output_dict[output_name] = graph.get_tensor_by_name(
|
|
||||||
output_tensor_info.name)
|
|
||||||
return session.run(output_dict, feed_dict=feed_dict)
|
|
||||||
|
|
||||||
|
|
||||||
class SaveTest(test.TestCase):
|
class SaveTest(test.TestCase):
|
||||||
@ -685,5 +692,55 @@ class MemoryTests(test.TestCase):
|
|||||||
save.save(self._model, save_dir, self._model.call)
|
save.save(self._model, save_dir, self._model.call)
|
||||||
|
|
||||||
|
|
||||||
|
class ExportMetaGraphTests(test.TestCase):
|
||||||
|
|
||||||
|
def test_export_meta_graph(self):
|
||||||
|
root = tracking.AutoTrackable()
|
||||||
|
root.variable = resource_variable_ops.UninitializedVariable(
|
||||||
|
name="some_variable", dtype=dtypes.float32)
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[tensor_spec.TensorSpec(None)])
|
||||||
|
def multiply_var(x):
|
||||||
|
return root.variable * x
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[tensor_spec.TensorSpec([])])
|
||||||
|
def update(y):
|
||||||
|
root.variable.assign_add(y)
|
||||||
|
# TODO(b/150393409): All functions exported as signatures must have at
|
||||||
|
# least one output.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[])
|
||||||
|
def initialize():
|
||||||
|
root.variable.assign(1.0)
|
||||||
|
# TODO(b/150393409): All functions exported as signatures must have at
|
||||||
|
# least one output.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
save_path = os.path.join(self.get_temp_dir(), "meta_graph.pb")
|
||||||
|
save.export_meta_graph(
|
||||||
|
root,
|
||||||
|
save_path,
|
||||||
|
signatures={
|
||||||
|
"multiply_var": multiply_var,
|
||||||
|
"initialize": initialize,
|
||||||
|
"update": update
|
||||||
|
})
|
||||||
|
|
||||||
|
with ops.Graph().as_default(), session_lib.Session() as session:
|
||||||
|
saver.import_meta_graph(save_path)
|
||||||
|
meta_graph_def = meta_graph.read_meta_graph_file(save_path)
|
||||||
|
|
||||||
|
# Initialize variable to 1
|
||||||
|
_run_signature(session, meta_graph_def, {}, "initialize")
|
||||||
|
out = _run_signature(session, meta_graph_def, {"x": 3}, "multiply_var")
|
||||||
|
self.assertAllEqual(out, {"output_0": 3})
|
||||||
|
|
||||||
|
# Adds 2 to the variable. Variable is now 3
|
||||||
|
_run_signature(session, meta_graph_def, {"y": 2}, "update")
|
||||||
|
out = _run_signature(session, meta_graph_def, {"x": 4}, "multiply_var")
|
||||||
|
self.assertAllEqual(out, {"output_0": 12})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user