Allow init-from-placeholder in tf.function + v1-style Graphs

PiperOrigin-RevId: 233459500
This commit is contained in:
Allen Lavoie 2019-02-11 12:46:01 -08:00 committed by TensorFlower Gardener
parent 4bdb229d2b
commit 049848467b
4 changed files with 55 additions and 10 deletions

View File

@ -96,7 +96,7 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
shape and `validate_shape` is `True`.
RuntimeError: If called outside of a function definition.
"""
if context.executing_eagerly():
if not ops.inside_function():
# If we've been init_scope()d out of the function definition nothing to do
# here; we can't really do the capturing or conditional logic.
resource_variable_ops.ResourceVariable.__init__(
@ -156,8 +156,14 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
if self._in_graph_mode:
with ops.init_scope():
outer_graph = ops.get_default_graph()
func_graph = ops.get_default_graph()
function_placeholders = (
func_graph.inputs + func_graph.internal_captures)
placeholder_ops = set(
[tensor.op for tensor in function_placeholders])
lifted_initializer = lift_to_graph.lift_to_graph(
initial_value, outer_graph)[initial_value]
initial_value, outer_graph,
disallowed_placeholders=placeholder_ops)[initial_value]
with ops.init_scope():
self._initial_value = lifted_initializer
with ops.name_scope("IsInitialized"):

View File

@ -451,6 +451,24 @@ class DefFunctionTest(test.TestCase):
func._decorate(decorator)
self.assertEqual(func().numpy(), 2)
def testLiftPlaceholderInitializedVariable(self):
with ops.Graph().as_default():
var_list = []
@def_function.function
def use_variable():
if not var_list:
initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32)
v = variables.Variable(initial_value)
var_list.append(v)
return var_list[0] + 1.
var_plus_one = use_variable()
with self.session() as session:
init_op = var_list[0].initializer
session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
self.assertEqual(3., session.run(var_plus_one))
def testDecorate_rejectedAfterTrace(self):
func = def_function.function(lambda: 1)
self.assertEqual(func().numpy(), 1)

View File

@ -40,8 +40,24 @@ class UnliftableError(Exception):
pass
def lift_to_graph(init_tensor, graph, sources=None):
"""Copies the tensor and all its inputs recursively to the outer graph."""
def lift_to_graph(init_tensor, graph, sources=None,
disallowed_placeholders=None):
"""Copies the tensor and all its inputs recursively to the outer graph.
Args:
init_tensor: The Tensor to lift.
graph: The graph to lift to.
sources: Optional sequence of nodes to start from. If omitted the whole
subgraph which feeds into `init_tensor` is lifted.
disallowed_placeholders: An optional set of ops which may not appear in the
lifted graph. Defaults to all placeholders.
Returns:
A mapping from ops in the current default graph to ops in `graph`.
Raises:
UnliftableError: If a placeholder blocks lifting.
"""
# Check that the initializer does not depend on any placeholders.
if sources is None:
sources = set([])
@ -53,10 +69,8 @@ def lift_to_graph(init_tensor, graph, sources=None):
if op in visited_ops:
continue
visited_ops.add(op)
# TODO(apassos) distinguish arg placeholders, capture placeholders,
# and placeholders the user might directly use to initialize
# variables.
if op.type == "Placeholder":
if ((disallowed_placeholders is not None and op in disallowed_placeholders)
or (disallowed_placeholders is None and op.type == "Placeholder")):
raise UnliftableError(
"Unable to lift tensor", init_tensor,
"because it depends transitively on placeholder ", op)

View File

@ -545,12 +545,19 @@ def func_graph_from_py_func(name,
convert_structure_to_signature(func_args, arg_names),
convert_structure_to_signature(func_kwargs))
flat_func_args = nest.flatten(func_args)
flat_func_kwargs = nest.flatten(func_kwargs)
# Temporarily set inputs to allow graph building code to inspect
# them. Reassigned below.
func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
if isinstance(arg, ops.Tensor)]
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
func_args_before = nest.pack_sequence_as(func_args, flat_func_args)
func_kwargs_before = nest.pack_sequence_as(
func_kwargs, nest.flatten(func_kwargs))
func_kwargs, flat_func_kwargs)
def convert(x):
"""Converts a function output to a Tensor."""