Add an imported representation of init ops for 1.x-style SavedModels
Adds TrackableAssets to a property of the root object, and adds a TrackableResource which runs a function containing the original init op. This allows repeated re-export of 1.x SavedModels without losing information about assets or initialization. Uniquifies shared_names rather than clearing them. 1.x SavedModels rely on shared_names sharing resources across functions (e.g. the init function vs. the function that uses a table) so clearing doesn't really work. This will leak cached kernels, but at least provides the right behavior. PiperOrigin-RevId: 233983446
This commit is contained in:
parent
3c9b46c245
commit
8aa71253c7
@ -212,8 +212,9 @@ def load_function_def_library(library):
|
|||||||
"""
|
"""
|
||||||
functions = {}
|
functions = {}
|
||||||
|
|
||||||
|
load_shared_name_suffix = "_load_{}".format(ops.uid())
|
||||||
for fdef in _sort_function_defs(library):
|
for fdef in _sort_function_defs(library):
|
||||||
copy = _fix_fdef(fdef, functions)
|
copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
|
||||||
|
|
||||||
func_graph = function_def_lib.function_def_to_graph(copy)
|
func_graph = function_def_lib.function_def_to_graph(copy)
|
||||||
for dep in _list_function_deps(fdef):
|
for dep in _list_function_deps(fdef):
|
||||||
@ -263,7 +264,7 @@ def _sort_function_defs(library):
|
|||||||
return [reverse[x] for x in output]
|
return [reverse[x] for x in output]
|
||||||
|
|
||||||
|
|
||||||
def _fix_fdef(orig_fdef, functions):
|
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
|
||||||
"""Fixes a FunctionDef proto to be loaded in current context.
|
"""Fixes a FunctionDef proto to be loaded in current context.
|
||||||
|
|
||||||
In particular, when loading a function library into an eager context, one
|
In particular, when loading a function library into an eager context, one
|
||||||
@ -272,6 +273,10 @@ def _fix_fdef(orig_fdef, functions):
|
|||||||
Args:
|
Args:
|
||||||
orig_fdef: FunctionDef proto to fix. It is not modified.
|
orig_fdef: FunctionDef proto to fix. It is not modified.
|
||||||
functions: map from function name to a ConcreteFunction instance.
|
functions: map from function name to a ConcreteFunction instance.
|
||||||
|
shared_name_suffix: A unique string for this load which helps to avoid
|
||||||
|
`shared_name` collisions across loads. Two functions from the same load
|
||||||
|
using the same `shared_name` still need to share, but functions from
|
||||||
|
different loads with the same `shared_name` should not.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A fixed copy of the original FunctionDef.
|
A fixed copy of the original FunctionDef.
|
||||||
@ -296,10 +301,10 @@ def _fix_fdef(orig_fdef, functions):
|
|||||||
attr_value.func.name = functions[attr_value.func.name].name
|
attr_value.func.name = functions[attr_value.func.name].name
|
||||||
|
|
||||||
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
|
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
|
||||||
# resources. For now drop "shared_name" when loading functions to avoid
|
# resources. For now uniquify "shared_name" when loading functions to avoid
|
||||||
# sharing.
|
# sharing.
|
||||||
if "shared_name" in node_def.attr:
|
if "shared_name" in node_def.attr:
|
||||||
del node_def.attr["shared_name"]
|
node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
|
||||||
|
|
||||||
fdef.signature.name = _clean_function_name(fdef.signature.name)
|
fdef.signature.name = _clean_function_name(fdef.signature.name)
|
||||||
return fdef
|
return fdef
|
||||||
|
@ -22,12 +22,41 @@ import functools
|
|||||||
|
|
||||||
from tensorflow.python.eager import wrap_function
|
from tensorflow.python.eager import wrap_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
from tensorflow.python.saved_model import signature_serialization
|
from tensorflow.python.saved_model import signature_serialization
|
||||||
from tensorflow.python.training import saver as tf_saver
|
from tensorflow.python.training import saver as tf_saver
|
||||||
from tensorflow.python.training.checkpointable import tracking
|
from tensorflow.python.training.checkpointable import tracking
|
||||||
|
|
||||||
|
|
||||||
|
class _Initializer(tracking.TrackableResource):
|
||||||
|
"""Represents an initialization operation restored from a SavedModel.
|
||||||
|
|
||||||
|
Without this object re-export of imported 1.x SavedModels would omit the
|
||||||
|
original SavedModel's initialization procedure.
|
||||||
|
|
||||||
|
Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an
|
||||||
|
initialization op. This object holds a function which runs the
|
||||||
|
initialization. It does not require any manual user intervention;
|
||||||
|
`tf.saved_model.save` will see this object and automatically add it to the
|
||||||
|
exported SavedModel, and `tf.saved_model.load` runs the initialization
|
||||||
|
function automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_fn, asset_paths):
|
||||||
|
super(_Initializer, self).__init__()
|
||||||
|
self._asset_paths = asset_paths
|
||||||
|
self._init_fn = init_fn
|
||||||
|
|
||||||
|
def create_resource(self):
|
||||||
|
return array_ops.placeholder(
|
||||||
|
dtype=dtypes.resource, shape=[], name="unused_resource")
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self._init_fn(*[path.asset_path for path in self._asset_paths])
|
||||||
|
|
||||||
|
|
||||||
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||||
"""Loads a SavedModel without using Sessions."""
|
"""Loads a SavedModel without using Sessions."""
|
||||||
|
|
||||||
@ -94,6 +123,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
self.restore_variables(wrapped, saver)
|
self.restore_variables(wrapped, saver)
|
||||||
with wrapped.graph.as_default():
|
with wrapped.graph.as_default():
|
||||||
init_op = loader_impl.get_init_op(meta_graph_def)
|
init_op = loader_impl.get_init_op(meta_graph_def)
|
||||||
|
root = tracking.AutoCheckpointable()
|
||||||
if init_op is not None:
|
if init_op is not None:
|
||||||
asset_feed_tensors = []
|
asset_feed_tensors = []
|
||||||
asset_paths = []
|
asset_paths = []
|
||||||
@ -104,9 +134,13 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
init_fn = wrapped.prune(
|
init_fn = wrapped.prune(
|
||||||
feeds=asset_feed_tensors,
|
feeds=asset_feed_tensors,
|
||||||
fetches=[wrapped.graph.as_graph_element(init_op)])
|
fetches=[wrapped.graph.as_graph_element(init_op)])
|
||||||
init_fn(*[path.asset_path for path in asset_paths])
|
initializer = _Initializer(init_fn, asset_paths)
|
||||||
|
initializer.initialize()
|
||||||
|
root.initializer = initializer
|
||||||
|
root.asset_paths = asset_paths
|
||||||
|
else:
|
||||||
|
root.asset_paths = []
|
||||||
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
||||||
root = tracking.AutoCheckpointable()
|
|
||||||
root.signatures = signature_serialization.create_signature_map(
|
root.signatures = signature_serialization.create_signature_map(
|
||||||
signature_functions)
|
signature_functions)
|
||||||
root.variables = list(wrapped.graph.variables)
|
root.variables = list(wrapped.graph.variables)
|
||||||
|
@ -192,14 +192,19 @@ class LoadTest(test.TestCase):
|
|||||||
str(ops.uid()))
|
str(ops.uid()))
|
||||||
save.save(imported, second_path, signatures=imported.signatures)
|
save.save(imported, second_path, signatures=imported.signatures)
|
||||||
shutil.rmtree(first_path)
|
shutil.rmtree(first_path)
|
||||||
self.skipTest(
|
|
||||||
"TODO(b/124321570): save TrackableAssets and make re-saving initialize "
|
|
||||||
"correctly")
|
|
||||||
second_import = load.load(second_path)
|
second_import = load.load(second_path)
|
||||||
fn = second_import.signatures["serving_default"]
|
fn = second_import.signatures["serving_default"]
|
||||||
self.assertAllClose({"output": [2, 0]},
|
self.assertAllClose({"output": [2, 0]},
|
||||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||||
|
|
||||||
|
third_path = os.path.join(self.get_temp_dir(), "saved_model",
|
||||||
|
str(ops.uid()))
|
||||||
|
save.save(second_import, third_path, signatures=second_import.signatures)
|
||||||
|
shutil.rmtree(second_path)
|
||||||
|
third_import = load.load(third_path)
|
||||||
|
fn = third_import.signatures["serving_default"]
|
||||||
|
self.assertAllClose({"output": [2, 0]},
|
||||||
|
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user