Allow users to interact with resource variables in imported 1.x SavedModels
Modifies wrap_function to search for unlifted variables and lift them out as captures. Adds a test for training an imported model. PiperOrigin-RevId: 233142944
This commit is contained in:
parent
1798f30745
commit
ce5a8d8ff7
@ -19,12 +19,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import lift_to_graph
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -67,6 +71,54 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
super(WrappedFunction, self).__init__(
|
||||
fn_graph, attrs=attrs, signature=signature)
|
||||
self._variable_holder = variable_holder
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
# TODO(allenl): Make this work in 1.x?
|
||||
self._lift_unlifted_variables()
|
||||
|
||||
def _lift_unlifted_variables(self):
|
||||
"""Finds resource variables and lifts them into the outer context.
|
||||
|
||||
When we import a GraphDef inside a wrap_function, no Python graph building
|
||||
code runs. This means we get VarHandleOps which create variable resources,
|
||||
but no corresponding Python objects. Leaving them like this works but gives
|
||||
the user no way to interact with or modify the variables outside the graph.
|
||||
|
||||
This method searches for variables and lifts them out as regular variable
|
||||
objects when possible, indicating to the FuncGraph that they are captures.
|
||||
"""
|
||||
with self.graph.as_default():
|
||||
collection_variables = (
|
||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
+ ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
||||
existing_captures = set(self.graph.internal_captures)
|
||||
lifted_variables = {}
|
||||
for old_variable in collection_variables:
|
||||
if (old_variable._in_graph_mode # pylint: disable=protected-access
|
||||
and isinstance(old_variable,
|
||||
resource_variable_ops.ResourceVariable)):
|
||||
if old_variable.handle in existing_captures:
|
||||
continue
|
||||
new_variable = def_function.UnliftedInitializerVariable(
|
||||
array_ops.placeholder(
|
||||
name="unused_{}_initializer".format(old_variable.op.name),
|
||||
shape=old_variable.shape,
|
||||
dtype=old_variable.dtype),
|
||||
name=old_variable.op.name,
|
||||
trainable=old_variable.trainable)
|
||||
self.graph.captures[new_variable.handle] = old_variable.handle
|
||||
existing_captures.add(old_variable.handle)
|
||||
lifted_variables[old_variable] = new_variable
|
||||
# pylint: disable=protected-access
|
||||
self._variable_holder._variables.append(new_variable)
|
||||
self.graph._weak_variables.append(weakref.ref(new_variable))
|
||||
# pylint: enable=protected-access
|
||||
# Update the graph's collections, partly for the user and partly so this
|
||||
# function is idempotent when it runs again in prune() calls.
|
||||
for collection_name in [ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.LOCAL_VARIABLES]:
|
||||
mutable_collection = ops.get_collection_ref(collection_name)
|
||||
for index, current in enumerate(mutable_collection):
|
||||
mutable_collection[index] = lifted_variables.get(current, current)
|
||||
|
||||
def prune(self, feeds, fetches):
|
||||
flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
|
||||
|
@ -108,7 +108,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
root = tracking.AutoCheckpointable()
|
||||
root.signatures = signature_serialization.create_signature_map(
|
||||
signature_functions)
|
||||
# TODO(allenl): Lift out variables and make a ".variables" property
|
||||
root.variables = list(wrapped.graph.variables)
|
||||
return root
|
||||
|
||||
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -49,6 +50,7 @@ class LoadTest(test.TestCase):
|
||||
local_variable = variables.VariableV1(
|
||||
1.,
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||
trainable=False,
|
||||
use_resource=True)
|
||||
output = array_ops.identity(start * v * local_variable, name="output")
|
||||
with session_lib.Session() as session:
|
||||
@ -70,6 +72,19 @@ class LoadTest(test.TestCase):
|
||||
fn(constant_op.constant(2.))
|
||||
self.assertEqual({"output": 6.},
|
||||
self.evaluate(fn(start=constant_op.constant(2.))))
|
||||
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
|
||||
imported.variables[0].assign(4.)
|
||||
self.assertEqual({"output": 8.},
|
||||
self.evaluate(fn(start=constant_op.constant(2.))))
|
||||
imported.variables[1].assign(2.)
|
||||
self.assertEqual({"output": 24.},
|
||||
self.evaluate(fn(start=constant_op.constant(3.))))
|
||||
self.assertTrue(imported.variables[0].trainable)
|
||||
self.assertFalse(imported.variables[1].trainable)
|
||||
with backprop.GradientTape() as tape:
|
||||
output = fn(start=constant_op.constant(4.))
|
||||
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
|
||||
self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy())
|
||||
|
||||
def test_ref_variable_import(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
Loading…
Reference in New Issue
Block a user