Remove 'graph' argument from FuncGraph.__init__.
With the old behavior you'd have to supply a graph argument in order to inherit state from the eager context, which doesn't make much sense. To keep it simple, I removed the argument and made it unconditionally inherit state from the current context or default graph. In addition, this reuses the current graph's collections instead of copying them. This allows collections to be modified from inside a FuncGraph. This is necessary for cond_v2 (at least until collections go away?). PiperOrigin-RevId: 209649802
This commit is contained in:
parent
75adfa5ca1
commit
1c00380977
@ -101,7 +101,7 @@ class CapturingGraph(ops.Graph):
|
|||||||
The entries are in the order they were captured.
|
The entries are in the order they were captured.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, graph=None):
|
def __init__(self):
|
||||||
super(CapturingGraph, self).__init__()
|
super(CapturingGraph, self).__init__()
|
||||||
|
|
||||||
self.captures = collections.OrderedDict()
|
self.captures = collections.OrderedDict()
|
||||||
@ -197,15 +197,16 @@ class FuncGraph(CapturingGraph):
|
|||||||
seed: The graph-level random seed.
|
seed: The graph-level random seed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, graph=None):
|
def __init__(self, name):
|
||||||
"""Construct a new FuncGraph.
|
"""Construct a new FuncGraph.
|
||||||
|
|
||||||
|
The graph will inherit its graph key, collections, seed, and distribution
|
||||||
|
strategy stack from the current context or graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: the name of the function.
|
name: the name of the function.
|
||||||
graph: if specified, this FuncGraph will inherit its graph key,
|
|
||||||
collections, and seed from `graph`.
|
|
||||||
"""
|
"""
|
||||||
super(FuncGraph, self).__init__(graph=graph)
|
super(FuncGraph, self).__init__()
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.inputs = []
|
self.inputs = []
|
||||||
@ -213,30 +214,29 @@ class FuncGraph(CapturingGraph):
|
|||||||
self.structured_outputs = None
|
self.structured_outputs = None
|
||||||
self.variables = []
|
self.variables = []
|
||||||
|
|
||||||
if graph is not None:
|
if context.executing_eagerly():
|
||||||
|
self.seed = context.global_seed()
|
||||||
|
self._xla_compile = (context.context().device_spec.device_type == "TPU")
|
||||||
|
else:
|
||||||
|
graph = ops.get_default_graph()
|
||||||
# Inherit the graph key, since this is used for matching variables in
|
# Inherit the graph key, since this is used for matching variables in
|
||||||
# optimizers.
|
# optimizers.
|
||||||
self._graph_key = graph._graph_key # pylint: disable=protected-access
|
self._graph_key = graph._graph_key # pylint: disable=protected-access
|
||||||
|
self.seed = graph.seed
|
||||||
|
self._xla_compile = getattr(graph, "_xla_compile", False)
|
||||||
|
|
||||||
# Copy the graph collections to ensure summaries and other things work.
|
graph = ops.get_default_graph()
|
||||||
# This lets the function access (but not mutate) collections of the
|
# TODO(b/112165328, b/112906995): summaries depend on inheriting collections
|
||||||
# containing graph, such as the global step and the summary writer
|
# from the default graph even in eager mode. It'd be nice to not have a
|
||||||
# collections.
|
# default graph with eager execution, so hopefully this will go away when we
|
||||||
for collection in graph.collections:
|
# remove collections.
|
||||||
self.get_collection_ref(collection)[:] = graph.get_collection(
|
# pylint: disable=protected-access
|
||||||
collection)
|
self._collections = graph._collections
|
||||||
|
# TODO(b/112906995): distribution strategy depends on inheriting this stack
|
||||||
# Copy distribution strategy scope from the containing graph as well.
|
# from the default graph even in eager mode. Maybe it should be part of the
|
||||||
self._distribution_strategy_stack = graph._distribution_strategy_stack # pylint: disable=protected-access
|
# eager context?
|
||||||
|
self._distribution_strategy_stack = graph._distribution_strategy_stack
|
||||||
if context.executing_eagerly():
|
# pylint: enable=protected-access
|
||||||
self.seed = context.global_seed()
|
|
||||||
self._xla_compile = (context.context().device_spec.device_type == "TPU")
|
|
||||||
else:
|
|
||||||
self.seed = graph.seed
|
|
||||||
self._xla_compile = getattr(graph, "_xla_compile", False)
|
|
||||||
else:
|
|
||||||
self._xla_compile = False
|
|
||||||
|
|
||||||
def capture(self, tensor, name=None):
|
def capture(self, tensor, name=None):
|
||||||
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
|
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
|
||||||
@ -476,8 +476,7 @@ class GraphCallable(object):
|
|||||||
|
|
||||||
def _construct_backprop_function(self):
|
def _construct_backprop_function(self):
|
||||||
"""Constructs the backprop function object for this function."""
|
"""Constructs the backprop function object for this function."""
|
||||||
backwards_graph = FuncGraph(
|
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
|
||||||
_backward_name(self._func_graph.name), self._func_graph)
|
|
||||||
with backwards_graph.as_default():
|
with backwards_graph.as_default():
|
||||||
gradients_wrt_outputs = [
|
gradients_wrt_outputs = [
|
||||||
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
|
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
|
||||||
@ -697,7 +696,7 @@ def _func_graph_from_py_func(name, python_func, args, kwds, signature=None):
|
|||||||
TypeError: If any of `python_func`'s return values is neither `None` nor a
|
TypeError: If any of `python_func`'s return values is neither `None` nor a
|
||||||
`Tensor`.
|
`Tensor`.
|
||||||
"""
|
"""
|
||||||
func_graph = FuncGraph(name, graph=ops.get_default_graph())
|
func_graph = FuncGraph(name)
|
||||||
with func_graph.as_default(), AutomaticControlDependencies() as a:
|
with func_graph.as_default(), AutomaticControlDependencies() as a:
|
||||||
variable_scope.get_variable_scope().set_use_resource(True)
|
variable_scope.get_variable_scope().set_use_resource(True)
|
||||||
|
|
||||||
|
@ -1614,6 +1614,33 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
|||||||
# Collections are ordered.
|
# Collections are ordered.
|
||||||
self.assertEqual([90, 100], ops.get_collection("key"))
|
self.assertEqual([90, 100], ops.get_collection("key"))
|
||||||
|
|
||||||
|
def test_defun(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
|
||||||
|
@eager_function.defun
|
||||||
|
def defun():
|
||||||
|
ops.add_to_collection("int", 1)
|
||||||
|
ops.add_to_collection("tensor", constant_op.constant(2))
|
||||||
|
|
||||||
|
@eager_function.defun
|
||||||
|
def inner_defun():
|
||||||
|
self.assertEqual(ops.get_collection("int"), [1])
|
||||||
|
three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
|
||||||
|
ops.add_to_collection("int", 2)
|
||||||
|
self.assertEqual(ops.get_collection("int"), [1, 2])
|
||||||
|
ops.add_to_collection("foo", "bar")
|
||||||
|
self.assertEqual(ops.get_collection("foo"), ["bar"])
|
||||||
|
return three
|
||||||
|
|
||||||
|
self.assertEqual(ops.get_collection("int"), [1])
|
||||||
|
three = inner_defun()
|
||||||
|
self.assertEqual(ops.get_collection("int"), [1, 2])
|
||||||
|
self.assertEqual(ops.get_collection("foo"), ["bar"])
|
||||||
|
return three
|
||||||
|
|
||||||
|
three = defun()
|
||||||
|
self.assertEqual(three.numpy(), 3)
|
||||||
|
|
||||||
|
|
||||||
ops.NotDifferentiable("FloatOutput")
|
ops.NotDifferentiable("FloatOutput")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user