Support loading v1 SavedModels in v1 graph mode with tf.saved_model.load
Mostly just some fiddling with collections and variable initializers. PiperOrigin-RevId: 247081188
This commit is contained in:
parent
de3e26b431
commit
289a0af9b0
@ -84,6 +84,30 @@ class VariableHolder(object):
|
|||||||
return wrapped
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def _lift_single_variable(old_variable, graph, variable_holder):
|
||||||
|
"""Lifts `old_variable` out of the `FuncGraph` `graph`."""
|
||||||
|
new_variable = resource_variable_ops.UninitializedVariable(
|
||||||
|
shape=old_variable.shape,
|
||||||
|
dtype=old_variable.dtype,
|
||||||
|
name=old_variable.op.name,
|
||||||
|
trainable=old_variable.trainable,
|
||||||
|
extra_handle_data=old_variable.handle)
|
||||||
|
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
||||||
|
graph.inputs.append(old_variable.handle)
|
||||||
|
graph.captures[new_variable.handle] = old_variable.handle
|
||||||
|
# Now that we've added the new variable to graph.captures,
|
||||||
|
# graph.capture will use that cached value and do some post-processing
|
||||||
|
# on the capture like recording it on the tape.
|
||||||
|
graph.capture(new_variable.handle)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
variable_name = new_variable.name.split(":")[0]
|
||||||
|
variable_holder._variables_by_name[variable_name] = new_variable
|
||||||
|
graph._weak_variables.append(weakref.ref(new_variable))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
graph.watch_variable(new_variable)
|
||||||
|
return new_variable
|
||||||
|
|
||||||
|
|
||||||
def _lift_unlifted_variables(graph, variable_holder):
|
def _lift_unlifted_variables(graph, variable_holder):
|
||||||
"""Finds resource variables and lifts them into the outer context.
|
"""Finds resource variables and lifts them into the outer context.
|
||||||
|
|
||||||
@ -100,39 +124,44 @@ def _lift_unlifted_variables(graph, variable_holder):
|
|||||||
variable_holder: A VariableHolder to record the lifted variables in.
|
variable_holder: A VariableHolder to record the lifted variables in.
|
||||||
"""
|
"""
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
collection_variables = (
|
global_collection_variables = ops.get_collection(
|
||||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
|
ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
local_collection_variables = ops.get_collection(
|
||||||
|
ops.GraphKeys.LOCAL_VARIABLES)
|
||||||
existing_captures = set(graph.internal_captures)
|
existing_captures = set(graph.internal_captures)
|
||||||
lifted_variables = {}
|
lifted_variables = {}
|
||||||
for old_variable in collection_variables:
|
|
||||||
if (old_variable._in_graph_mode # pylint: disable=protected-access
|
def _should_lift_variable(v):
|
||||||
and
|
return ((v._in_graph_mode # pylint: disable=protected-access
|
||||||
isinstance(old_variable, resource_variable_ops.ResourceVariable)):
|
and v.graph.building_function)
|
||||||
if old_variable.handle in existing_captures:
|
and isinstance(v, resource_variable_ops.ResourceVariable)
|
||||||
continue
|
and v.handle not in existing_captures)
|
||||||
new_variable = resource_variable_ops.UninitializedVariable(
|
|
||||||
shape=old_variable.shape,
|
for old_variable in global_collection_variables:
|
||||||
dtype=old_variable.dtype,
|
if _should_lift_variable(old_variable):
|
||||||
name=old_variable.op.name,
|
new_variable = _lift_single_variable(
|
||||||
trainable=old_variable.trainable,
|
old_variable, graph, variable_holder)
|
||||||
extra_handle_data=old_variable.handle)
|
|
||||||
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
|
||||||
graph.inputs.append(old_variable.handle)
|
|
||||||
graph.captures[new_variable.handle] = old_variable.handle
|
|
||||||
# Now that we've added the new variable to graph.captures,
|
|
||||||
# graph.capture will use that cached value and do some post-processing
|
|
||||||
# on the capture like recording it on the tape.
|
|
||||||
graph.capture(new_variable.handle)
|
|
||||||
existing_captures.add(old_variable.handle)
|
|
||||||
lifted_variables[old_variable] = new_variable
|
lifted_variables[old_variable] = new_variable
|
||||||
# pylint: disable=protected-access
|
existing_captures.add(old_variable.handle)
|
||||||
variable_name = new_variable.name.split(":")[0]
|
|
||||||
variable_holder._variables_by_name[variable_name] = new_variable
|
for old_variable in local_collection_variables:
|
||||||
graph._weak_variables.append(weakref.ref(new_variable))
|
if _should_lift_variable(old_variable):
|
||||||
# pylint: enable=protected-access
|
new_variable = _lift_single_variable(
|
||||||
graph.watch_variable(new_variable)
|
old_variable, graph, variable_holder)
|
||||||
# Update the graph's collections, partly for the user and partly so this
|
lifted_variables[old_variable] = new_variable
|
||||||
|
existing_captures.add(old_variable.handle)
|
||||||
|
if new_variable._in_graph_mode: # pylint: disable=protected-access
|
||||||
|
outer_graph = new_variable.graph
|
||||||
|
# Variables are added to the global collection by default. In this
|
||||||
|
# case we only want the variable in the local collection, so we'll pop
|
||||||
|
# it out.
|
||||||
|
global_collection = outer_graph.get_collection_ref(
|
||||||
|
ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
|
global_collection.remove(new_variable)
|
||||||
|
outer_graph.add_to_collection(
|
||||||
|
ops.GraphKeys.LOCAL_VARIABLES, new_variable)
|
||||||
|
|
||||||
|
# Update the FuncGraph's collections, partly for the user and partly so this
|
||||||
# function is idempotent when it runs again in prune() calls.
|
# function is idempotent when it runs again in prune() calls.
|
||||||
for collection_name in [
|
for collection_name in [
|
||||||
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
|
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
|
||||||
@ -148,9 +177,7 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
|
|
||||||
def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
|
def __init__(self, fn_graph, variable_holder, attrs=None, signature=None):
|
||||||
self._variable_holder = variable_holder
|
self._variable_holder = variable_holder
|
||||||
if ops.executing_eagerly_outside_functions():
|
_lift_unlifted_variables(fn_graph, variable_holder)
|
||||||
# TODO(allenl): Make this work in 1.x?
|
|
||||||
_lift_unlifted_variables(fn_graph, variable_holder)
|
|
||||||
# We call __init__ after lifting variables so that the function's signature
|
# We call __init__ after lifting variables so that the function's signature
|
||||||
# properly reflects the new captured inputs.
|
# properly reflects the new captured inputs.
|
||||||
super(WrappedFunction, self).__init__(
|
super(WrappedFunction, self).__init__(
|
||||||
|
|||||||
@ -20,10 +20,12 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import lift_to_graph
|
from tensorflow.python.eager import lift_to_graph
|
||||||
from tensorflow.python.eager import wrap_function
|
from tensorflow.python.eager import wrap_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
from tensorflow.python.saved_model import signature_serialization
|
from tensorflow.python.saved_model import signature_serialization
|
||||||
@ -56,7 +58,7 @@ class _Initializer(tracking.CapturableResource):
|
|||||||
dtype=dtypes.resource, shape=[], name="unused_resource")
|
dtype=dtypes.resource, shape=[], name="unused_resource")
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
self._init_fn(*[path.asset_path for path in self._asset_paths])
|
return self._init_fn(*[path.asset_path for path in self._asset_paths])
|
||||||
|
|
||||||
|
|
||||||
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||||
@ -90,11 +92,21 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
"""Restores variables from the checkpoint."""
|
"""Restores variables from the checkpoint."""
|
||||||
if saver is not None:
|
if saver is not None:
|
||||||
saver_def = saver.saver_def
|
saver_def = saver.saver_def
|
||||||
|
filename_tensor = wrapped.graph.as_graph_element(
|
||||||
|
saver_def.filename_tensor_name)
|
||||||
|
# We both feed and fetch filename_tensor so we have an operation to use to
|
||||||
|
# feed into variable initializers (only relevant for v1 graph building).
|
||||||
restore_fn = wrapped.prune(
|
restore_fn = wrapped.prune(
|
||||||
feeds=[wrapped.graph.as_graph_element(
|
feeds=[filename_tensor],
|
||||||
saver_def.filename_tensor_name)],
|
fetches=[filename_tensor,
|
||||||
fetches=[wrapped.graph.as_graph_element(saver_def.restore_op_name)])
|
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
|
||||||
restore_fn(constant_op.constant(self._variables_path))
|
initializer, _ = restore_fn(constant_op.constant(self._variables_path))
|
||||||
|
if not ops.executing_eagerly_outside_functions():
|
||||||
|
for variable in wrapped.graph.get_collection_ref(
|
||||||
|
ops.GraphKeys.GLOBAL_VARIABLES):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
variable._initializer_op = initializer
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
def _extract_signatures(self, wrapped, meta_graph_def):
|
def _extract_signatures(self, wrapped, meta_graph_def):
|
||||||
"""Creates ConcreteFunctions for signatures in `meta_graph_def`."""
|
"""Creates ConcreteFunctions for signatures in `meta_graph_def`."""
|
||||||
@ -151,6 +163,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
with wrapped.graph.as_default():
|
with wrapped.graph.as_default():
|
||||||
init_op = loader_impl.get_init_op(
|
init_op = loader_impl.get_init_op(
|
||||||
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
|
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
|
||||||
|
# Add a dummy Tensor we know we can fetch to add control dependencies to.
|
||||||
|
init_anchor = constant_op.constant(0., name="dummy_fetch")
|
||||||
|
|
||||||
root = tracking.AutoTrackable()
|
root = tracking.AutoTrackable()
|
||||||
asset_feed_tensors = []
|
asset_feed_tensors = []
|
||||||
@ -161,9 +175,19 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
asset_paths.append(tracking.TrackableAsset(value))
|
asset_paths.append(tracking.TrackableAsset(value))
|
||||||
init_fn = wrapped.prune(
|
init_fn = wrapped.prune(
|
||||||
feeds=asset_feed_tensors,
|
feeds=asset_feed_tensors,
|
||||||
fetches=[wrapped.graph.as_graph_element(init_op)])
|
fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)])
|
||||||
initializer = _Initializer(init_fn, asset_paths)
|
initializer = _Initializer(init_fn, asset_paths)
|
||||||
initializer._initialize() # pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
local_init_op, _ = initializer._initialize()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
with ops.init_scope():
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op)
|
||||||
|
for variable in wrapped.graph.get_collection_ref(
|
||||||
|
ops.GraphKeys.LOCAL_VARIABLES):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
variable._initializer_op = local_init_op
|
||||||
|
# pylint: enable=protected-access
|
||||||
root.initializer = initializer
|
root.initializer = initializer
|
||||||
root.asset_paths = asset_paths
|
root.asset_paths = asset_paths
|
||||||
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
|
||||||
@ -182,3 +206,4 @@ def load(export_dir, tags):
|
|||||||
"""Load a v1-style SavedModel as an object."""
|
"""Load a v1-style SavedModel as an object."""
|
||||||
loader = _EagerSavedModelLoader(export_dir)
|
loader = _EagerSavedModelLoader(export_dir)
|
||||||
return loader.load(tags=tags)
|
return loader.load(tags=tags)
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from tensorflow.python.eager import test
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.framework import versions
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.lib.io import file_io
|
from tensorflow.python.lib.io import file_io
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -51,7 +52,7 @@ class LoadTest(test.TestCase):
|
|||||||
export_graph = ops.Graph()
|
export_graph = ops.Graph()
|
||||||
with export_graph.as_default():
|
with export_graph.as_default():
|
||||||
start = array_ops.placeholder(
|
start = array_ops.placeholder(
|
||||||
shape=[None], dtype=dtypes.float32, name="start")
|
shape=None, dtype=dtypes.float32, name="start")
|
||||||
if use_resource:
|
if use_resource:
|
||||||
distractor = variables.RefVariable(-1., name="distractor")
|
distractor = variables.RefVariable(-1., name="distractor")
|
||||||
v = resource_variable_ops.ResourceVariable(3., name="v")
|
v = resource_variable_ops.ResourceVariable(3., name="v")
|
||||||
@ -81,17 +82,20 @@ class LoadTest(test.TestCase):
|
|||||||
legacy_init_op=local_variable.initializer)
|
legacy_init_op=local_variable.initializer)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_resource_variable_import(self):
|
def test_resource_variable_import(self):
|
||||||
imported = load.load(self._v1_single_metagraph_saved_model(
|
imported = load.load(self._v1_single_metagraph_saved_model(
|
||||||
use_resource=True))
|
use_resource=True))
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
self.evaluate(variables.local_variables_initializer())
|
||||||
fn = imported.signatures["serving_default"]
|
fn = imported.signatures["serving_default"]
|
||||||
self.assertEqual({"output": 6.},
|
self.assertEqual({"output": 6.},
|
||||||
self.evaluate(fn(constant_op.constant(2.))))
|
self.evaluate(fn(constant_op.constant(2.))))
|
||||||
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
|
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
|
||||||
imported.variables[0].assign(4.)
|
self.evaluate(imported.variables[0].assign(4.))
|
||||||
self.assertEqual({"output": 8.},
|
self.assertEqual({"output": 8.},
|
||||||
self.evaluate(fn(start=constant_op.constant(2.))))
|
self.evaluate(fn(start=constant_op.constant(2.))))
|
||||||
imported.variables[1].assign(2.)
|
self.evaluate(imported.variables[1].assign(2.))
|
||||||
self.assertEqual({"output": 24.},
|
self.assertEqual({"output": 24.},
|
||||||
self.evaluate(fn(start=constant_op.constant(3.))))
|
self.evaluate(fn(start=constant_op.constant(3.))))
|
||||||
self.assertTrue(imported.variables[0].trainable)
|
self.assertTrue(imported.variables[0].trainable)
|
||||||
@ -99,7 +103,9 @@ class LoadTest(test.TestCase):
|
|||||||
with backprop.GradientTape() as tape:
|
with backprop.GradientTape() as tape:
|
||||||
output = fn(start=constant_op.constant(4.))
|
output = fn(start=constant_op.constant(4.))
|
||||||
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
|
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
|
||||||
self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy())
|
self.assertEqual(
|
||||||
|
8.,
|
||||||
|
self.evaluate(tape.gradient(output, imported.variables[0])))
|
||||||
|
|
||||||
def test_ref_variable_import(self):
|
def test_ref_variable_import(self):
|
||||||
saved = self._v1_single_metagraph_saved_model(use_resource=False)
|
saved = self._v1_single_metagraph_saved_model(use_resource=False)
|
||||||
@ -185,9 +191,11 @@ class LoadTest(test.TestCase):
|
|||||||
file_io.delete_file(vocab_path)
|
file_io.delete_file(vocab_path)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_asset_loading(self):
|
def test_asset_loading(self):
|
||||||
first_path = self._v1_asset_saved_model()
|
first_path = self._v1_asset_saved_model()
|
||||||
imported = load.load(first_path)
|
imported = load.load(first_path)
|
||||||
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
fn = imported.signatures["serving_default"]
|
fn = imported.signatures["serving_default"]
|
||||||
self.assertAllClose({"output": [2, 0]},
|
self.assertAllClose({"output": [2, 0]},
|
||||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||||
@ -195,7 +203,9 @@ class LoadTest(test.TestCase):
|
|||||||
str(ops.uid()))
|
str(ops.uid()))
|
||||||
save.save(imported, second_path, signatures=imported.signatures)
|
save.save(imported, second_path, signatures=imported.signatures)
|
||||||
shutil.rmtree(first_path)
|
shutil.rmtree(first_path)
|
||||||
|
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
|
||||||
second_import = load.load(second_path)
|
second_import = load.load(second_path)
|
||||||
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
fn = second_import.signatures["serving_default"]
|
fn = second_import.signatures["serving_default"]
|
||||||
self.assertAllClose({"output": [2, 0]},
|
self.assertAllClose({"output": [2, 0]},
|
||||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||||
@ -204,7 +214,9 @@ class LoadTest(test.TestCase):
|
|||||||
str(ops.uid()))
|
str(ops.uid()))
|
||||||
save.save(second_import, third_path, signatures=second_import.signatures)
|
save.save(second_import, third_path, signatures=second_import.signatures)
|
||||||
shutil.rmtree(second_path)
|
shutil.rmtree(second_path)
|
||||||
|
del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
|
||||||
third_import = load.load(third_path)
|
third_import = load.load(third_path)
|
||||||
|
self.evaluate(lookup_ops.tables_initializer())
|
||||||
fn = third_import.signatures["serving_default"]
|
fn = third_import.signatures["serving_default"]
|
||||||
self.assertAllClose({"output": [2, 0]},
|
self.assertAllClose({"output": [2, 0]},
|
||||||
fn(start=constant_op.constant(["gamma", "alpha"])))
|
fn(start=constant_op.constant(["gamma", "alpha"])))
|
||||||
@ -368,3 +380,4 @@ class LoadTest(test.TestCase):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user