From ce5a8d8ff726e210ecb0d0402eb4f1de43b04346 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Fri, 8 Feb 2019 16:01:19 -0800 Subject: [PATCH] Allow users to interact with resource variables in imported 1.x SavedModels Modifies wrap_function to search for unlifted variables and lift them out as captures. Adds a test for training an imported model. PiperOrigin-RevId: 233142944 --- tensorflow/python/eager/wrap_function.py | 52 +++++++++++++++++++ .../python/saved_model/load_v1_in_v2.py | 2 +- .../python/saved_model/load_v1_in_v2_test.py | 15 ++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index f8fbda861ca..de384bd4564 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -19,12 +19,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import weakref + +from tensorflow.python.eager import def_function from tensorflow.python.eager import function from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -67,6 +71,54 @@ class WrappedFunction(function.ConcreteFunction): super(WrappedFunction, self).__init__( fn_graph, attrs=attrs, signature=signature) self._variable_holder = variable_holder + if ops.executing_eagerly_outside_functions(): + # TODO(allenl): Make this work in 1.x? + self._lift_unlifted_variables() + + def _lift_unlifted_variables(self): + """Finds resource variables and lifts them into the outer context. + + When we import a GraphDef inside a wrap_function, no Python graph building + code runs. This means we get VarHandleOps which create variable resources, + but no corresponding Python objects. Leaving them like this works but gives + the user no way to interact with or modify the variables outside the graph. + + This method searches for variables and lifts them out as regular variable + objects when possible, indicating to the FuncGraph that they are captures. + """ + with self.graph.as_default(): + collection_variables = ( + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) + existing_captures = set(self.graph.internal_captures) + lifted_variables = {} + for old_variable in collection_variables: + if (old_variable._in_graph_mode # pylint: disable=protected-access + and isinstance(old_variable, + resource_variable_ops.ResourceVariable)): + if old_variable.handle in existing_captures: + continue + new_variable = def_function.UnliftedInitializerVariable( + array_ops.placeholder( + name="unused_{}_initializer".format(old_variable.op.name), + shape=old_variable.shape, + dtype=old_variable.dtype), + name=old_variable.op.name, + trainable=old_variable.trainable) + self.graph.captures[new_variable.handle] = old_variable.handle + existing_captures.add(old_variable.handle) + lifted_variables[old_variable] = new_variable + # pylint: disable=protected-access + self._variable_holder._variables.append(new_variable) + self.graph._weak_variables.append(weakref.ref(new_variable)) + # pylint: enable=protected-access + # Update the graph's collections, partly for the user and partly so this + # function is idempotent when it runs again in prune() calls. + for collection_name in [ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.LOCAL_VARIABLES]: + mutable_collection = ops.get_collection_ref(collection_name) + for index, current in enumerate(mutable_collection): + mutable_collection[index] = lifted_variables.get(current, current) def prune(self, feeds, fetches): flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index c7c8d8d8040..5980b848788 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -108,7 +108,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): root = tracking.AutoCheckpointable() root.signatures = signature_serialization.create_signature_map( signature_functions) - # TODO(allenl): Lift out variables and make a ".variables" property + root.variables = list(wrapped.graph.variables) return root diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index 42e13e782f0..2bb6b93b3b8 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import backprop from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -49,6 +50,7 @@ class LoadTest(test.TestCase): local_variable = variables.VariableV1( 1., collections=[ops.GraphKeys.LOCAL_VARIABLES], + trainable=False, use_resource=True) output = array_ops.identity(start * v * local_variable, name="output") with session_lib.Session() as session: @@ -70,6 +72,19 @@ class LoadTest(test.TestCase): fn(constant_op.constant(2.)) self.assertEqual({"output": 6.}, self.evaluate(fn(start=constant_op.constant(2.)))) + self.assertAllEqual([3., 1.], self.evaluate(imported.variables)) + imported.variables[0].assign(4.) + self.assertEqual({"output": 8.}, + self.evaluate(fn(start=constant_op.constant(2.)))) + imported.variables[1].assign(2.) + self.assertEqual({"output": 24.}, + self.evaluate(fn(start=constant_op.constant(3.)))) + self.assertTrue(imported.variables[0].trainable) + self.assertFalse(imported.variables[1].trainable) + with backprop.GradientTape() as tape: + output = fn(start=constant_op.constant(4.)) + self.assertEqual(imported.variables[:1], list(tape.watched_variables())) + self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy()) def test_ref_variable_import(self): with self.assertRaises(NotImplementedError):