Remove 'graph' argument from FuncGraph.__init__.

With the old behavior you'd have to supply a graph argument in order
to inherit state from the eager context, which doesn't make much
sense. To keep it simple, I removed the argument and made it
unconditionally inherit state from the current context or default
graph.

In addition, this reuses the current graph's collections instead of
copying them. This allows collections to be modified from inside a
FuncGraph. This is necessary for cond_v2 (at least until collections
go away?).

PiperOrigin-RevId: 209649802
This commit is contained in:
Skye Wanderman-Milne 2018-08-21 13:29:05 -07:00 committed by TensorFlower Gardener
parent 75adfa5ca1
commit 1c00380977
2 changed files with 54 additions and 28 deletions

View File

@ -101,7 +101,7 @@ class CapturingGraph(ops.Graph):
The entries are in the order they were captured.
"""
def __init__(self, graph=None):
def __init__(self):
super(CapturingGraph, self).__init__()
self.captures = collections.OrderedDict()
@ -197,15 +197,16 @@ class FuncGraph(CapturingGraph):
seed: The graph-level random seed.
"""
def __init__(self, name, graph=None):
def __init__(self, name):
"""Construct a new FuncGraph.
The graph will inherit its graph key, collections, seed, and distribution
strategy stack from the current context or graph.
Args:
name: the name of the function.
graph: if specified, this FuncGraph will inherit its graph key,
collections, and seed from `graph`.
"""
super(FuncGraph, self).__init__(graph=graph)
super(FuncGraph, self).__init__()
self.name = name
self.inputs = []
@ -213,30 +214,29 @@ class FuncGraph(CapturingGraph):
self.structured_outputs = None
self.variables = []
if graph is not None:
# Inherit the graph key, since this is used for matching variables in
# optimizers.
self._graph_key = graph._graph_key # pylint: disable=protected-access
# Copy the graph collections to ensure summaries and other things work.
# This lets the function access (but not mutate) collections of the
# containing graph, such as the global step and the summary writer
# collections.
for collection in graph.collections:
self.get_collection_ref(collection)[:] = graph.get_collection(
collection)
# Copy distribution strategy scope from the containing graph as well.
self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
if context.executing_eagerly():
self.seed = context.global_seed()
self._xla_compile = (context.context().device_spec.device_type == "TPU")
else:
graph = ops.get_default_graph()
# Inherit the graph key, since this is used for matching variables in
# optimizers.
self._graph_key = graph._graph_key # pylint: disable=protected-access
self.seed = graph.seed
self._xla_compile = getattr(graph, "_xla_compile", False)
else:
self._xla_compile = False
graph = ops.get_default_graph()
# TODO(b/112165328, b/112906995): summaries depend on inheriting collections
# from the default graph even in eager mode. It'd be nice to not have a
# default graph with eager execution, so hopefully this will go away when we
# remove collections.
# pylint: disable=protected-access
self._collections = graph._collections
# TODO(b/112906995): distribution strategy depends on inheriting this stack
# from the default graph even in eager mode. Maybe it should be part of the
# eager context?
self._distribution_strategy_stack = graph._distribution_strategy_stack
# pylint: enable=protected-access
def capture(self, tensor, name=None):
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
@ -476,8 +476,7 @@ class GraphCallable(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(
_backward_name(self._func_graph.name), self._func_graph)
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
with backwards_graph.as_default():
gradients_wrt_outputs = [
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
@ -697,7 +696,7 @@ def _func_graph_from_py_func(name, python_func, args, kwds, signature=None):
TypeError: If any of `python_func`'s return values is neither `None` nor a
`Tensor`.
"""
func_graph = FuncGraph(name, graph=ops.get_default_graph())
func_graph = FuncGraph(name)
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)

View File

@ -1614,6 +1614,33 @@ class CollectionTest(test_util.TensorFlowTestCase):
# Collections are ordered.
self.assertEqual([90, 100], ops.get_collection("key"))
def test_defun(self):
with context.eager_mode():
@eager_function.defun
def defun():
ops.add_to_collection("int", 1)
ops.add_to_collection("tensor", constant_op.constant(2))
@eager_function.defun
def inner_defun():
self.assertEqual(ops.get_collection("int"), [1])
three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
ops.add_to_collection("int", 2)
self.assertEqual(ops.get_collection("int"), [1, 2])
ops.add_to_collection("foo", "bar")
self.assertEqual(ops.get_collection("foo"), ["bar"])
return three
self.assertEqual(ops.get_collection("int"), [1])
three = inner_defun()
self.assertEqual(ops.get_collection("int"), [1, 2])
self.assertEqual(ops.get_collection("foo"), ["bar"])
return three
three = defun()
self.assertEqual(three.numpy(), 3)
ops.NotDifferentiable("FloatOutput")