From 1c00380977893fefc1371906831bf43361eed0b1 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 21 Aug 2018 13:29:05 -0700 Subject: [PATCH] Remove 'graph' argument from FuncGraph.__init__. With the old behavior you'd have to supply a graph argument in order to inherit state from the eager context, which doesn't make much sense. To keep it simple, I removed the argument and made it unconditionally inherit state from the current context or default graph. In addition, this reuses the current graph's collections instead of copying them. This allows collections to be modified from inside a FuncGraph. This is necessary for cond_v2 (at least until collections go away?). PiperOrigin-RevId: 209649802 --- tensorflow/python/eager/function.py | 55 ++++++++++++------------- tensorflow/python/framework/ops_test.py | 27 ++++++++++++ 2 files changed, 54 insertions(+), 28 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 3f8dac0bd46..e04595f5ed8 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -101,7 +101,7 @@ class CapturingGraph(ops.Graph): The entries are in the order they were captured. """ - def __init__(self, graph=None): + def __init__(self): super(CapturingGraph, self).__init__() self.captures = collections.OrderedDict() @@ -197,15 +197,16 @@ class FuncGraph(CapturingGraph): seed: The graph-level random seed. """ - def __init__(self, name, graph=None): + def __init__(self, name): """Construct a new FuncGraph. + The graph will inherit its graph key, collections, seed, and distribution + strategy stack from the current context or graph. + Args: name: the name of the function. - graph: if specified, this FuncGraph will inherit its graph key, - collections, and seed from `graph`. """ - super(FuncGraph, self).__init__(graph=graph) + super(FuncGraph, self).__init__() self.name = name self.inputs = [] @@ -213,30 +214,29 @@ class FuncGraph(CapturingGraph): self.structured_outputs = None self.variables = [] - if graph is not None: + if context.executing_eagerly(): + self.seed = context.global_seed() + self._xla_compile = (context.context().device_spec.device_type == "TPU") + else: + graph = ops.get_default_graph() # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key # pylint: disable=protected-access + self.seed = graph.seed + self._xla_compile = getattr(graph, "_xla_compile", False) - # Copy the graph collections to ensure summaries and other things work. - # This lets the function access (but not mutate) collections of the - # containing graph, such as the global step and the summary writer - # collections. - for collection in graph.collections: - self.get_collection_ref(collection)[:] = graph.get_collection( - collection) - - # Copy distribution strategy scope from the containing graph as well. - self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access - - if context.executing_eagerly(): - self.seed = context.global_seed() - self._xla_compile = (context.context().device_spec.device_type == "TPU") - else: - self.seed = graph.seed - self._xla_compile = getattr(graph, "_xla_compile", False) - else: - self._xla_compile = False + graph = ops.get_default_graph() + # TODO(b/112165328, b/112906995): summaries depend on inheriting collections + # from the default graph even in eager mode. It'd be nice to not have a + # default graph with eager execution, so hopefully this will go away when we + # remove collections. + # pylint: disable=protected-access + self._collections = graph._collections + # TODO(b/112906995): distribution strategy depends on inheriting this stack + # from the default graph even in eager mode. Maybe it should be part of the + # eager context? + self._distribution_strategy_stack = graph._distribution_strategy_stack + # pylint: enable=protected-access def capture(self, tensor, name=None): """Calls CapturingGraph.capture and updates self.inputs if necessary.""" @@ -476,8 +476,7 @@ class GraphCallable(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" - backwards_graph = FuncGraph( - _backward_name(self._func_graph.name), self._func_graph) + backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) with backwards_graph.as_default(): gradients_wrt_outputs = [ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs @@ -697,7 +696,7 @@ def _func_graph_from_py_func(name, python_func, args, kwds, signature=None): TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ - func_graph = FuncGraph(name, graph=ops.get_default_graph()) + func_graph = FuncGraph(name) with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 318387c61b2..9144a38dae6 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1614,6 +1614,33 @@ class CollectionTest(test_util.TensorFlowTestCase): # Collections are ordered. self.assertEqual([90, 100], ops.get_collection("key")) + def test_defun(self): + with context.eager_mode(): + + @eager_function.defun + def defun(): + ops.add_to_collection("int", 1) + ops.add_to_collection("tensor", constant_op.constant(2)) + + @eager_function.defun + def inner_defun(): + self.assertEqual(ops.get_collection("int"), [1]) + three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] + ops.add_to_collection("int", 2) + self.assertEqual(ops.get_collection("int"), [1, 2]) + ops.add_to_collection("foo", "bar") + self.assertEqual(ops.get_collection("foo"), ["bar"]) + return three + + self.assertEqual(ops.get_collection("int"), [1]) + three = inner_defun() + self.assertEqual(ops.get_collection("int"), [1, 2]) + self.assertEqual(ops.get_collection("foo"), ["bar"]) + return three + + three = defun() + self.assertEqual(three.numpy(), 3) + ops.NotDifferentiable("FloatOutput")