Add an imported representation of init ops for 1.x-style SavedModels

Adds TrackableAssets to a property of the root object, and adds a TrackableResource which runs a function containing the original init op. This allows repeated re-export of 1.x SavedModels without losing information about assets or initialization.

Uniquifies shared_names rather than clearing them. 1.x SavedModels rely on shared_names sharing resources across functions (e.g. the init function vs. the function that uses a table) so clearing doesn't really work. This will leak cached kernels, but at least provides the right behavior.

PiperOrigin-RevId: 233983446
This commit is contained in:
Allen Lavoie 2019-02-14 10:33:15 -08:00 committed by TensorFlower Gardener
parent 3c9b46c245
commit 8aa71253c7
3 changed files with 53 additions and 9 deletions

View File

@ -212,8 +212,9 @@ def load_function_def_library(library):
"""
functions = {}
load_shared_name_suffix = "_load_{}".format(ops.uid())
for fdef in _sort_function_defs(library):
copy = _fix_fdef(fdef, functions)
copy = _fix_fdef(fdef, functions, load_shared_name_suffix)
func_graph = function_def_lib.function_def_to_graph(copy)
for dep in _list_function_deps(fdef):
@ -263,7 +264,7 @@ def _sort_function_defs(library):
return [reverse[x] for x in output]
def _fix_fdef(orig_fdef, functions):
def _fix_fdef(orig_fdef, functions, shared_name_suffix):
"""Fixes a FunctionDef proto to be loaded in current context.
In particular, when loading a function library into an eager context, one
@ -272,6 +273,10 @@ def _fix_fdef(orig_fdef, functions):
Args:
orig_fdef: FunctionDef proto to fix. It is not modified.
functions: map from function name to a ConcreteFunction instance.
shared_name_suffix: A unique string for this load which helps to avoid
`shared_name` collisions across loads. Two functions from the same load
using the same `shared_name` still need to share, but functions from
different loads with the same `shared_name` should not.
Returns:
A fixed copy of the original FunctionDef.
@ -296,10 +301,10 @@ def _fix_fdef(orig_fdef, functions):
attr_value.func.name = functions[attr_value.func.name].name
# TODO(b/124205571): Avoid accidental sharing and destruction of restored
# resources. For now drop "shared_name" when loading functions to avoid
# resources. For now uniquify "shared_name" when loading functions to avoid
# sharing.
if "shared_name" in node_def.attr:
del node_def.attr["shared_name"]
node_def.attr["shared_name"].s += compat.as_bytes(shared_name_suffix)
fdef.signature.name = _clean_function_name(fdef.signature.name)
return fdef

View File

@ -22,12 +22,41 @@ import functools
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training.checkpointable import tracking
class _Initializer(tracking.TrackableResource):
"""Represents an initialization operation restored from a SavedModel.
Without this object re-export of imported 1.x SavedModels would omit the
original SavedModel's initialization procedure.
Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an
initialization op. This object holds a function which runs the
initialization. It does not require any manual user intervention;
`tf.saved_model.save` will see this object and automatically add it to the
exported SavedModel, and `tf.saved_model.load` runs the initialization
function automatically.
"""
def __init__(self, init_fn, asset_paths):
super(_Initializer, self).__init__()
self._asset_paths = asset_paths
self._init_fn = init_fn
def create_resource(self):
return array_ops.placeholder(
dtype=dtypes.resource, shape=[], name="unused_resource")
def initialize(self):
self._init_fn(*[path.asset_path for path in self._asset_paths])
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
"""Loads a SavedModel without using Sessions."""
@ -94,6 +123,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
self.restore_variables(wrapped, saver)
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(meta_graph_def)
root = tracking.AutoCheckpointable()
if init_op is not None:
asset_feed_tensors = []
asset_paths = []
@ -104,9 +134,13 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
init_fn = wrapped.prune(
feeds=asset_feed_tensors,
fetches=[wrapped.graph.as_graph_element(init_op)])
init_fn(*[path.asset_path for path in asset_paths])
initializer = _Initializer(init_fn, asset_paths)
initializer.initialize()
root.initializer = initializer
root.asset_paths = asset_paths
else:
root.asset_paths = []
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
root = tracking.AutoCheckpointable()
root.signatures = signature_serialization.create_signature_map(
signature_functions)
root.variables = list(wrapped.graph.variables)

View File

@ -192,14 +192,19 @@ class LoadTest(test.TestCase):
str(ops.uid()))
save.save(imported, second_path, signatures=imported.signatures)
shutil.rmtree(first_path)
self.skipTest(
"TODO(b/124321570): save TrackableAssets and make re-saving initialize "
"correctly")
second_import = load.load(second_path)
fn = second_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
third_path = os.path.join(self.get_temp_dir(), "saved_model",
str(ops.uid()))
save.save(second_import, third_path, signatures=second_import.signatures)
shutil.rmtree(second_path)
third_import = load.load(third_path)
fn = third_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
if __name__ == "__main__":
test.main()