diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 3f8dac0bd46..e04595f5ed8 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -101,7 +101,7 @@ class CapturingGraph(ops.Graph): The entries are in the order they were captured. """ - def __init__(self, graph=None): + def __init__(self): super(CapturingGraph, self).__init__() self.captures = collections.OrderedDict() @@ -197,15 +197,16 @@ class FuncGraph(CapturingGraph): seed: The graph-level random seed. """ - def __init__(self, name, graph=None): + def __init__(self, name): """Construct a new FuncGraph. + The graph will inherit its graph key, collections, seed, and distribution + strategy stack from the current context or graph. + Args: name: the name of the function. - graph: if specified, this FuncGraph will inherit its graph key, - collections, and seed from `graph`. """ - super(FuncGraph, self).__init__(graph=graph) + super(FuncGraph, self).__init__() self.name = name self.inputs = [] @@ -213,30 +214,29 @@ class FuncGraph(CapturingGraph): self.structured_outputs = None self.variables = [] - if graph is not None: + if context.executing_eagerly(): + self.seed = context.global_seed() + self._xla_compile = (context.context().device_spec.device_type == "TPU") + else: + graph = ops.get_default_graph() # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key # pylint: disable=protected-access + self.seed = graph.seed + self._xla_compile = getattr(graph, "_xla_compile", False) - # Copy the graph collections to ensure summaries and other things work. - # This lets the function access (but not mutate) collections of the - # containing graph, such as the global step and the summary writer - # collections. - for collection in graph.collections: - self.get_collection_ref(collection)[:] = graph.get_collection( - collection) - - # Copy distribution strategy scope from the containing graph as well. - self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access - - if context.executing_eagerly(): - self.seed = context.global_seed() - self._xla_compile = (context.context().device_spec.device_type == "TPU") - else: - self.seed = graph.seed - self._xla_compile = getattr(graph, "_xla_compile", False) - else: - self._xla_compile = False + graph = ops.get_default_graph() + # TODO(b/112165328, b/112906995): summaries depend on inheriting collections + # from the default graph even in eager mode. It'd be nice to not have a + # default graph with eager execution, so hopefully this will go away when we + # remove collections. + # pylint: disable=protected-access + self._collections = graph._collections + # TODO(b/112906995): distribution strategy depends on inheriting this stack + # from the default graph even in eager mode. Maybe it should be part of the + # eager context? + self._distribution_strategy_stack = graph._distribution_strategy_stack + # pylint: enable=protected-access def capture(self, tensor, name=None): """Calls CapturingGraph.capture and updates self.inputs if necessary.""" @@ -476,8 +476,7 @@ class GraphCallable(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" - backwards_graph = FuncGraph( - _backward_name(self._func_graph.name), self._func_graph) + backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) with backwards_graph.as_default(): gradients_wrt_outputs = [ graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs @@ -697,7 +696,7 @@ def _func_graph_from_py_func(name, python_func, args, kwds, signature=None): TypeError: If any of `python_func`'s return values is neither `None` nor a `Tensor`. """ - func_graph = FuncGraph(name, graph=ops.get_default_graph()) + func_graph = FuncGraph(name) with func_graph.as_default(), AutomaticControlDependencies() as a: variable_scope.get_variable_scope().set_use_resource(True) diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 318387c61b2..9144a38dae6 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -1614,6 +1614,33 @@ class CollectionTest(test_util.TensorFlowTestCase): # Collections are ordered. self.assertEqual([90, 100], ops.get_collection("key")) + def test_defun(self): + with context.eager_mode(): + + @eager_function.defun + def defun(): + ops.add_to_collection("int", 1) + ops.add_to_collection("tensor", constant_op.constant(2)) + + @eager_function.defun + def inner_defun(): + self.assertEqual(ops.get_collection("int"), [1]) + three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] + ops.add_to_collection("int", 2) + self.assertEqual(ops.get_collection("int"), [1, 2]) + ops.add_to_collection("foo", "bar") + self.assertEqual(ops.get_collection("foo"), ["bar"]) + return three + + self.assertEqual(ops.get_collection("int"), [1]) + three = inner_defun() + self.assertEqual(ops.get_collection("int"), [1, 2]) + self.assertEqual(ops.get_collection("foo"), ["bar"]) + return three + + three = defun() + self.assertEqual(three.numpy(), 3) + ops.NotDifferentiable("FloatOutput")