Support loading v1 SavedModels in v1 graph mode with tf.saved_model.load
Mostly just some fiddling with collections and variable initializers. PiperOrigin-RevId: 247081188
This commit is contained in:
parent
de3e26b431
commit
289a0af9b0
@ -84,6 +84,30 @@ class VariableHolder(object):
|
||||
return wrapped
|
||||
|
||||
|
||||
def _lift_single_variable(old_variable, graph, variable_holder):
|
||||
"""Lifts `old_variable` out of the `FuncGraph` `graph`."""
|
||||
new_variable = resource_variable_ops.UninitializedVariable(
|
||||
shape=old_variable.shape,
|
||||
dtype=old_variable.dtype,
|
||||
name=old_variable.op.name,
|
||||
trainable=old_variable.trainable,
|
||||
extra_handle_data=old_variable.handle)
|
||||
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
||||
graph.inputs.append(old_variable.handle)
|
||||
graph.captures[new_variable.handle] = old_variable.handle
|
||||
# Now that we've added the new variable to graph.captures,
|
||||
# graph.capture will use that cached value and do some post-processing
|
||||
# on the capture like recording it on the tape.
|
||||
graph.capture(new_variable.handle)
|
||||
# pylint: disable=protected-access
|
||||
variable_name = new_variable.name.split(":")[0]
|
||||
variable_holder._variables_by_name[variable_name] = new_variable
|
||||
graph._weak_variables.append(weakref.ref(new_variable))
|
||||
# pylint: enable=protected-access
|
||||
graph.watch_variable(new_variable)
|
||||
return new_variable
|
||||
|
||||
|
||||
def _lift_unlifted_variables(graph, variable_holder):
|
||||
"""Finds resource variables and lifts them into the outer context.
|
||||
|
||||
@ -100,39 +124,44 @@ def _lift_unlifted_variables(graph, variable_holder):
|
||||
variable_holder: A VariableHolder to record the lifted variables in.
|
||||
"""
|
||||
with graph.as_default():
|
||||
collection_variables = (
|
||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
|
||||
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
||||
global_collection_variables = ops.get_collection(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
local_collection_variables = ops.get_collection(
|
||||
ops.GraphKeys.LOCAL_VARIABLES)
|
||||
existing_captures = set(graph.internal_captures)
|
||||
lifted_variables = {}
|
||||
for old_variable in collection_variables:
|
||||
if (old_variable._in_graph_mode # pylint: disable=protected-access
|
||||
and
|
||||
isinstance(old_variable, resource_variable_ops.ResourceVariable)):
|
||||
if old_variable.handle in existing_captures:
|
||||
continue
|
||||
new_variable = resource_variable_ops.UninitializedVariable(
|
||||
shape=old_variable.shape,
|
||||
dtype=old_variable.dtype,
|
||||
name=old_variable.op.name,
|
||||
trainable=old_variable.trainable,
|
||||
extra_handle_data=old_variable.handle)
|
||||
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
||||
graph.inputs.append(old_variable.handle)
|
||||
graph.captures[new_variable.handle] = old_variable.handle
|
||||
# Now that we've added the new variable to graph.captures,
|
||||
# graph.capture will use that cached value and do some post-processing
|
||||
# on the capture like recording it on the tape.
|
||||
graph.capture(new_variable.handle)
|
||||
existing_captures.add(old_variable.handle)
|
||||
|
||||
def _should_lift_variable(v):
|
||||
return ((v._in_graph_mode # pylint: disable=protected-access
|
||||
and v.graph.building_function)
|
||||
and isinstance(v, resource_variable_ops.ResourceVariable)
|
||||
and v.handle not in existing_captures)
|
||||
|
||||
for old_variable in global_collection_variables:
|
||||
if _should_lift_variable(old_variable):
|
||||
new_variable = _lift_single_variable(
|
||||
old_variable, graph, variable_holder)
|
||||
lifted_variables[old_variable] = new_variable
|
||||
# pylint: disable=protected-access
|
||||
variable_name = new_variable.name.split(":")[0]
|
||||
variable_holder._variables_by_name[variable_name] = new_variable
|
||||
graph._weak_variables.append(weakref.ref(new_variable))
|
||||
# pylint: enable=protected-access
|
||||
graph.watch_variable(new_variable)
|
||||
# Update the graph's collections, partly for the user and partly so this
|
||||
existing_captures.add(old_variable.handle)
|
||||
|
||||
for old_variable in local_collection_variables:
|
||||
if _should_lift_variable(old_variable):
|
||||
new_variable = _lift_single_variable(
|
||||
old_variable, graph, variable_holder)
|
||||
lifted_variables[old_variable] = new_variable
|
||||
existing_captures.add(old_variable.handle)
|
||||
if new_variable._in_graph_mode: # pylint: disable=protected-access
|
||||
outer_graph = new_variable.graph
|
||||
# Variables are added to the global collection by default. In this
|
||||
# case we only want the variable in the local collection, so we'll pop
|
||||
# it out.
|
||||
global_collection = outer_graph.get_collection_ref(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
global_collection.remove(new_variable)
|
||||
outer_graph.add_to_collection(
|
||||
ops.GraphKeys.LOCAL_VARIABLES, new_variable)
|
||||
|
||||
# Update the FuncGraph's collections, partly for the user and partly so this
|
||||
# function is idempotent when it runs again in prune() calls.
|
||||
for collection_name in [
|
||||
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
|
||||
@ -148,9 +177,7 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
|
||||
def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
|
||||
self._variable_holder = variable_holder
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# TODO(allenl): Make this work in 1.x?
|
||||
_lift_unlifted_variables(fn_graph, variable_holder)
|
||||
_lift_unlifted_variables(fn_graph, variable_holder)
|
||||
# We call __init__ after lifting variables so that the function's signature
|
||||
# properly reflects the new captured inputs.
|
||||
super(WrappedFunction, self).__init__(
|
||||
|
@ -20,10 +20,12 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.eager import wrap_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import signature_serialization
|
||||
@ -56,7 +58,7 @@ class _Initializer(tracking.CapturableResource):
|
||||
dtype=dtypes.resource, shape=[], name="unused_resource")
|
||||
|
||||
def _initialize(self):
|
||||
self._init_fn(*[path.asset_path for path in self._asset_paths])
|
||||
return self._init_fn(*[path.asset_path for path in self._asset_paths])
|
||||
|
||||
|
||||
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
@ -90,11 +92,21 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
"""Restores variables from the checkpoint."""
|
||||
if saver is not None:
|
||||
saver_def = saver.saver_def
|
||||
filename_tensor = wrapped.graph.as_graph_element(
|
||||
saver_def.filename_tensor_name)
|
||||
# We both feed and fetch filename_tensor so we have an operation to use to
|
||||
# feed into variable initializers (only relevant for v1 graph building).
|
||||
restore_fn = wrapped.prune(
|
||||
feeds=[wrapped.graph.as_graph_element(
|
||||
saver_def.filename_tensor_name)],
|
||||
fetches=[wrapped.graph.as_graph_element(saver_def.restore_op_name)])
|
||||
restore_fn(constant_op.constant(self._variables_path))
|
||||
feeds=[filename_tensor],
|
||||
fetches=[filename_tensor,
|
||||
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
|
||||
initializer, _ = restore_fn(constant_op.constant(self._variables_path))
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
for variable in wrapped.graph.get_collection_ref(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES):
|
||||
# pylint: disable=protected-access
|
||||
variable._initializer_op = initializer
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _extract_signatures(self, wrapped, meta_graph_def):
|
||||
"""Creates ConcreteFunctions for signatures in `meta_graph_def`."""
|
||||
@ -151,6 +163,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
with wrapped.graph.as_default():
|
||||
init_op = loader_impl.get_init_op(
|
||||
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
|
||||
# Add a dummy Tensor we know we can fetch to add control dependencies to.
|
||||
init_anchor = constant_op.constant(0., name="dummy_fetch")
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
asset_feed_tensors = []
|
||||
@ -161,9 +175,19 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
asset_paths.append(tracking.TrackableAsset(value))
|
||||
init_fn = wrapped.prune(
|
||||
feeds=asset_feed_tensors,
|
||||
fetches=[wrapped.graph.as_graph_element(init_op)])
|
||||
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
|
||||
initializer = _Initializer(init_fn, asset_paths)
|
||||
initializer._initialize() # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
local_init_op, _ = initializer._initialize()
|
||||
# pylint: enable=protected-access
|
||||
with ops.init_scope():
|
||||
if not context.executing_eagerly():
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
|
||||
for variable in wrapped.graph.get_collection_ref(
|
||||
ops.GraphKeys.LOCAL_VARIABLES):
|
||||
# pylint: disable=protected-access
|
||||
variable._initializer_op = local_init_op
|
||||
# pylint: enable=protected-access
|
||||
root.initializer = initializer
|
||||
root.asset_paths = asset_paths
|
||||
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
||||
@ -182,3 +206,4 @@ def load(export_dir, tags):
|
||||
"""Load a v1-style SavedModel as an object."""
|
||||
loader = _EagerSavedModelLoader(export_dir)
|
||||
return loader.load(tags=tags)
|
||||
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -51,7 +52,7 @@ class LoadTest(test.TestCase):
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
start = array_ops.placeholder(
|
||||
shape=[None], dtype=dtypes.float32, name="start")
|
||||
shape=None, dtype=dtypes.float32, name="start")
|
||||
if use_resource:
|
||||
distractor = variables.RefVariable(-1., name="distractor")
|
||||
v = resource_variable_ops.ResourceVariable(3., name="v")
|
||||
@ -81,17 +82,20 @@ class LoadTest(test.TestCase):
|
||||
legacy_init_op=local_variable.initializer)
|
||||
return path
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_resource_variable_import(self):
|
||||
imported = load.load(self._v1_single_metagraph_saved_model(
|
||||
use_resource=True))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
fn = imported.signatures["serving_default"]
|
||||
self.assertEqual({"output": 6.},
|
||||
self.evaluate(fn(constant_op.constant(2.))))
|
||||
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
|
||||
imported.variables[0].assign(4.)
|
||||
self.evaluate(imported.variables[0].assign(4.))
|
||||
self.assertEqual({"output": 8.},
|
||||
self.evaluate(fn(start=constant_op.constant(2.))))
|
||||
imported.variables[1].assign(2.)
|
||||
self.evaluate(imported.variables[1].assign(2.))
|
||||
self.assertEqual({"output": 24.},
|
||||
self.evaluate(fn(start=constant_op.constant(3.))))
|
||||
self.assertTrue(imported.variables[0].trainable)
|
||||
@ -99,7 +103,9 @@ class LoadTest(test.TestCase):
|
||||
with backprop.GradientTape() as tape:
|
||||
output = fn(start=constant_op.constant(4.))
|
||||
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
|
||||
self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy())
|
||||
self.assertEqual(
|
||||
8.,
|
||||
self.evaluate(tape.gradient(output, imported.variables[0])))
|
||||
|
||||
def test_ref_variable_import(self):
|
||||
saved = self._v1_single_metagraph_saved_model(use_resource=False)
|
||||
@ -185,9 +191,11 @@ class LoadTest(test.TestCase):
|
||||
file_io.delete_file(vocab_path)
|
||||
return path
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_asset_loading(self):
|
||||
first_path = self._v1_asset_saved_model()
|
||||
imported = load.load(first_path)
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
fn = imported.signatures["serving_default"]
|
||||
self.assertAllClose({"output": [2, 0]},
|
||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||
@ -195,7 +203,9 @@ class LoadTest(test.TestCase):
|
||||
str(ops.uid()))
|
||||
save.save(imported, second_path, signatures=imported.signatures)
|
||||
shutil.rmtree(first_path)
|
||||
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
|
||||
second_import = load.load(second_path)
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
fn = second_import.signatures["serving_default"]
|
||||
self.assertAllClose({"output": [2, 0]},
|
||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||
@ -204,7 +214,9 @@ class LoadTest(test.TestCase):
|
||||
str(ops.uid()))
|
||||
save.save(second_import, third_path, signatures=second_import.signatures)
|
||||
shutil.rmtree(second_path)
|
||||
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
|
||||
third_import = load.load(third_path)
|
||||
self.evaluate(lookup_ops.tables_initializer())
|
||||
fn = third_import.signatures["serving_default"]
|
||||
self.assertAllClose({"output": [2, 0]},
|
||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||
@ -368,3 +380,4 @@ class LoadTest(test.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user