Convert multiple v1 functions into ConcreteFunctions while sharing variables.
Uses name-based matching to reuse variables. PiperOrigin-RevId: 238345493
This commit is contained in:
parent
c94aab2f47
commit
50713af4c3
@ -37,21 +37,40 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
class VariableHolder(object):
|
||||
"""Holds variables for a python function."""
|
||||
|
||||
def __init__(self, fn):
|
||||
def __init__(self, fn=None, share_variables=False):
|
||||
self._fn = fn
|
||||
|
||||
self._variables = []
|
||||
|
||||
self._share_variables = share_variables
|
||||
self._variables_by_name = {}
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._variables
|
||||
|
||||
def variable_creator_scope(self, next_creator, **kwargs):
|
||||
"""Creates variables & adds them to collections to match legacy code."""
|
||||
collections = kwargs.pop("collections", None)
|
||||
v = None
|
||||
|
||||
# Get expected variable name.
|
||||
name = kwargs.get("name", None)
|
||||
with ops.name_scope(name, "Variable") as name_scope:
|
||||
name = name_scope
|
||||
|
||||
if self._share_variables:
|
||||
v = self._variables_by_name.get(name, None)
|
||||
|
||||
if v is None:
|
||||
v = next_creator(**kwargs)
|
||||
self._variables.append(v)
|
||||
|
||||
collections = kwargs.get("collections")
|
||||
trainable = v.trainable
|
||||
if self._share_variables:
|
||||
self._variables_by_name[name] = v
|
||||
|
||||
if collections is None:
|
||||
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
||||
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
|
||||
if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
|
||||
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
|
||||
|
||||
ops.add_to_collections(collections, v)
|
||||
@ -59,8 +78,13 @@ class VariableHolder(object):
|
||||
return v
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
|
||||
|
||||
def call_with_variable_creator_scope(self, fn):
|
||||
def wrapped(*args, **kwargs):
|
||||
with variable_scope.variable_creator_scope(self.variable_creator_scope):
|
||||
return self._fn(*args, **kwargs)
|
||||
return fn(*args, **kwargs)
|
||||
return wrapped
|
||||
|
||||
|
||||
# TODO(allenl): make this trackable
|
||||
@ -120,7 +144,8 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
for index, current in enumerate(mutable_collection):
|
||||
mutable_collection[index] = lifted_variables.get(current, current)
|
||||
|
||||
def prune(self, feeds, fetches):
|
||||
def prune(self, feeds, fetches, name=None):
|
||||
name = name or "pruned"
|
||||
flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches)
|
||||
for f in flat_feeds:
|
||||
if not isinstance(f, ops.Tensor):
|
||||
@ -148,7 +173,7 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
"are from this graph (%s). Tensor %s from graph %s" % (
|
||||
self._func_graph, f, f.graph))
|
||||
with self._func_graph.as_default():
|
||||
pruned_graph = func_graph.FuncGraph("pruned")
|
||||
pruned_graph = func_graph.FuncGraph(name)
|
||||
with ops.control_dependencies(operation_fetches):
|
||||
if tensor_fetches:
|
||||
identity_fetches = array_ops.identity_n(tensor_fetches)
|
||||
@ -187,6 +212,89 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
return pruned_fn
|
||||
|
||||
|
||||
class WrappedGraph(object):
|
||||
"""Class for wrapping multiple TF 1.X functions in a single graph.
|
||||
|
||||
Maintains a dictionary mapping names to wrapped functions. See
|
||||
`tf.compat.v1.wrap_function` to learn more about wrapping V1 functions.
|
||||
|
||||
Functions wrapped using this class have access to variables and collections
|
||||
created in other wrapped functions, using the standard TF 1.X API (
|
||||
`tf.compat.v1.get_variable` or
|
||||
`tf.compat.v1.get_default_graph().get_collection(...)`)
|
||||
|
||||
Outside a function, variables and collections may be accessed using the
|
||||
`variables` and `graph` properties.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
def add_v1(x):
|
||||
with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE):
|
||||
v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
|
||||
return v + x
|
||||
|
||||
def increment_var_v1(x):
|
||||
with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE):
|
||||
v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32)
|
||||
return v.assign_add(x)
|
||||
|
||||
g = WrappedGraph()
|
||||
add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)])
|
||||
increment_var = g.wrap_function(increment_var_v1,
|
||||
[tf.TensorSpec([], tf.int32)])
|
||||
|
||||
assert len(g.variables) == 1
|
||||
assert g.variables[0].numpy() == 0
|
||||
increment_var(tf.constant(5))
|
||||
assert g.variables[0].numpy() == 5
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, variable_holder=None, **kwargs):
|
||||
self._variable_holder = (
|
||||
variable_holder or VariableHolder(share_variables=True))
|
||||
|
||||
name = kwargs.pop("name", "wrapped_function_graph")
|
||||
# Always start with empty collections, unless otherwise specified. Setting
|
||||
# `collections=None` will copy the collections from the outer graph.
|
||||
collections = kwargs.pop("collections", {})
|
||||
self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs)
|
||||
|
||||
self._wrapped_function = WrappedFunction(self.graph, self._variable_holder)
|
||||
self._functions = {}
|
||||
|
||||
@property
|
||||
def functions(self):
|
||||
return self._functions
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
return self._variable_holder.variables
|
||||
|
||||
def wrap_function(self, fn, signature, name=None):
|
||||
"""Wrap a TF 1.X function and save to functions dictionary."""
|
||||
func_graph.func_graph_from_py_func(
|
||||
None, # Name is unused.
|
||||
self._variable_holder.call_with_variable_creator_scope(fn),
|
||||
args=None, kwargs=None, signature=signature,
|
||||
add_control_dependencies=False,
|
||||
func_graph=self.graph)
|
||||
|
||||
# This code relies on questional behavior from `func_graph_from_py_func`.
|
||||
# If an existing FuncGraph is passed into the `func_graph` arg, the inputs
|
||||
# and structured outputs are overwritten. Pretty sure this is a bug,
|
||||
# because structured outputs doesn't match up with the outputs...
|
||||
fn_inputs = self.graph.inputs[:-len(self.graph.captures)]
|
||||
fn_outputs = self.graph.structured_outputs
|
||||
|
||||
wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs)
|
||||
name = name or fn.__name__
|
||||
self._functions[name] = wrapped_function
|
||||
return wrapped_function
|
||||
|
||||
|
||||
@tf_export(v1=["wrap_function"])
|
||||
def wrap_function(fn, signature, name=None):
|
||||
"""Wraps the TF 1.x function fn into a graph function.
|
||||
|
@ -26,6 +26,8 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -257,6 +259,154 @@ class WrapFunctionTest(test.TestCase):
|
||||
self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy())
|
||||
|
||||
|
||||
class WrappedGraphTest(test.TestCase):
|
||||
|
||||
def testAddFunction(self):
|
||||
|
||||
def fn(x):
|
||||
v = variables.Variable(3, name='v')
|
||||
v2 = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
|
||||
return v + v2 + x
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result = fn(constant_op.constant(5))
|
||||
sess.run(variables.global_variables_initializer())
|
||||
expected = sess.run(result)
|
||||
|
||||
g = wrap_function.WrappedGraph()
|
||||
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
|
||||
wrapped_fn = g.wrap_function(fn, signature)
|
||||
self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy())
|
||||
|
||||
def testCollections(self):
|
||||
|
||||
def fn(x):
|
||||
v = variables.VariableV1(3, name='v', trainable=False, collections=['a'])
|
||||
v2 = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32,
|
||||
collections=['a', 'b'])
|
||||
return v + v2 + x
|
||||
|
||||
def assert_collections(graph):
|
||||
self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1)
|
||||
self.assertLen(graph.get_collection('a'), 2)
|
||||
self.assertLen(graph.get_collection('b'), 1)
|
||||
|
||||
g = wrap_function.WrappedGraph()
|
||||
g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)])
|
||||
assert_collections(g.graph)
|
||||
|
||||
def assert_fn():
|
||||
assert_collections(ops.get_default_graph())
|
||||
return 1 # Return is required
|
||||
|
||||
# Assert that collections are accessible within a wrapped function.
|
||||
g.wrap_function(assert_fn, [])
|
||||
|
||||
def testShareVariablesSameGraph(self):
|
||||
|
||||
def add_v1(x):
|
||||
with variable_scope.variable_scope(
|
||||
'reuse', reuse=variable_scope.AUTO_REUSE):
|
||||
v = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32)
|
||||
return v + x
|
||||
|
||||
def subtract_v1(x):
|
||||
with variable_scope.variable_scope(
|
||||
'reuse', reuse=variable_scope.AUTO_REUSE):
|
||||
v = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32)
|
||||
return v - x
|
||||
|
||||
def different_variable_fn_v1(x):
|
||||
with variable_scope.variable_scope(
|
||||
'no_reuse', reuse=variable_scope.AUTO_REUSE):
|
||||
v = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32)
|
||||
return v * x
|
||||
|
||||
def increment_variable_v1(x):
|
||||
with variable_scope.variable_scope(
|
||||
'reuse', reuse=variable_scope.AUTO_REUSE):
|
||||
v = variable_scope.get_variable(
|
||||
'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32)
|
||||
return v.assign_add(x)
|
||||
|
||||
g = wrap_function.WrappedGraph()
|
||||
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
|
||||
add = g.wrap_function(add_v1, signature)
|
||||
subtract = g.wrap_function(subtract_v1, signature)
|
||||
different_variable_fn = g.wrap_function(different_variable_fn_v1, signature)
|
||||
increment_variable = g.wrap_function(increment_variable_v1, signature)
|
||||
|
||||
self.assertEqual(10, add(constant_op.constant(7)).numpy())
|
||||
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
|
||||
|
||||
# The shared variable has a starting value of 3 because add_v1 was wrapped
|
||||
# first.
|
||||
self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
|
||||
self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
|
||||
|
||||
# Check that variable updates
|
||||
self.assertEqual(17, add(constant_op.constant(7)).numpy())
|
||||
self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
|
||||
|
||||
# Sanity check - result from this function shouldn't change.
|
||||
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
|
||||
|
||||
self.assertAllEqual({'reuse/v:0', 'no_reuse/v:0'},
|
||||
set([v.name for v in g.variables]))
|
||||
|
||||
def testShareVariablesDifferentGraphs(self):
|
||||
|
||||
def add_v1(x):
|
||||
v = variables.Variable(3, name='v')
|
||||
return v + x
|
||||
|
||||
def subtract_v1(x):
|
||||
v = variables.Variable(4, name='v')
|
||||
return v - x
|
||||
|
||||
def different_variable_fn_v1(x):
|
||||
with ops.name_scope('different_scope'):
|
||||
v = variables.Variable(5, name='v')
|
||||
return v * x
|
||||
|
||||
def increment_variable_v1(x):
|
||||
v = variables.Variable(6, name='v')
|
||||
return v.assign_add(x)
|
||||
|
||||
signature = [tensor_spec.TensorSpec([], dtypes.int32)]
|
||||
vh = wrap_function.VariableHolder(share_variables=True)
|
||||
new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh)
|
||||
|
||||
add = new_graph().wrap_function(add_v1, signature)
|
||||
subtract = new_graph().wrap_function(subtract_v1, signature)
|
||||
different_variable_fn = new_graph().wrap_function(
|
||||
different_variable_fn_v1, signature)
|
||||
increment_variable = new_graph().wrap_function(
|
||||
increment_variable_v1, signature)
|
||||
|
||||
self.assertEqual(10, add(constant_op.constant(7)).numpy())
|
||||
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
|
||||
|
||||
# Because the variable in add_v1 was created first, its starting value is 3
|
||||
# instead of the values defined in subtract_v1 or increment_variable_v1.
|
||||
self.assertEqual(-4, subtract(constant_op.constant(7)).numpy())
|
||||
self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy())
|
||||
|
||||
# Check that variable updates
|
||||
self.assertEqual(17, add(constant_op.constant(7)).numpy())
|
||||
self.assertEqual(3, subtract(constant_op.constant(7)).numpy())
|
||||
|
||||
# Sanity check - result from this function shouldn't change.
|
||||
self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy())
|
||||
|
||||
self.assertAllEqual({'v:0', 'different_scope/v:0'},
|
||||
set([v.name for v in vh.variables]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user