Allow users to interact with resource variables in imported 1.x SavedModels

Modifies wrap_function to search for unlifted variables and lift them out as captures. Adds a test for training an imported model.

PiperOrigin-RevId: 233142944
This commit is contained in:
Allen Lavoie 2019-02-08 16:01:19 -08:00 committed by TensorFlower Gardener
parent 1798f30745
commit ce5a8d8ff7
3 changed files with 68 additions and 1 deletions

View File

@ -19,12 +19,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import weakref
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@ -67,6 +71,54 @@ class WrappedFunction(function.ConcreteFunction):
super(WrappedFunction, self).__init__(
fn_graph, attrs=attrs, signature=signature)
self._variable_holder = variable_holder
if ops.executing_eagerly_outside_functions():
# TODO(allenl): Make this work in 1.x?
self._lift_unlifted_variables()
def _lift_unlifted_variables(self):
"""Finds resource variables and lifts them into the outer context.
When we import a GraphDef inside a wrap_function, no Python graph building
code runs. This means we get VarHandleOps which create variable resources,
but no corresponding Python objects. Leaving them like this works but gives
the user no way to interact with or modify the variables outside the graph.
This method searches for variables and lifts them out as regular variable
objects when possible, indicating to the FuncGraph that they are captures.
"""
with self.graph.as_default():
collection_variables = (
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
existing_captures = set(self.graph.internal_captures)
lifted_variables = {}
for old_variable in collection_variables:
if (old_variable._in_graph_mode # pylint: disable=protected-access
and isinstance(old_variable,
resource_variable_ops.ResourceVariable)):
if old_variable.handle in existing_captures:
continue
new_variable = def_function.UnliftedInitializerVariable(
array_ops.placeholder(
name="unused_{}_initializer".format(old_variable.op.name),
shape=old_variable.shape,
dtype=old_variable.dtype),
name=old_variable.op.name,
trainable=old_variable.trainable)
self.graph.captures[new_variable.handle] = old_variable.handle
existing_captures.add(old_variable.handle)
lifted_variables[old_variable] = new_variable
# pylint: disable=protected-access
self._variable_holder._variables.append(new_variable)
self.graph._weak_variables.append(weakref.ref(new_variable))
# pylint: enable=protected-access
# Update the graph's collections, partly for the user and partly so this
# function is idempotent when it runs again in prune() calls.
for collection_name in [ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.LOCAL_VARIABLES]:
mutable_collection = ops.get_collection_ref(collection_name)
for index, current in enumerate(mutable_collection):
mutable_collection[index] = lifted_variables.get(current, current)
def prune(self, feeds, fetches):
flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)

View File

@ -108,7 +108,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
root = tracking.AutoCheckpointable()
root.signatures = signature_serialization.create_signature_map(
signature_functions)
# TODO(allenl): Lift out variables and make a ".variables" property
root.variables = list(wrapped.graph.variables)
return root

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import os
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -49,6 +50,7 @@ class LoadTest(test.TestCase):
local_variable = variables.VariableV1(
1.,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
trainable=False,
use_resource=True)
output = array_ops.identity(start * v * local_variable, name="output")
with session_lib.Session() as session:
@ -70,6 +72,19 @@ class LoadTest(test.TestCase):
fn(constant_op.constant(2.))
self.assertEqual({"output": 6.},
self.evaluate(fn(start=constant_op.constant(2.))))
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
imported.variables[0].assign(4.)
self.assertEqual({"output": 8.},
self.evaluate(fn(start=constant_op.constant(2.))))
imported.variables[1].assign(2.)
self.assertEqual({"output": 24.},
self.evaluate(fn(start=constant_op.constant(3.))))
self.assertTrue(imported.variables[0].trainable)
self.assertFalse(imported.variables[1].trainable)
with backprop.GradientTape() as tape:
output = fn(start=constant_op.constant(4.))
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy())
def test_ref_variable_import(self):
with self.assertRaises(NotImplementedError):