Allow users to interact with resource variables in imported 1.x SavedModels
Modifies wrap_function to search for unlifted variables and lift them out as captures. Adds a test for training an imported model. PiperOrigin-RevId: 233142944
This commit is contained in:
parent
1798f30745
commit
ce5a8d8ff7
@ -19,12 +19,16 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.eager import lift_to_graph
|
from tensorflow.python.eager import lift_to_graph
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -67,6 +71,54 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
super(WrappedFunction, self).__init__(
|
super(WrappedFunction, self).__init__(
|
||||||
fn_graph, attrs=attrs, signature=signature)
|
fn_graph, attrs=attrs, signature=signature)
|
||||||
self._variable_holder = variable_holder
|
self._variable_holder = variable_holder
|
||||||
|
if ops.executing_eagerly_outside_functions():
|
||||||
|
# TODO(allenl): Make this work in 1.x?
|
||||||
|
self._lift_unlifted_variables()
|
||||||
|
|
||||||
|
def _lift_unlifted_variables(self):
|
||||||
|
"""Finds resource variables and lifts them into the outer context.
|
||||||
|
|
||||||
|
When we import a GraphDef inside a wrap_function, no Python graph building
|
||||||
|
code runs. This means we get VarHandleOps which create variable resources,
|
||||||
|
but no corresponding Python objects. Leaving them like this works but gives
|
||||||
|
the user no way to interact with or modify the variables outside the graph.
|
||||||
|
|
||||||
|
This method searches for variables and lifts them out as regular variable
|
||||||
|
objects when possible, indicating to the FuncGraph that they are captures.
|
||||||
|
"""
|
||||||
|
with self.graph.as_default():
|
||||||
|
collection_variables = (
|
||||||
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||||
|
+ ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
||||||
|
existing_captures = set(self.graph.internal_captures)
|
||||||
|
lifted_variables = {}
|
||||||
|
for old_variable in collection_variables:
|
||||||
|
if (old_variable._in_graph_mode # pylint: disable=protected-access
|
||||||
|
and isinstance(old_variable,
|
||||||
|
resource_variable_ops.ResourceVariable)):
|
||||||
|
if old_variable.handle in existing_captures:
|
||||||
|
continue
|
||||||
|
new_variable = def_function.UnliftedInitializerVariable(
|
||||||
|
array_ops.placeholder(
|
||||||
|
name="unused_{}_initializer".format(old_variable.op.name),
|
||||||
|
shape=old_variable.shape,
|
||||||
|
dtype=old_variable.dtype),
|
||||||
|
name=old_variable.op.name,
|
||||||
|
trainable=old_variable.trainable)
|
||||||
|
self.graph.captures[new_variable.handle] = old_variable.handle
|
||||||
|
existing_captures.add(old_variable.handle)
|
||||||
|
lifted_variables[old_variable] = new_variable
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self._variable_holder._variables.append(new_variable)
|
||||||
|
self.graph._weak_variables.append(weakref.ref(new_variable))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
# Update the graph's collections, partly for the user and partly so this
|
||||||
|
# function is idempotent when it runs again in prune() calls.
|
||||||
|
for collection_name in [ops.GraphKeys.GLOBAL_VARIABLES,
|
||||||
|
ops.GraphKeys.LOCAL_VARIABLES]:
|
||||||
|
mutable_collection = ops.get_collection_ref(collection_name)
|
||||||
|
for index, current in enumerate(mutable_collection):
|
||||||
|
mutable_collection[index] = lifted_variables.get(current, current)
|
||||||
|
|
||||||
def prune(self, feeds, fetches):
|
def prune(self, feeds, fetches):
|
||||||
flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
|
flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
|
||||||
|
@ -108,7 +108,7 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
|||||||
root = tracking.AutoCheckpointable()
|
root = tracking.AutoCheckpointable()
|
||||||
root.signatures = signature_serialization.create_signature_map(
|
root.signatures = signature_serialization.create_signature_map(
|
||||||
signature_functions)
|
signature_functions)
|
||||||
# TODO(allenl): Lift out variables and make a ".variables" property
|
root.variables = list(wrapped.graph.variables)
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -49,6 +50,7 @@ class LoadTest(test.TestCase):
|
|||||||
local_variable = variables.VariableV1(
|
local_variable = variables.VariableV1(
|
||||||
1.,
|
1.,
|
||||||
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
collections=[ops.GraphKeys.LOCAL_VARIABLES],
|
||||||
|
trainable=False,
|
||||||
use_resource=True)
|
use_resource=True)
|
||||||
output = array_ops.identity(start * v * local_variable, name="output")
|
output = array_ops.identity(start * v * local_variable, name="output")
|
||||||
with session_lib.Session() as session:
|
with session_lib.Session() as session:
|
||||||
@ -70,6 +72,19 @@ class LoadTest(test.TestCase):
|
|||||||
fn(constant_op.constant(2.))
|
fn(constant_op.constant(2.))
|
||||||
self.assertEqual({"output": 6.},
|
self.assertEqual({"output": 6.},
|
||||||
self.evaluate(fn(start=constant_op.constant(2.))))
|
self.evaluate(fn(start=constant_op.constant(2.))))
|
||||||
|
self.assertAllEqual([3., 1.], self.evaluate(imported.variables))
|
||||||
|
imported.variables[0].assign(4.)
|
||||||
|
self.assertEqual({"output": 8.},
|
||||||
|
self.evaluate(fn(start=constant_op.constant(2.))))
|
||||||
|
imported.variables[1].assign(2.)
|
||||||
|
self.assertEqual({"output": 24.},
|
||||||
|
self.evaluate(fn(start=constant_op.constant(3.))))
|
||||||
|
self.assertTrue(imported.variables[0].trainable)
|
||||||
|
self.assertFalse(imported.variables[1].trainable)
|
||||||
|
with backprop.GradientTape() as tape:
|
||||||
|
output = fn(start=constant_op.constant(4.))
|
||||||
|
self.assertEqual(imported.variables[:1], list(tape.watched_variables()))
|
||||||
|
self.assertEqual(8., tape.gradient(output, imported.variables[0]).numpy())
|
||||||
|
|
||||||
def test_ref_variable_import(self):
|
def test_ref_variable_import(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
|
Loading…
Reference in New Issue
Block a user