Add an imported representation of init ops for 1.x-style SavedModels
Adds TrackableAssets to a property of the root object, and adds a TrackableResource which runs a function containing the original init op. This allows repeated re-export of 1.x SavedModels without losing information about assets or initialization. Uniquifies shared_names rather than clearing them. 1.x SavedModels rely on shared_names sharing resources across functions (e.g. the init function vs. the function that uses a table) so clearing doesn't really work. This will leak cached kernels, but at least provides the right behavior. PiperOrigin-RevId: 233983446
This commit is contained in:
parent
3c9b46c245
commit
8aa71253c7
@ -212,8 +212,9 @@ def load_function_def_library(library):
|
||||
"""
|
||||
functions = {}
|
||||
|
||||
load_shared_name_suffix = "_load_{}".format(ops.uid())
|
||||
for fdef in _sort_function_defs(library):
|
||||
copy = _fix_fdef(fdef, functions)
|
||||
copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
|
||||
|
||||
func_graph = function_def_lib.function_def_to_graph(copy)
|
||||
for dep in _list_function_deps(fdef):
|
||||
@ -263,7 +264,7 @@ def _sort_function_defs(library):
|
||||
return [reverse[x] for x in output]
|
||||
|
||||
|
||||
def _fix_fdef(orig_fdef, functions):
|
||||
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
|
||||
"""Fixes a FunctionDef proto to be loaded in current context.
|
||||
|
||||
In particular, when loading a function library into an eager context, one
|
||||
@ -272,6 +273,10 @@ def _fix_fdef(orig_fdef, functions):
|
||||
Args:
|
||||
orig_fdef: FunctionDef proto to fix. It is not modified.
|
||||
functions: map from function name to a ConcreteFunction instance.
|
||||
shared_name_suffix: A unique string for this load which helps to avoid
|
||||
`shared_name` collisions across loads. Two functions from the same load
|
||||
using the same `shared_name` still need to share, but functions from
|
||||
different loads with the same `shared_name` should not.
|
||||
|
||||
Returns:
|
||||
A fixed copy of the original FunctionDef.
|
||||
@ -296,10 +301,10 @@ def _fix_fdef(orig_fdef, functions):
|
||||
attr_value.func.name = functions[attr_value.func.name].name
|
||||
|
||||
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
|
||||
# resources. For now drop "shared_name" when loading functions to avoid
|
||||
# resources. For now uniquify "shared_name" when loading functions to avoid
|
||||
# sharing.
|
||||
if "shared_name" in node_def.attr:
|
||||
del node_def.attr["shared_name"]
|
||||
node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
|
||||
|
||||
fdef.signature.name = _clean_function_name(fdef.signature.name)
|
||||
return fdef
|
||||
|
@ -22,12 +22,41 @@ import functools
|
||||
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import signature_serialization
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training.checkpointable import tracking
|
||||
|
||||
|
||||
class _Initializer(tracking.TrackableResource):
|
||||
"""Represents an initialization operation restored from a SavedModel.
|
||||
|
||||
Without this object re-export of imported 1.x SavedModels would omit the
|
||||
original SavedModel's initialization procedure.
|
||||
|
||||
Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an
|
||||
initialization op. This object holds a function which runs the
|
||||
initialization. It does not require any manual user intervention;
|
||||
`tf.saved_model.save` will see this object and automatically add it to the
|
||||
exported SavedModel, and `tf.saved_model.load` runs the initialization
|
||||
function automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, init_fn, asset_paths):
|
||||
super(_Initializer, self).__init__()
|
||||
self._asset_paths = asset_paths
|
||||
self._init_fn = init_fn
|
||||
|
||||
def create_resource(self):
|
||||
return array_ops.placeholder(
|
||||
dtype=dtypes.resource, shape=[], name="unused_resource")
|
||||
|
||||
def initialize(self):
|
||||
self._init_fn(*[path.asset_path for path in self._asset_paths])
|
||||
|
||||
|
||||
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
"""Loads a SavedModel without using Sessions."""
|
||||
|
||||
@ -94,6 +123,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
self.restore_variables(wrapped, saver)
|
||||
with wrapped.graph.as_default():
|
||||
init_op = loader_impl.get_init_op(meta_graph_def)
|
||||
root = tracking.AutoCheckpointable()
|
||||
if init_op is not None:
|
||||
asset_feed_tensors = []
|
||||
asset_paths = []
|
||||
@ -104,9 +134,13 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
init_fn = wrapped.prune(
|
||||
feeds=asset_feed_tensors,
|
||||
fetches=[wrapped.graph.as_graph_element(init_op)])
|
||||
init_fn(*[path.asset_path for path in asset_paths])
|
||||
initializer = _Initializer(init_fn, asset_paths)
|
||||
initializer.initialize()
|
||||
root.initializer = initializer
|
||||
root.asset_paths = asset_paths
|
||||
else:
|
||||
root.asset_paths = []
|
||||
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
||||
root = tracking.AutoCheckpointable()
|
||||
root.signatures = signature_serialization.create_signature_map(
|
||||
signature_functions)
|
||||
root.variables = list(wrapped.graph.variables)
|
||||
|
@ -192,14 +192,19 @@ class LoadTest(test.TestCase):
|
||||
str(ops.uid()))
|
||||
save.save(imported, second_path, signatures=imported.signatures)
|
||||
shutil.rmtree(first_path)
|
||||
self.skipTest(
|
||||
"TODO(b/124321570): save TrackableAssets and make re-saving initialize "
|
||||
"correctly")
|
||||
second_import = load.load(second_path)
|
||||
fn = second_import.signatures["serving_default"]
|
||||
self.assertAllClose({"output": [2, 0]},
|
||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||
|
||||
third_path = os.path.join(self.get_temp_dir(), "saved_model",
|
||||
str(ops.uid()))
|
||||
save.save(second_import, third_path, signatures=second_import.signatures)
|
||||
shutil.rmtree(second_path)
|
||||
third_import = load.load(third_path)
|
||||
fn = third_import.signatures["serving_default"]
|
||||
self.assertAllClose({"output": [2, 0]},
|
||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user