From 049848467bb65e8c31c19781bed27845847772ff Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 11 Feb 2019 12:46:01 -0800 Subject: [PATCH] Allow init-from-placeholder in tf.function + v1-style Graphs PiperOrigin-RevId: 233459500 --- tensorflow/python/eager/def_function.py | 10 ++++++-- tensorflow/python/eager/def_function_test.py | 18 ++++++++++++++ tensorflow/python/eager/lift_to_graph.py | 26 +++++++++++++++----- tensorflow/python/framework/func_graph.py | 11 +++++++-- 4 files changed, 55 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index f5bb8cfd57a..7b8ceb979b2 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -96,7 +96,7 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ - if context.executing_eagerly(): + if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. resource_variable_ops.ResourceVariable.__init__( @@ -156,8 +156,14 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): if self._in_graph_mode: with ops.init_scope(): outer_graph = ops.get_default_graph() + func_graph = ops.get_default_graph() + function_placeholders = ( + func_graph.inputs + func_graph.internal_captures) + placeholder_ops = set( + [tensor.op for tensor in function_placeholders]) lifted_initializer = lift_to_graph.lift_to_graph( - initial_value, outer_graph)[initial_value] + initial_value, outer_graph, + disallowed_placeholders=placeholder_ops)[initial_value] with ops.init_scope(): self._initial_value = lifted_initializer with ops.name_scope("IsInitialized"): diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py index 6680f0b3e65..462aa8aa0a9 100644 --- a/tensorflow/python/eager/def_function_test.py +++ b/tensorflow/python/eager/def_function_test.py @@ -451,6 +451,24 @@ class DefFunctionTest(test.TestCase): func._decorate(decorator) self.assertEqual(func().numpy(), 2) + def testLiftPlaceholderInitializedVariable(self): + with ops.Graph().as_default(): + var_list = [] + + @def_function.function + def use_variable(): + if not var_list: + initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32) + v = variables.Variable(initial_value) + var_list.append(v) + return var_list[0] + 1. + + var_plus_one = use_variable() + with self.session() as session: + init_op = var_list[0].initializer + session.run(init_op, feed_dict={init_op.inputs[1]: 2.}) + self.assertEqual(3., session.run(var_plus_one)) + def testDecorate_rejectedAfterTrace(self): func = def_function.function(lambda: 1) self.assertEqual(func().numpy(), 1) diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index ad62e6d10ac..e48001eddc1 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -40,8 +40,24 @@ class UnliftableError(Exception): pass -def lift_to_graph(init_tensor, graph, sources=None): - """Copies the tensor and all its inputs recursively to the outer graph.""" +def lift_to_graph(init_tensor, graph, sources=None, + disallowed_placeholders=None): + """Copies the tensor and all its inputs recursively to the outer graph. + + Args: + init_tensor: The Tensor to lift. + graph: The graph to lift to. + sources: Optional sequence of nodes to start from. If omitted the whole + subgraph which feeds into `init_tensor` is lifted. + disallowed_placeholders: An optional set of ops which may not appear in the + lifted graph. Defaults to all placeholders. + + Returns: + A mapping from ops in the current default graph to ops in `graph`. + + Raises: + UnliftableError: If a placeholder blocks lifting. + """ # Check that the initializer does not depend on any placeholders. if sources is None: sources = set([]) @@ -53,10 +69,8 @@ def lift_to_graph(init_tensor, graph, sources=None): if op in visited_ops: continue visited_ops.add(op) - # TODO(apassos) distinguish arg placeholders, capture placeholders, - # and placeholders the user might directly use to initialize - # variables. - if op.type == "Placeholder": + if ((disallowed_placeholders is not None and op in disallowed_placeholders) + or (disallowed_placeholders is None and op.type == "Placeholder")): raise UnliftableError( "Unable to lift tensor", init_tensor, "because it depends transitively on placeholder ", op) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index f60cf9350cf..f83fc74cf91 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -545,12 +545,19 @@ def func_graph_from_py_func(name, convert_structure_to_signature(func_args, arg_names), convert_structure_to_signature(func_kwargs)) + flat_func_args = nest.flatten(func_args) + flat_func_kwargs = nest.flatten(func_kwargs) + # Temporarily set inputs to allow graph building code to inspect + # them. Reassigned below. + func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs + if isinstance(arg, ops.Tensor)] + # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects - func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) + func_args_before = nest.pack_sequence_as(func_args, flat_func_args) func_kwargs_before = nest.pack_sequence_as( - func_kwargs, nest.flatten(func_kwargs)) + func_kwargs, flat_func_kwargs) def convert(x): """Converts a function output to a Tensor."""