Convert multiple v1 functions into ConcreteFunctions while sharing variables.

Uses name-based matching to reuse variables.

PiperOrigin-RevId: 238345493
This commit is contained in:
Katherine Wu 2019-03-13 17:30:36 -07:00 committed by TensorFlower Gardener
parent c94aab2f47
commit 50713af4c3
2 changed files with 268 additions and 10 deletions

View File

@ -37,21 +37,40 @@ from tensorflow.python.util.tf_export import tf_export
class VariableHolder(object):
"""Holds variables for a python function."""
def __init__(self, fn):
def __init__(self, fn=None, share_variables=False):
self._fn = fn
self._variables = []
self._share_variables = share_variables
self._variables_by_name = {}
@property
def variables(self):
return self._variables
def variable_creator_scope(self, next_creator, **kwargs):
"""Creates variables & adds them to collections to match legacy code."""
v = next_creator(**kwargs)
self._variables.append(v)
collections = kwargs.pop("collections", None)
v = None
collections = kwargs.get("collections")
trainable = v.trainable
# Get expected variable name.
name = kwargs.get("name", None)
with ops.name_scope(name, "Variable") as name_scope:
name = name_scope
if self._share_variables:
v = self._variables_by_name.get(name, None)
if v is None:
v = next_creator(**kwargs)
self._variables.append(v)
if self._share_variables:
self._variables_by_name[name] = v
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
ops.add_to_collections(collections, v)
@ -59,8 +78,13 @@ class VariableHolder(object):
return v
def __call__(self, *args, **kwargs):
with variable_scope.variable_creator_scope(self.variable_creator_scope):
return self._fn(*args, **kwargs)
return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
def call_with_variable_creator_scope(self, fn):
def wrapped(*args, **kwargs):
with variable_scope.variable_creator_scope(self.variable_creator_scope):
return fn(*args, **kwargs)
return wrapped
# TODO(allenl): make this trackable
@ -120,7 +144,8 @@ class WrappedFunction(function.ConcreteFunction):
for index, current in enumerate(mutable_collection):
mutable_collection[index] = lifted_variables.get(current, current)
def prune(self, feeds, fetches):
def prune(self, feeds, fetches, name=None):
name = name or "pruned"
flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
for f in flat_feeds:
if not isinstance(f, ops.Tensor):
@ -148,7 +173,7 @@ class WrappedFunction(function.ConcreteFunction):
"are from this graph (%s). Tensor %s from graph %s" % (
self._func_graph, f, f.graph))
with self._func_graph.as_default():
pruned_graph = func_graph.FuncGraph("pruned")
pruned_graph = func_graph.FuncGraph(name)
with ops.control_dependencies(operation_fetches):
if tensor_fetches:
identity_fetches = array_ops.identity_n(tensor_fetches)
@ -187,6 +212,89 @@ class WrappedFunction(function.ConcreteFunction):
return pruned_fn
class WrappedGraph(object):
"""Class for wrapping multiple TF 1.X functions in a single graph.
Maintains a dictionary mapping names to wrapped functions. See
`tf.compat.v1.wrap_function` to learn more about wrapping V1 functions.
Functions wrapped using this class have access to variables and collections
created in other wrapped functions, using the standard TF 1.X API (
`tf.compat.v1.get_variable` or
`tf.compat.v1.get_default_graph().get_collection(...)`)
Outside a function, variables and collections may be accessed using the
`variables` and `graph` properties.
Example:
```
def add_v1(x):
with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE):
v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
return v + x
def increment_var_v1(x):
with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE):
v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
return v.assign_add(x)
g = WrappedGraph()
add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)])
increment_var = g.wrap_function(increment_var_v1,
[tf.TensorSpec([], tf.int32)])
assert len(g.variables) == 1
assert g.variables[0].numpy() == 0
increment_var(tf.constant(5))
assert g.variables[0].numpy() == 5
```
"""
def __init__(self, variable_holder=None, **kwargs):
self._variable_holder = (
variable_holder or VariableHolder(share_variables=True))
name = kwargs.pop("name", "wrapped_function_graph")
# Always start with empty collections, unless otherwise specified. Setting
# `collections=None` will copy the collections from the outer graph.
collections = kwargs.pop("collections", {})
self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)
self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
self._functions = {}
@property
def functions(self):
return self._functions
@property
def variables(self):
return self._variable_holder.variables
def wrap_function(self, fn, signature, name=None):
"""Wrap a TF 1.X function and save to functions dictionary."""
func_graph.func_graph_from_py_func(
None, # Name is unused.
self._variable_holder.call_with_variable_creator_scope(fn),
args=None, kwargs=None, signature=signature,
add_control_dependencies=False,
func_graph=self.graph)
# This code relies on questional behavior from `func_graph_from_py_func`.
# If an existing FuncGraph is passed into the `func_graph` arg, the inputs
# and structured outputs are overwritten. Pretty sure this is a bug,
# because structured outputs doesn't match up with the outputs...
fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
fn_outputs = self.graph.structured_outputs
wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs)
name = name or fn.__name__
self._functions[name] = wrapped_function
return wrapped_function
@tf_export(v1=["wrap_function"])
def wrap_function(fn, signature, name=None):
"""Wraps the TF 1.x function fn into a graph function.

View File

@ -26,6 +26,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -257,6 +259,154 @@ class WrapFunctionTest(test.TestCase):
self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy())
class WrappedGraphTest(test.TestCase):
def testAddFunction(self):
def fn(x):
v = variables.Variable(3, name='v')
v2 = variable_scope.get_variable(
'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
return v + v2 + x
with self.cached_session() as sess:
result = fn(constant_op.constant(5))
sess.run(variables.global_variables_initializer())
expected = sess.run(result)
g = wrap_function.WrappedGraph()
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
wrapped_fn = g.wrap_function(fn, signature)
self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy())
def testCollections(self):
def fn(x):
v = variables.VariableV1(3, name='v', trainable=False, collections=['a'])
v2 = variable_scope.get_variable(
'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32,
collections=['a', 'b'])
return v + v2 + x
def assert_collections(graph):
self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1)
self.assertLen(graph.get_collection('a'), 2)
self.assertLen(graph.get_collection('b'), 1)
g = wrap_function.WrappedGraph()
g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)])
assert_collections(g.graph)
def assert_fn():
assert_collections(ops.get_default_graph())
return 1 # Return is required
# Assert that collections are accessible within a wrapped function.
g.wrap_function(assert_fn, [])
def testShareVariablesSameGraph(self):
def add_v1(x):
with variable_scope.variable_scope(
'reuse', reuse=variable_scope.AUTO_REUSE):
v = variable_scope.get_variable(
'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32)
return v + x
def subtract_v1(x):
with variable_scope.variable_scope(
'reuse', reuse=variable_scope.AUTO_REUSE):
v = variable_scope.get_variable(
'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
return v - x
def different_variable_fn_v1(x):
with variable_scope.variable_scope(
'no_reuse', reuse=variable_scope.AUTO_REUSE):
v = variable_scope.get_variable(
'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32)
return v * x
def increment_variable_v1(x):
with variable_scope.variable_scope(
'reuse', reuse=variable_scope.AUTO_REUSE):
v = variable_scope.get_variable(
'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32)
return v.assign_add(x)
g = wrap_function.WrappedGraph()
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
add = g.wrap_function(add_v1, signature)
subtract = g.wrap_function(subtract_v1, signature)
different_variable_fn = g.wrap_function(different_variable_fn_v1, signature)
increment_variable = g.wrap_function(increment_variable_v1, signature)
self.assertEqual(10, add(constant_op.constant(7)).numpy())
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
# The shared variable has a starting value of 3 because add_v1 was wrapped
# first.
self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
# Check that variable updates
self.assertEqual(17, add(constant_op.constant(7)).numpy())
self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
# Sanity check - result from this function shouldn't change.
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
self.assertAllEqual({'reuse/v:0', 'no_reuse/v:0'},
set([v.name for v in g.variables]))
def testShareVariablesDifferentGraphs(self):
def add_v1(x):
v = variables.Variable(3, name='v')
return v + x
def subtract_v1(x):
v = variables.Variable(4, name='v')
return v - x
def different_variable_fn_v1(x):
with ops.name_scope('different_scope'):
v = variables.Variable(5, name='v')
return v * x
def increment_variable_v1(x):
v = variables.Variable(6, name='v')
return v.assign_add(x)
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
vh = wrap_function.VariableHolder(share_variables=True)
new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh)
add = new_graph().wrap_function(add_v1, signature)
subtract = new_graph().wrap_function(subtract_v1, signature)
different_variable_fn = new_graph().wrap_function(
different_variable_fn_v1, signature)
increment_variable = new_graph().wrap_function(
increment_variable_v1, signature)
self.assertEqual(10, add(constant_op.constant(7)).numpy())
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
# Because the variable in add_v1 was created first, its starting value is 3
# instead of the values defined in subtract_v1 or increment_variable_v1.
self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
# Check that variable updates
self.assertEqual(17, add(constant_op.constant(7)).numpy())
self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
# Sanity check - result from this function shouldn't change.
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
self.assertAllEqual({'v:0', 'different_scope/v:0'},
set([v.name for v in vh.variables]))
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()