Allow init-from-placeholder in tf.function + v1-style Graphs
PiperOrigin-RevId: 233459500
This commit is contained in:
parent
4bdb229d2b
commit
049848467b
@ -96,7 +96,7 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
|
||||
shape and `validate_shape` is `True`.
|
||||
RuntimeError: If called outside of a function definition.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
if not ops.inside_function():
|
||||
# If we've been init_scope()d out of the function definition nothing to do
|
||||
# here; we can't really do the capturing or conditional logic.
|
||||
resource_variable_ops.ResourceVariable.__init__(
|
||||
@ -156,8 +156,14 @@ class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
|
||||
if self._in_graph_mode:
|
||||
with ops.init_scope():
|
||||
outer_graph = ops.get_default_graph()
|
||||
func_graph = ops.get_default_graph()
|
||||
function_placeholders = (
|
||||
func_graph.inputs + func_graph.internal_captures)
|
||||
placeholder_ops = set(
|
||||
[tensor.op for tensor in function_placeholders])
|
||||
lifted_initializer = lift_to_graph.lift_to_graph(
|
||||
initial_value, outer_graph)[initial_value]
|
||||
initial_value, outer_graph,
|
||||
disallowed_placeholders=placeholder_ops)[initial_value]
|
||||
with ops.init_scope():
|
||||
self._initial_value = lifted_initializer
|
||||
with ops.name_scope("IsInitialized"):
|
||||
|
@ -451,6 +451,24 @@ class DefFunctionTest(test.TestCase):
|
||||
func._decorate(decorator)
|
||||
self.assertEqual(func().numpy(), 2)
|
||||
|
||||
def testLiftPlaceholderInitializedVariable(self):
|
||||
with ops.Graph().as_default():
|
||||
var_list = []
|
||||
|
||||
@def_function.function
|
||||
def use_variable():
|
||||
if not var_list:
|
||||
initial_value = array_ops.placeholder(shape=[], dtype=dtypes.float32)
|
||||
v = variables.Variable(initial_value)
|
||||
var_list.append(v)
|
||||
return var_list[0] + 1.
|
||||
|
||||
var_plus_one = use_variable()
|
||||
with self.session() as session:
|
||||
init_op = var_list[0].initializer
|
||||
session.run(init_op, feed_dict={init_op.inputs[1]: 2.})
|
||||
self.assertEqual(3., session.run(var_plus_one))
|
||||
|
||||
def testDecorate_rejectedAfterTrace(self):
|
||||
func = def_function.function(lambda: 1)
|
||||
self.assertEqual(func().numpy(), 1)
|
||||
|
@ -40,8 +40,24 @@ class UnliftableError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def lift_to_graph(init_tensor, graph, sources=None):
|
||||
"""Copies the tensor and all its inputs recursively to the outer graph."""
|
||||
def lift_to_graph(init_tensor, graph, sources=None,
|
||||
disallowed_placeholders=None):
|
||||
"""Copies the tensor and all its inputs recursively to the outer graph.
|
||||
|
||||
Args:
|
||||
init_tensor: The Tensor to lift.
|
||||
graph: The graph to lift to.
|
||||
sources: Optional sequence of nodes to start from. If omitted the whole
|
||||
subgraph which feeds into `init_tensor` is lifted.
|
||||
disallowed_placeholders: An optional set of ops which may not appear in the
|
||||
lifted graph. Defaults to all placeholders.
|
||||
|
||||
Returns:
|
||||
A mapping from ops in the current default graph to ops in `graph`.
|
||||
|
||||
Raises:
|
||||
UnliftableError: If a placeholder blocks lifting.
|
||||
"""
|
||||
# Check that the initializer does not depend on any placeholders.
|
||||
if sources is None:
|
||||
sources = set([])
|
||||
@ -53,10 +69,8 @@ def lift_to_graph(init_tensor, graph, sources=None):
|
||||
if op in visited_ops:
|
||||
continue
|
||||
visited_ops.add(op)
|
||||
# TODO(apassos) distinguish arg placeholders, capture placeholders,
|
||||
# and placeholders the user might directly use to initialize
|
||||
# variables.
|
||||
if op.type == "Placeholder":
|
||||
if ((disallowed_placeholders is not None and op in disallowed_placeholders)
|
||||
or (disallowed_placeholders is None and op.type == "Placeholder")):
|
||||
raise UnliftableError(
|
||||
"Unable to lift tensor", init_tensor,
|
||||
"because it depends transitively on placeholder ", op)
|
||||
|
@ -545,12 +545,19 @@ def func_graph_from_py_func(name,
|
||||
convert_structure_to_signature(func_args, arg_names),
|
||||
convert_structure_to_signature(func_kwargs))
|
||||
|
||||
flat_func_args = nest.flatten(func_args)
|
||||
flat_func_kwargs = nest.flatten(func_kwargs)
|
||||
# Temporarily set inputs to allow graph building code to inspect
|
||||
# them. Reassigned below.
|
||||
func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
|
||||
if isinstance(arg, ops.Tensor)]
|
||||
|
||||
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
|
||||
# Variables to help check whether mutation happens in calling the function
|
||||
# Copy the recursive list, tuple and map structure, but not base objects
|
||||
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
|
||||
func_args_before = nest.pack_sequence_as(func_args, flat_func_args)
|
||||
func_kwargs_before = nest.pack_sequence_as(
|
||||
func_kwargs, nest.flatten(func_kwargs))
|
||||
func_kwargs, flat_func_kwargs)
|
||||
|
||||
def convert(x):
|
||||
"""Converts a function output to a Tensor."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user