Support loading v1 SavedModels in v1 graph mode with tf.saved_model.load

Mostly just some fiddling with collections and variable initializers.

PiperOrigin-RevId: 247081188
This commit is contained in:
Allen Lavoie 2019-05-07 13:37:46 -07:00 committed by TensorFlower Gardener
parent de3e26b431
commit 289a0af9b0
3 changed files with 109 additions and 44 deletions

View File

@ -84,6 +84,30 @@ class VariableHolder(object):
return wrapped
def _lift_single_variable(old_variable, graph, variable_holder):
"""Lifts `old_variable` out of the `FuncGraph` `graph`."""
new_variable = resource_variable_ops.UninitializedVariable(
shape=old_variable.shape,
dtype=old_variable.dtype,
name=old_variable.op.name,
trainable=old_variable.trainable,
extra_handle_data=old_variable.handle)
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
graph.inputs.append(old_variable.handle)
graph.captures[new_variable.handle] = old_variable.handle
# Now that we've added the new variable to graph.captures,
# graph.capture will use that cached value and do some post-processing
# on the capture like recording it on the tape.
graph.capture(new_variable.handle)
# pylint: disable=protected-access
variable_name = new_variable.name.split(":")[0]
variable_holder._variables_by_name[variable_name] = new_variable
graph._weak_variables.append(weakref.ref(new_variable))
# pylint: enable=protected-access
graph.watch_variable(new_variable)
return new_variable
def _lift_unlifted_variables(graph, variable_holder):
"""Finds resource variables and lifts them into the outer context.
@ -100,39 +124,44 @@ def _lift_unlifted_variables(graph, variable_holder):
variable_holder: A VariableHolder to record the lifted variables in.
"""
with graph.as_default():
collection_variables = (
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
global_collection_variables = ops.get_collection(
ops.GraphKeys.GLOBAL_VARIABLES)
local_collection_variables = ops.get_collection(
ops.GraphKeys.LOCAL_VARIABLES)
existing_captures = set(graph.internal_captures)
lifted_variables = {}
for old_variable in collection_variables:
if (old_variable._in_graph_mode # pylint: disable=protected-access
and
isinstance(old_variable, resource_variable_ops.ResourceVariable)):
if old_variable.handle in existing_captures:
continue
new_variable = resource_variable_ops.UninitializedVariable(
shape=old_variable.shape,
dtype=old_variable.dtype,
name=old_variable.op.name,
trainable=old_variable.trainable,
extra_handle_data=old_variable.handle)
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
graph.inputs.append(old_variable.handle)
graph.captures[new_variable.handle] = old_variable.handle
# Now that we've added the new variable to graph.captures,
# graph.capture will use that cached value and do some post-processing
# on the capture like recording it on the tape.
graph.capture(new_variable.handle)
existing_captures.add(old_variable.handle)
def _should_lift_variable(v):
return ((v._in_graph_mode # pylint: disable=protected-access
and v.graph.building_function)
and isinstance(v, resource_variable_ops.ResourceVariable)
and v.handle not in existing_captures)
for old_variable in global_collection_variables:
if _should_lift_variable(old_variable):
new_variable = _lift_single_variable(
old_variable, graph, variable_holder)
lifted_variables[old_variable] = new_variable
# pylint: disable=protected-access
variable_name = new_variable.name.split(":")[0]
variable_holder._variables_by_name[variable_name] = new_variable
graph._weak_variables.append(weakref.ref(new_variable))
# pylint: enable=protected-access
graph.watch_variable(new_variable)
# Update the graph's collections, partly for the user and partly so this
existing_captures.add(old_variable.handle)
for old_variable in local_collection_variables:
if _should_lift_variable(old_variable):
new_variable = _lift_single_variable(
old_variable, graph, variable_holder)
lifted_variables[old_variable] = new_variable
existing_captures.add(old_variable.handle)
if new_variable._in_graph_mode: # pylint: disable=protected-access
outer_graph = new_variable.graph
# Variables are added to the global collection by default. In this
# case we only want the variable in the local collection, so we'll pop
# it out.
global_collection = outer_graph.get_collection_ref(
ops.GraphKeys.GLOBAL_VARIABLES)
global_collection.remove(new_variable)
outer_graph.add_to_collection(
ops.GraphKeys.LOCAL_VARIABLES, new_variable)
# Update the FuncGraph's collections, partly for the user and partly so this
# function is idempotent when it runs again in prune() calls.
for collection_name in [
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
@ -148,9 +177,7 @@ class WrappedFunction(function.ConcreteFunction):
def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
self._variable_holder = variable_holder
if ops.executing_eagerly_outside_functions():
# TODO(allenl): Make this work in 1.x?
_lift_unlifted_variables(fn_graph, variable_holder)
_lift_unlifted_variables(fn_graph, variable_holder)
# We call __init__ after lifting variables so that the function's signature
# properly reflects the new captured inputs.
super(WrappedFunction, self).__init__(

View File

@ -20,10 +20,12 @@ from __future__ import print_function
import functools
from tensorflow.python.eager import context
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_serialization
@ -56,7 +58,7 @@ class _Initializer(tracking.CapturableResource):
dtype=dtypes.resource, shape=[], name="unused_resource")
def _initialize(self):
self._init_fn(*[path.asset_path for path in self._asset_paths])
return self._init_fn(*[path.asset_path for path in self._asset_paths])
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
@ -90,11 +92,21 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
"""Restores variables from the checkpoint."""
if saver is not None:
saver_def = saver.saver_def
filename_tensor = wrapped.graph.as_graph_element(
saver_def.filename_tensor_name)
# We both feed and fetch filename_tensor so we have an operation to use to
# feed into variable initializers (only relevant for v1 graph building).
restore_fn = wrapped.prune(
feeds=[wrapped.graph.as_graph_element(
saver_def.filename_tensor_name)],
fetches=[wrapped.graph.as_graph_element(saver_def.restore_op_name)])
restore_fn(constant_op.constant(self._variables_path))
feeds=[filename_tensor],
fetches=[filename_tensor,
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
initializer, _ = restore_fn(constant_op.constant(self._variables_path))
if not ops.executing_eagerly_outside_functions():
for variable in wrapped.graph.get_collection_ref(
ops.GraphKeys.GLOBAL_VARIABLES):
# pylint: disable=protected-access
variable._initializer_op = initializer
# pylint: enable=protected-access
def _extract_signatures(self, wrapped, meta_graph_def):
"""Creates ConcreteFunctions for signatures in `meta_graph_def`."""
@ -151,6 +163,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
# Add a dummy Tensor we know we can fetch to add control dependencies to.
init_anchor = constant_op.constant(0., name="dummy_fetch")
root = tracking.AutoTrackable()
asset_feed_tensors = []
@ -161,9 +175,19 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
asset_paths.append(tracking.TrackableAsset(value))
init_fn = wrapped.prune(
feeds=asset_feed_tensors,
fetches=[wrapped.graph.as_graph_element(init_op)])
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
initializer = _Initializer(init_fn, asset_paths)
initializer._initialize() # pylint: disable=protected-access
# pylint: disable=protected-access
local_init_op, _ = initializer._initialize()
# pylint: enable=protected-access
with ops.init_scope():
if not context.executing_eagerly():
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
for variable in wrapped.graph.get_collection_ref(
ops.GraphKeys.LOCAL_VARIABLES):
# pylint: disable=protected-access
variable._initializer_op = local_init_op
# pylint: enable=protected-access
root.initializer = initializer
root.asset_paths = asset_paths
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
@ -182,3 +206,4 @@ def load(export_dir, tags):
"""Load a v1-style SavedModel as an object."""
loader = _EagerSavedModelLoader(export_dir)
return loader.load(tags=tags)

View File

@ -28,6 +28,7 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@ -51,7 +52,7 @@ class LoadTest(test.TestCase):
export_graph = ops.Graph()
with export_graph.as_default():
start = array_ops.placeholder(
shape=[None], dtype=dtypes.float32, name="start")
shape=None, dtype=dtypes.float32, name="start")
if use_resource:
distractor = variables.RefVariable(-1., name="distractor")
v = resource_variable_ops.ResourceVariable(3., name="v")
@ -81,17 +82,20 @@ class LoadTest(test.TestCase):
legacy_init_op=local_variable.initializer)
return path
@test_util.run_in_graph_and_eager_modes
def test_resource_variable_import(self):
imported = load.load(self._v1_single_metagraph_saved_model(
use_resource=True))
self.evaluate(variables.global_variables_initializer())
self.evaluate(variables.local_variables_initializer())
fn = imported.signatures["serving_default"]
self.assertEqual({"output": 6.},
self.evaluate(fn(constant_op.constant(2.))))
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
imported.variables[0].assign(4.)
self.evaluate(imported.variables[0].assign(4.))
self.assertEqual({"output": 8.},
self.evaluate(fn(start=constant_op.constant(2.))))
imported.variables[1].assign(2.)
self.evaluate(imported.variables[1].assign(2.))
self.assertEqual({"output": 24.},
self.evaluate(fn(start=constant_op.constant(3.))))
self.assertTrue(imported.variables[0].trainable)
@ -99,7 +103,9 @@ class LoadTest(test.TestCase):
with backprop.GradientTape() as tape:
output = fn(start=constant_op.constant(4.))
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy())
self.assertEqual(
8.,
self.evaluate(tape.gradient(output, imported.variables[0])))
def test_ref_variable_import(self):
saved = self._v1_single_metagraph_saved_model(use_resource=False)
@ -185,9 +191,11 @@ class LoadTest(test.TestCase):
file_io.delete_file(vocab_path)
return path
@test_util.run_in_graph_and_eager_modes
def test_asset_loading(self):
first_path = self._v1_asset_saved_model()
imported = load.load(first_path)
self.evaluate(lookup_ops.tables_initializer())
fn = imported.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
@ -195,7 +203,9 @@ class LoadTest(test.TestCase):
str(ops.uid()))
save.save(imported, second_path, signatures=imported.signatures)
shutil.rmtree(first_path)
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
second_import = load.load(second_path)
self.evaluate(lookup_ops.tables_initializer())
fn = second_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
@ -204,7 +214,9 @@ class LoadTest(test.TestCase):
str(ops.uid()))
save.save(second_import, third_path, signatures=second_import.signatures)
shutil.rmtree(second_path)
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
third_import = load.load(third_path)
self.evaluate(lookup_ops.tables_initializer())
fn = third_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
@ -368,3 +380,4 @@ class LoadTest(test.TestCase):
if __name__ == "__main__":
test.main()