Remove 'graph' argument from FuncGraph.__init__.
With the old behavior you'd have to supply a graph argument in order to inherit state from the eager context, which doesn't make much sense. To keep it simple, I removed the argument and made it unconditionally inherit state from the current context or default graph. In addition, this reuses the current graph's collections instead of copying them. This allows collections to be modified from inside a FuncGraph. This is necessary for cond_v2 (at least until collections go away?). PiperOrigin-RevId: 209649802
This commit is contained in:
parent
75adfa5ca1
commit
1c00380977
@ -101,7 +101,7 @@ class CapturingGraph(ops.Graph):
|
||||
The entries are in the order they were captured.
|
||||
"""
|
||||
|
||||
def __init__(self, graph=None):
|
||||
def __init__(self):
|
||||
super(CapturingGraph, self).__init__()
|
||||
|
||||
self.captures = collections.OrderedDict()
|
||||
@ -197,15 +197,16 @@ class FuncGraph(CapturingGraph):
|
||||
seed: The graph-level random seed.
|
||||
"""
|
||||
|
||||
def __init__(self, name, graph=None):
|
||||
def __init__(self, name):
|
||||
"""Construct a new FuncGraph.
|
||||
|
||||
The graph will inherit its graph key, collections, seed, and distribution
|
||||
strategy stack from the current context or graph.
|
||||
|
||||
Args:
|
||||
name: the name of the function.
|
||||
graph: if specified, this FuncGraph will inherit its graph key,
|
||||
collections, and seed from `graph`.
|
||||
"""
|
||||
super(FuncGraph, self).__init__(graph=graph)
|
||||
super(FuncGraph, self).__init__()
|
||||
|
||||
self.name = name
|
||||
self.inputs = []
|
||||
@ -213,30 +214,29 @@ class FuncGraph(CapturingGraph):
|
||||
self.structured_outputs = None
|
||||
self.variables = []
|
||||
|
||||
if graph is not None:
|
||||
# Inherit the graph key, since this is used for matching variables in
|
||||
# optimizers.
|
||||
self._graph_key = graph._graph_key # pylint: disable=protected-access
|
||||
|
||||
# Copy the graph collections to ensure summaries and other things work.
|
||||
# This lets the function access (but not mutate) collections of the
|
||||
# containing graph, such as the global step and the summary writer
|
||||
# collections.
|
||||
for collection in graph.collections:
|
||||
self.get_collection_ref(collection)[:] = graph.get_collection(
|
||||
collection)
|
||||
|
||||
# Copy distribution strategy scope from the containing graph as well.
|
||||
self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
|
||||
|
||||
if context.executing_eagerly():
|
||||
self.seed = context.global_seed()
|
||||
self._xla_compile = (context.context().device_spec.device_type == "TPU")
|
||||
else:
|
||||
graph = ops.get_default_graph()
|
||||
# Inherit the graph key, since this is used for matching variables in
|
||||
# optimizers.
|
||||
self._graph_key = graph._graph_key # pylint: disable=protected-access
|
||||
self.seed = graph.seed
|
||||
self._xla_compile = getattr(graph, "_xla_compile", False)
|
||||
else:
|
||||
self._xla_compile = False
|
||||
|
||||
graph = ops.get_default_graph()
|
||||
# TODO(b/112165328, b/112906995): summaries depend on inheriting collections
|
||||
# from the default graph even in eager mode. It'd be nice to not have a
|
||||
# default graph with eager execution, so hopefully this will go away when we
|
||||
# remove collections.
|
||||
# pylint: disable=protected-access
|
||||
self._collections = graph._collections
|
||||
# TODO(b/112906995): distribution strategy depends on inheriting this stack
|
||||
# from the default graph even in eager mode. Maybe it should be part of the
|
||||
# eager context?
|
||||
self._distribution_strategy_stack = graph._distribution_strategy_stack
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def capture(self, tensor, name=None):
|
||||
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
|
||||
@ -476,8 +476,7 @@ class GraphCallable(object):
|
||||
|
||||
def _construct_backprop_function(self):
|
||||
"""Constructs the backprop function object for this function."""
|
||||
backwards_graph = FuncGraph(
|
||||
_backward_name(self._func_graph.name), self._func_graph)
|
||||
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
|
||||
with backwards_graph.as_default():
|
||||
gradients_wrt_outputs = [
|
||||
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
|
||||
@ -697,7 +696,7 @@ def _func_graph_from_py_func(name, python_func, args, kwds, signature=None):
|
||||
TypeError: If any of `python_func`'s return values is neither `None` nor a
|
||||
`Tensor`.
|
||||
"""
|
||||
func_graph = FuncGraph(name, graph=ops.get_default_graph())
|
||||
func_graph = FuncGraph(name)
|
||||
with func_graph.as_default(), AutomaticControlDependencies() as a:
|
||||
variable_scope.get_variable_scope().set_use_resource(True)
|
||||
|
||||
|
@ -1614,6 +1614,33 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
||||
# Collections are ordered.
|
||||
self.assertEqual([90, 100], ops.get_collection("key"))
|
||||
|
||||
def test_defun(self):
|
||||
with context.eager_mode():
|
||||
|
||||
@eager_function.defun
|
||||
def defun():
|
||||
ops.add_to_collection("int", 1)
|
||||
ops.add_to_collection("tensor", constant_op.constant(2))
|
||||
|
||||
@eager_function.defun
|
||||
def inner_defun():
|
||||
self.assertEqual(ops.get_collection("int"), [1])
|
||||
three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
|
||||
ops.add_to_collection("int", 2)
|
||||
self.assertEqual(ops.get_collection("int"), [1, 2])
|
||||
ops.add_to_collection("foo", "bar")
|
||||
self.assertEqual(ops.get_collection("foo"), ["bar"])
|
||||
return three
|
||||
|
||||
self.assertEqual(ops.get_collection("int"), [1])
|
||||
three = inner_defun()
|
||||
self.assertEqual(ops.get_collection("int"), [1, 2])
|
||||
self.assertEqual(ops.get_collection("foo"), ["bar"])
|
||||
return three
|
||||
|
||||
three = defun()
|
||||
self.assertEqual(three.numpy(), 3)
|
||||
|
||||
|
||||
ops.NotDifferentiable("FloatOutput")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user