From 50713af4c39f0740e0cf4cea56bdff2b2347f540 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Wed, 13 Mar 2019 17:30:36 -0700 Subject: [PATCH] Convert multiple v1 functions into ConcreteFunctions while sharing variables. Uses name-based matching to reuse variables. PiperOrigin-RevId: 238345493 --- tensorflow/python/eager/wrap_function.py | 128 +++++++++++++-- tensorflow/python/eager/wrap_function_test.py | 150 ++++++++++++++++++ 2 files changed, 268 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index 8d42cc15ba0..b4ece94848c 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -37,21 +37,40 @@ from tensorflow.python.util.tf_export import tf_export class VariableHolder(object): """Holds variables for a python function.""" - def __init__(self, fn): + def __init__(self, fn=None, share_variables=False): self._fn = fn + self._variables = [] + self._share_variables = share_variables + self._variables_by_name = {} + + @property + def variables(self): + return self._variables + def variable_creator_scope(self, next_creator, **kwargs): """Creates variables & adds them to collections to match legacy code.""" - v = next_creator(**kwargs) - self._variables.append(v) + collections = kwargs.pop("collections", None) + v = None - collections = kwargs.get("collections") - trainable = v.trainable + # Get expected variable name. + name = kwargs.get("name", None) + with ops.name_scope(name, "Variable") as name_scope: + name = name_scope + + if self._share_variables: + v = self._variables_by_name.get(name, None) + + if v is None: + v = next_creator(**kwargs) + self._variables.append(v) + if self._share_variables: + self._variables_by_name[name] = v if collections is None: collections = [ops.GraphKeys.GLOBAL_VARIABLES] - if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: + if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] ops.add_to_collections(collections, v) @@ -59,8 +78,13 @@ class VariableHolder(object): return v def __call__(self, *args, **kwargs): - with variable_scope.variable_creator_scope(self.variable_creator_scope): - return self._fn(*args, **kwargs) + return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs) + + def call_with_variable_creator_scope(self, fn): + def wrapped(*args, **kwargs): + with variable_scope.variable_creator_scope(self.variable_creator_scope): + return fn(*args, **kwargs) + return wrapped # TODO(allenl): make this trackable @@ -120,7 +144,8 @@ class WrappedFunction(function.ConcreteFunction): for index, current in enumerate(mutable_collection): mutable_collection[index] = lifted_variables.get(current, current) - def prune(self, feeds, fetches): + def prune(self, feeds, fetches, name=None): + name = name or "pruned" flat_feeds, flat_fetches = nest.flatten(feeds), nest.flatten(fetches) for f in flat_feeds: if not isinstance(f, ops.Tensor): @@ -148,7 +173,7 @@ class WrappedFunction(function.ConcreteFunction): "are from this graph (%s). Tensor %s from graph %s" % ( self._func_graph, f, f.graph)) with self._func_graph.as_default(): - pruned_graph = func_graph.FuncGraph("pruned") + pruned_graph = func_graph.FuncGraph(name) with ops.control_dependencies(operation_fetches): if tensor_fetches: identity_fetches = array_ops.identity_n(tensor_fetches) @@ -187,6 +212,89 @@ class WrappedFunction(function.ConcreteFunction): return pruned_fn +class WrappedGraph(object): + """Class for wrapping multiple TF 1.X functions in a single graph. + + Maintains a dictionary mapping names to wrapped functions. See + `tf.compat.v1.wrap_function` to learn more about wrapping V1 functions. + + Functions wrapped using this class have access to variables and collections + created in other wrapped functions, using the standard TF 1.X API ( + `tf.compat.v1.get_variable` or + `tf.compat.v1.get_default_graph().get_collection(...)`) + + Outside a function, variables and collections may be accessed using the + `variables` and `graph` properties. + + Example: + + ``` + def add_v1(x): + with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE): + v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) + return v + x + + def increment_var_v1(x): + with tf.compat.v1.variable_scope('vars', reuse=tf.AUTO_REUSE): + v = tf.compat.v1.get_variable('v', shape=[], dtype=tf.int32) + return v.assign_add(x) + + g = WrappedGraph() + add = g.wrap_function(add_v1, [tf.TensorSpec([], tf.int32)]) + increment_var = g.wrap_function(increment_var_v1, + [tf.TensorSpec([], tf.int32)]) + + assert len(g.variables) == 1 + assert g.variables[0].numpy() == 0 + increment_var(tf.constant(5)) + assert g.variables[0].numpy() == 5 + + ``` + """ + + def __init__(self, variable_holder=None, **kwargs): + self._variable_holder = ( + variable_holder or VariableHolder(share_variables=True)) + + name = kwargs.pop("name", "wrapped_function_graph") + # Always start with empty collections, unless otherwise specified. Setting + # `collections=None` will copy the collections from the outer graph. + collections = kwargs.pop("collections", {}) + self.graph = func_graph.FuncGraph(name, collections=collections, **kwargs) + + self._wrapped_function = WrappedFunction(self.graph, self._variable_holder) + self._functions = {} + + @property + def functions(self): + return self._functions + + @property + def variables(self): + return self._variable_holder.variables + + def wrap_function(self, fn, signature, name=None): + """Wrap a TF 1.X function and save to functions dictionary.""" + func_graph.func_graph_from_py_func( + None, # Name is unused. + self._variable_holder.call_with_variable_creator_scope(fn), + args=None, kwargs=None, signature=signature, + add_control_dependencies=False, + func_graph=self.graph) + + # This code relies on questional behavior from `func_graph_from_py_func`. + # If an existing FuncGraph is passed into the `func_graph` arg, the inputs + # and structured outputs are overwritten. Pretty sure this is a bug, + # because structured outputs doesn't match up with the outputs... + fn_inputs = self.graph.inputs[:-len(self.graph.captures)] + fn_outputs = self.graph.structured_outputs + + wrapped_function = self._wrapped_function.prune(fn_inputs, fn_outputs) + name = name or fn.__name__ + self._functions[name] = wrapped_function + return wrapped_function + + @tf_export(v1=["wrap_function"]) def wrap_function(fn, signature, name=None): """Wraps the TF 1.x function fn into a graph function. diff --git a/tensorflow/python/eager/wrap_function_test.py b/tensorflow/python/eager/wrap_function_test.py index 79404ce563c..fa3d5823d9d 100644 --- a/tensorflow/python/eager/wrap_function_test.py +++ b/tensorflow/python/eager/wrap_function_test.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -257,6 +259,154 @@ class WrapFunctionTest(test.TestCase): self.assertEqual(2., revived_function(constant_op.constant(1.)).numpy()) +class WrappedGraphTest(test.TestCase): + + def testAddFunction(self): + + def fn(x): + v = variables.Variable(3, name='v') + v2 = variable_scope.get_variable( + 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) + return v + v2 + x + + with self.cached_session() as sess: + result = fn(constant_op.constant(5)) + sess.run(variables.global_variables_initializer()) + expected = sess.run(result) + + g = wrap_function.WrappedGraph() + signature = [tensor_spec.TensorSpec([], dtypes.int32)] + wrapped_fn = g.wrap_function(fn, signature) + self.assertEqual(expected, wrapped_fn(constant_op.constant(5)).numpy()) + + def testCollections(self): + + def fn(x): + v = variables.VariableV1(3, name='v', trainable=False, collections=['a']) + v2 = variable_scope.get_variable( + 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32, + collections=['a', 'b']) + return v + v2 + x + + def assert_collections(graph): + self.assertLen(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES), 1) + self.assertLen(graph.get_collection('a'), 2) + self.assertLen(graph.get_collection('b'), 1) + + g = wrap_function.WrappedGraph() + g.wrap_function(fn, [tensor_spec.TensorSpec([], dtypes.int32)]) + assert_collections(g.graph) + + def assert_fn(): + assert_collections(ops.get_default_graph()) + return 1 # Return is required + + # Assert that collections are accessible within a wrapped function. + g.wrap_function(assert_fn, []) + + def testShareVariablesSameGraph(self): + + def add_v1(x): + with variable_scope.variable_scope( + 'reuse', reuse=variable_scope.AUTO_REUSE): + v = variable_scope.get_variable( + 'v', initializer=init_ops.Constant(3), shape=[], dtype=dtypes.int32) + return v + x + + def subtract_v1(x): + with variable_scope.variable_scope( + 'reuse', reuse=variable_scope.AUTO_REUSE): + v = variable_scope.get_variable( + 'v', initializer=init_ops.Constant(4), shape=[], dtype=dtypes.int32) + return v - x + + def different_variable_fn_v1(x): + with variable_scope.variable_scope( + 'no_reuse', reuse=variable_scope.AUTO_REUSE): + v = variable_scope.get_variable( + 'v', initializer=init_ops.Constant(5), shape=[], dtype=dtypes.int32) + return v * x + + def increment_variable_v1(x): + with variable_scope.variable_scope( + 'reuse', reuse=variable_scope.AUTO_REUSE): + v = variable_scope.get_variable( + 'v', initializer=init_ops.Constant(6), shape=[], dtype=dtypes.int32) + return v.assign_add(x) + + g = wrap_function.WrappedGraph() + signature = [tensor_spec.TensorSpec([], dtypes.int32)] + add = g.wrap_function(add_v1, signature) + subtract = g.wrap_function(subtract_v1, signature) + different_variable_fn = g.wrap_function(different_variable_fn_v1, signature) + increment_variable = g.wrap_function(increment_variable_v1, signature) + + self.assertEqual(10, add(constant_op.constant(7)).numpy()) + self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) + + # The shared variable has a starting value of 3 because add_v1 was wrapped + # first. + self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) + self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) + + # Check that variable updates + self.assertEqual(17, add(constant_op.constant(7)).numpy()) + self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) + + # Sanity check - result from this function shouldn't change. + self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) + + self.assertAllEqual({'reuse/v:0', 'no_reuse/v:0'}, + set([v.name for v in g.variables])) + + def testShareVariablesDifferentGraphs(self): + + def add_v1(x): + v = variables.Variable(3, name='v') + return v + x + + def subtract_v1(x): + v = variables.Variable(4, name='v') + return v - x + + def different_variable_fn_v1(x): + with ops.name_scope('different_scope'): + v = variables.Variable(5, name='v') + return v * x + + def increment_variable_v1(x): + v = variables.Variable(6, name='v') + return v.assign_add(x) + + signature = [tensor_spec.TensorSpec([], dtypes.int32)] + vh = wrap_function.VariableHolder(share_variables=True) + new_graph = lambda: wrap_function.WrappedGraph(variable_holder=vh) + + add = new_graph().wrap_function(add_v1, signature) + subtract = new_graph().wrap_function(subtract_v1, signature) + different_variable_fn = new_graph().wrap_function( + different_variable_fn_v1, signature) + increment_variable = new_graph().wrap_function( + increment_variable_v1, signature) + + self.assertEqual(10, add(constant_op.constant(7)).numpy()) + self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) + + # Because the variable in add_v1 was created first, its starting value is 3 + # instead of the values defined in subtract_v1 or increment_variable_v1. + self.assertEqual(-4, subtract(constant_op.constant(7)).numpy()) + self.assertEqual(10, increment_variable(constant_op.constant(7)).numpy()) + + # Check that variable updates + self.assertEqual(17, add(constant_op.constant(7)).numpy()) + self.assertEqual(3, subtract(constant_op.constant(7)).numpy()) + + # Sanity check - result from this function shouldn't change. + self.assertEqual(35, different_variable_fn(constant_op.constant(7)).numpy()) + + self.assertAllEqual({'v:0', 'different_scope/v:0'}, + set([v.name for v in vh.variables])) + if __name__ == '__main__': ops.enable_eager_execution() test.main()