Make func_graph captures private to class
- Modify callers to use helper methods vs directly touching the captures dictionary. - Change captures call to return list of ordered tuples instead of dict PiperOrigin-RevId: 260634556
This commit is contained in:
parent
ace39d8e0a
commit
bdbd1d27dc
@ -721,9 +721,7 @@ class _TapeGradientFunctions(object):
|
||||
"""Forward+backward functions where the backward function sees `outputs`."""
|
||||
# First figure out which of `outputs` are trainable. We'll accept gradients
|
||||
# for each of these in the backward function.
|
||||
handles_to_variables = {self._func_graph.captures[v.handle]: v
|
||||
for v in self._func_graph.variables
|
||||
if v.handle in self._func_graph.captures}
|
||||
handles_to_variables = self._func_graph.variable_captures
|
||||
trainable_outputs = []
|
||||
for output in outputs:
|
||||
if gradients_util.IsTrainable(output):
|
||||
@ -753,8 +751,9 @@ class _TapeGradientFunctions(object):
|
||||
src_graph=self._func_graph)
|
||||
|
||||
captures_from_forward = [
|
||||
c for c in backwards_graph.captures.keys() if
|
||||
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
||||
c for c in backwards_graph.external_captures
|
||||
if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
|
||||
]
|
||||
existing_outputs = set(self._func_graph.outputs)
|
||||
for capture in captures_from_forward:
|
||||
if capture not in existing_outputs:
|
||||
@ -770,7 +769,7 @@ class _TapeGradientFunctions(object):
|
||||
# `backward_function` correspond to outputs (including
|
||||
# side outputs) of `self._tape_forward_function`.
|
||||
backwards_graph.inputs = (
|
||||
gradients_wrt_outputs + list(backwards_graph.captures.values()))
|
||||
gradients_wrt_outputs + backwards_graph.internal_captures)
|
||||
backwards_graph.outputs.extend(
|
||||
grad
|
||||
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
||||
@ -980,9 +979,8 @@ class ConcreteFunction(object):
|
||||
self._arg_keywords = None
|
||||
self._num_positional_args = None
|
||||
self._func_graph = func_graph
|
||||
self._captured_inputs = list(self._func_graph.captures.keys())
|
||||
self._captured_closures = [
|
||||
x[0] for x in self._func_graph.deferred_captures.values()]
|
||||
self._captured_inputs = self._func_graph.external_captures
|
||||
self._captured_closures = self._func_graph.deferred_external_captures
|
||||
self._output_shapes = tuple(
|
||||
output.shape for output in self._func_graph.outputs)
|
||||
attrs = _parse_func_attrs(attrs or {})
|
||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import six
|
||||
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
@ -288,11 +287,14 @@ def lift_to_graph(init_tensors,
|
||||
|
||||
# When lifting from one FuncGraph to another, we will need to capture the
|
||||
# relevant tensors as well.
|
||||
captures = collections.OrderedDict()
|
||||
captures = []
|
||||
inverse_captures = {}
|
||||
internal_captures = []
|
||||
if (isinstance(base_graph, func_graph.FuncGraph) and
|
||||
isinstance(graph, func_graph.FuncGraph)):
|
||||
captures = base_graph.captures
|
||||
inverse_captures = {v: k for k, v in captures.items()}
|
||||
inverse_captures = {v: k for k, v in captures}
|
||||
internal_captures = base_graph.internal_captures
|
||||
|
||||
# ops_to_copy now holds a reverse topologically sorted list of ops which
|
||||
# ends in the initializer. We copy those to the outermost graph and
|
||||
@ -302,7 +304,7 @@ def lift_to_graph(init_tensors,
|
||||
}) # Pass through variables.
|
||||
source_ops = set()
|
||||
# Add the sources in the same order as the original graph.
|
||||
for s in six.itervalues(captures):
|
||||
for s in internal_captures:
|
||||
if s in sources:
|
||||
sources.remove(s)
|
||||
source_ops.add(s.op)
|
||||
|
@ -41,7 +41,7 @@ class LiftToGraphTest(test.TestCase):
|
||||
return v1 + v2 + v3
|
||||
|
||||
concrete_fn = fn.get_concrete_function()
|
||||
original_captures = concrete_fn.graph.captures
|
||||
original_captures = concrete_fn.graph.internal_captures
|
||||
outputs = concrete_fn.graph.outputs
|
||||
|
||||
for _ in range(100):
|
||||
@ -49,11 +49,10 @@ class LiftToGraphTest(test.TestCase):
|
||||
|
||||
lift_to_graph.lift_to_graph(
|
||||
outputs, g, add_sources=True, handle_captures=True)
|
||||
lifted_captures = g.captures
|
||||
lifted_captures = g.internal_captures
|
||||
self.assertLen(lifted_captures, 3)
|
||||
for original_capture, lifted_capture in zip(original_captures.values(),
|
||||
lifted_captures.values()):
|
||||
self.assertEqual(original_capture.name, lifted_capture.name)
|
||||
for original, lifted in zip(original_captures, lifted_captures):
|
||||
self.assertEqual(original.name, lifted.name)
|
||||
|
||||
def testClassAttrsRemoved(self):
|
||||
"""Tests that _class attrs (from colocate_with()) are removed."""
|
||||
|
@ -116,8 +116,7 @@ def _lift_single_variable(old_variable, graph, variable_holder):
|
||||
trainable=old_variable.trainable,
|
||||
extra_handle_data=old_variable.handle)
|
||||
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
||||
graph.inputs.append(old_variable.handle)
|
||||
graph.captures[new_variable.handle] = old_variable.handle
|
||||
graph.add_capture(new_variable.handle, old_variable.handle)
|
||||
# Now that we've added the new variable to graph.captures,
|
||||
# graph.capture will use that cached value and do some post-processing
|
||||
# on the capture like recording it on the tape.
|
||||
@ -311,10 +310,9 @@ class WrappedFunction(function.ConcreteFunction):
|
||||
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
||||
pruned_graph.control_outputs.extend(
|
||||
[lift_map[operation] for operation in operation_fetches])
|
||||
for external_capture, internal_capture in self.graph.captures.items():
|
||||
pruned_graph.captures[external_capture] = lift_map[internal_capture]
|
||||
pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
|
||||
pruned_graph.inputs.extend(pruned_graph.captures.values())
|
||||
for external_capture, internal_capture in self.graph.captures:
|
||||
pruned_graph.add_capture(external_capture, lift_map[internal_capture])
|
||||
for ti in tensor_infos:
|
||||
if ti.WhichOneof("encoding") == "name": # Dense tensors only
|
||||
t = pruned_graph.as_graph_element(ti.name)
|
||||
|
@ -182,7 +182,7 @@ def _get_tensor_data(func):
|
||||
}
|
||||
|
||||
# Iterates through all captures which are represented as Placeholders.
|
||||
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures.items()):
|
||||
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
|
||||
tensor_name = _get_tensor_name(name_tensor.name)
|
||||
is_variable = idx in map_index_to_variable
|
||||
if is_variable:
|
||||
@ -352,7 +352,7 @@ def _construct_concrete_function(func, output_graph_def,
|
||||
ConcreteFunction.
|
||||
"""
|
||||
# Create a ConcreteFunction from the new GraphDef.
|
||||
input_tensors = list(func.graph.captures.values())
|
||||
input_tensors = func.graph.internal_captures
|
||||
converted_inputs = set(
|
||||
[input_tensors[index] for index in converted_input_indices])
|
||||
not_converted_inputs = set(func.inputs).difference(converted_inputs)
|
||||
|
@ -152,9 +152,6 @@ class FuncGraph(ops.Graph):
|
||||
or the global default Graph.
|
||||
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
|
||||
The entries are in the order they were captured.
|
||||
deferred_captures: Maps arbitrary key -> (closure, nest of placeholders),
|
||||
where at function call time the value of closure() will be used to feed
|
||||
the nest of placeholders.
|
||||
control_captures: Set of external ops on which this graph has a control
|
||||
dependency.
|
||||
seed: The graph-level random seed.
|
||||
@ -193,12 +190,15 @@ class FuncGraph(ops.Graph):
|
||||
self._weak_variables = []
|
||||
self._watched_variables = weakref.WeakSet()
|
||||
self.outer_graph = ops.get_default_graph()
|
||||
self.captures = py_collections.OrderedDict()
|
||||
self._captures = py_collections.OrderedDict()
|
||||
# If not None, records the names of output args of this function. Used to
|
||||
# preserve the output names in the signature of a serialized+deserialized
|
||||
# function. Private at the moment mostly because it's often out of date.
|
||||
self._output_names = None
|
||||
self.deferred_captures = py_collections.OrderedDict()
|
||||
# Maps arbitrary key -> (closure, nest of placeholders), where at function
|
||||
# call time the value of closure() will be used to feed the nest of
|
||||
# placeholders.
|
||||
self._deferred_captures = py_collections.OrderedDict()
|
||||
# Inherit capture-by-value from outer graph.
|
||||
if capture_by_value is not None:
|
||||
self.capture_by_value = capture_by_value
|
||||
@ -273,7 +273,7 @@ class FuncGraph(ops.Graph):
|
||||
"""
|
||||
if key is None:
|
||||
key = object()
|
||||
if key not in self.deferred_captures:
|
||||
if key not in self._deferred_captures:
|
||||
|
||||
def convert_to_placeholder(s):
|
||||
if not isinstance(s, tensor_spec.TensorSpec):
|
||||
@ -296,8 +296,8 @@ class FuncGraph(ops.Graph):
|
||||
# pylint: enable=protected-access
|
||||
return nest.flatten(y, expand_composites=True)
|
||||
|
||||
self.deferred_captures[key] = (wrapped_closure, placeholder)
|
||||
return self.deferred_captures[key][1]
|
||||
self._deferred_captures[key] = (wrapped_closure, placeholder)
|
||||
return self._deferred_captures[key][1]
|
||||
|
||||
def control_dependencies(self, control_inputs):
|
||||
"""Handles control dependencies.
|
||||
@ -439,7 +439,7 @@ class FuncGraph(ops.Graph):
|
||||
op_def=None,
|
||||
compute_device=True):
|
||||
# When capturing by value, do the read outside
|
||||
reverse_captures = dict((v, k) for k, v in self.captures.items())
|
||||
reverse_captures = dict((v, k) for k, v in self.captures)
|
||||
uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs]
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
@ -584,31 +584,82 @@ class FuncGraph(ops.Graph):
|
||||
return tensor
|
||||
|
||||
def _capture_helper(self, tensor, name):
|
||||
captured_tensor = self.captures.get(tensor, None)
|
||||
if captured_tensor is None:
|
||||
captured_tensor = _create_substitute_placeholder(tensor, name=name,
|
||||
dtype=tensor.dtype)
|
||||
self.captures[tensor] = captured_tensor
|
||||
self.inputs.append(captured_tensor)
|
||||
tape.record_operation("captured_value", [captured_tensor], [tensor],
|
||||
placeholder = self._captures.get(tensor, None)
|
||||
if placeholder is None:
|
||||
placeholder = _create_substitute_placeholder(
|
||||
tensor, name=name, dtype=tensor.dtype)
|
||||
self.add_capture(tensor, placeholder)
|
||||
tape.record_operation("captured_value", [placeholder], [tensor],
|
||||
lambda x: [x])
|
||||
return captured_tensor
|
||||
return placeholder
|
||||
|
||||
@property
|
||||
def captures(self):
|
||||
"""Order list of tuples containing external and internal captures."""
|
||||
return self._captures.items()
|
||||
|
||||
def add_capture(self, tensor, placeholder):
|
||||
"""Capture a specific tensor and utilize the provided placeholder.
|
||||
|
||||
Args:
|
||||
tensor: Tensor to captures.
|
||||
placeholder: Provided placeholder for the tensor.
|
||||
"""
|
||||
self._captures[tensor] = placeholder
|
||||
self.inputs.append(placeholder)
|
||||
|
||||
def reset_captures(self, capture_list):
|
||||
"""Set the captures with the provided list of captures & placeholder."""
|
||||
self._captures = py_collections.OrderedDict(capture_list)
|
||||
|
||||
def pop_capture(self, tensor):
|
||||
"""Remove the capture and return the generated placeholder."""
|
||||
return self._captures.pop(tensor, None)
|
||||
|
||||
def clear_captures(self):
|
||||
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
||||
# Clearing captures using clear() leaves some cycles around.
|
||||
while self._captures:
|
||||
self._captures.popitem()
|
||||
memory.dismantle_ordered_dict(self._captures)
|
||||
while self._deferred_captures:
|
||||
self._deferred_captures.popitem()
|
||||
memory.dismantle_ordered_dict(self._deferred_captures)
|
||||
|
||||
def capture_distributed_variable(self, variable, placeholder):
|
||||
"""Add given distributed variable to captures with given placeholder."""
|
||||
self.captures[variable] = placeholder
|
||||
self._captures[variable] = placeholder
|
||||
tape.record_operation("captured_value", [placeholder], [variable],
|
||||
lambda x: [x])
|
||||
|
||||
@property
|
||||
def external_captures(self):
|
||||
"""External tensors captured by this function."""
|
||||
return list(self.captures.keys())
|
||||
return list(self._captures.keys())
|
||||
|
||||
@property
|
||||
def internal_captures(self):
|
||||
"""Placeholders in this function corresponding captured tensors."""
|
||||
return list(self.captures.values())
|
||||
return list(self._captures.values())
|
||||
|
||||
@property
|
||||
def deferred_external_captures(self):
|
||||
"""Ordered nest of tensors whose placeholders will be fed at call time."""
|
||||
return [c[0] for c in self._deferred_captures.values()]
|
||||
|
||||
@property
|
||||
def deferred_internal_captures(self):
|
||||
"""List of nest of placeholders which at call time will be fed."""
|
||||
return [c[1] for c in self._deferred_captures.values()]
|
||||
|
||||
@property
|
||||
def variable_captures(self):
|
||||
"""Map of variable handles to variables that as in the list of captures."""
|
||||
return {
|
||||
self._captures[v.handle]: v
|
||||
for v in self.variables
|
||||
if v.handle in self._captures
|
||||
}
|
||||
|
||||
|
||||
def func_graph_from_py_func(name,
|
||||
@ -813,7 +864,7 @@ def func_graph_from_py_func(name,
|
||||
# Even if an argument variable was not used in the function, we've
|
||||
# already manually captured the resource Tensor when creating argument
|
||||
# placeholders.
|
||||
resource_placeholder = func_graph.captures.pop(arg.handle, None)
|
||||
resource_placeholder = func_graph.pop_capture(arg.handle)
|
||||
if resource_placeholder is None:
|
||||
continue
|
||||
arg_variables.add(arg)
|
||||
@ -822,12 +873,8 @@ def func_graph_from_py_func(name,
|
||||
inputs.append(arg)
|
||||
variables = [v for v in graph_variables if v not in arg_variables]
|
||||
func_graph.inputs = (
|
||||
inputs +
|
||||
list(func_graph.captures.values()) +
|
||||
nest.flatten(
|
||||
[x[1] for x in func_graph.deferred_captures.values()],
|
||||
expand_composites=True))
|
||||
|
||||
inputs + func_graph.internal_captures + nest.flatten(
|
||||
func_graph.deferred_internal_captures, expand_composites=True))
|
||||
func_graph.structured_outputs = func_outputs
|
||||
# Returning a closed-over tensor does not trigger convert_to_tensor.
|
||||
func_graph.outputs.extend(
|
||||
@ -854,7 +901,7 @@ def maybe_captured(tensor):
|
||||
"""
|
||||
if (not isinstance(tensor, ops.EagerTensor) and
|
||||
tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
|
||||
for input_t, placeholder_t in tensor.op.graph.captures.items():
|
||||
for input_t, placeholder_t in tensor.op.graph.captures:
|
||||
if tensor == placeholder_t:
|
||||
return maybe_captured(input_t)
|
||||
# pylint: enable=protected-access
|
||||
@ -1064,12 +1111,5 @@ def dismantle_func_graph(func_graph):
|
||||
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
|
||||
after this function.
|
||||
"""
|
||||
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
||||
# Clearing captures using clear() leaves some cycles around.
|
||||
while func_graph.captures:
|
||||
func_graph.captures.popitem()
|
||||
memory.dismantle_ordered_dict(func_graph.captures)
|
||||
while func_graph.deferred_captures:
|
||||
func_graph.deferred_captures.popitem()
|
||||
memory.dismantle_ordered_dict(func_graph.deferred_captures)
|
||||
func_graph.clear_captures()
|
||||
ops.dismantle_graph(func_graph)
|
||||
|
@ -3457,8 +3457,7 @@ class EagerExecutionFunction(object):
|
||||
with ops.control_dependencies(updates_ops):
|
||||
self.outputs[0] = array_ops.identity(self.outputs[0])
|
||||
|
||||
exec_graph.inputs = self._input_references + list(
|
||||
exec_graph.captures.values())
|
||||
exec_graph.inputs = self._input_references + exec_graph.internal_captures
|
||||
exec_graph.outputs = self.outputs
|
||||
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
||||
|
||||
|
@ -275,8 +275,7 @@ def get_func_graphs(op):
|
||||
fdef, input_shapes)
|
||||
for external_t, internal_t in zip(inputs, func_graph.inputs):
|
||||
custom_gradient.copy_handle_data(external_t, internal_t)
|
||||
func_graph.captures = collections.OrderedDict(zip(inputs,
|
||||
func_graph.inputs))
|
||||
func_graph.reset_captures(zip(inputs, func_graph.inputs))
|
||||
# Link the op so that the gradient code can use it.
|
||||
func_graph._forward_cond = op
|
||||
return func_graph
|
||||
@ -482,8 +481,7 @@ def _make_inputs_match(branch_graphs, branch_inputs):
|
||||
branch_graph.inputs = input_list
|
||||
|
||||
# Rewrite the FuncGraphs' state to reflect the new inputs.
|
||||
branch_graph.captures = collections.OrderedDict(
|
||||
zip(new_inputs, branch_graph.inputs))
|
||||
branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
|
||||
|
||||
return new_inputs
|
||||
|
||||
@ -751,7 +749,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
||||
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||
# XLA does not yet support optionals, so capture intermediates directly.
|
||||
# TODO(skyewm,jpienaar): can XLA support optionals?
|
||||
if tensor not in self.captures:
|
||||
if tensor not in self.external_captures:
|
||||
self.xla_intermediates.append(tensor)
|
||||
self.op_needs_rewrite = True
|
||||
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
|
||||
|
@ -400,7 +400,7 @@ def _Captures(func_graph):
|
||||
return func_graph.captures
|
||||
else:
|
||||
assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
|
||||
return func_graph._captured # pylint: disable=protected-access
|
||||
return func_graph._captured.items() # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _MaybeCaptured(t):
|
||||
@ -415,7 +415,7 @@ def _MaybeCaptured(t):
|
||||
# pylint: disable=protected-access
|
||||
if (not isinstance(t, ops.EagerTensor) and
|
||||
_IsFunction(t.op.graph) and t.op.type == "Placeholder"):
|
||||
for input_t, placeholder_t in _Captures(t.op.graph).items():
|
||||
for input_t, placeholder_t in _Captures(t.op.graph):
|
||||
if t == placeholder_t:
|
||||
return _MaybeCaptured(input_t)
|
||||
# pylint: enable=protected-access
|
||||
@ -481,7 +481,7 @@ def _Consumers(t, func_graphs):
|
||||
"""
|
||||
consumers = t.consumers()
|
||||
for func in func_graphs:
|
||||
for input_t, placeholder in _Captures(func).items():
|
||||
for input_t, placeholder in _Captures(func):
|
||||
if input_t == t:
|
||||
consumers.extend(_Consumers(placeholder, func_graphs))
|
||||
return consumers
|
||||
|
@ -203,7 +203,7 @@ def while_loop(cond,
|
||||
assert (cond_graph.external_captures ==
|
||||
body_graph.external_captures[:num_cond_captures])
|
||||
for body_capture in body_graph.external_captures[num_cond_captures:]:
|
||||
assert body_capture not in cond_graph.captures
|
||||
assert body_capture not in cond_graph.external_captures
|
||||
cond_graph.capture(body_capture)
|
||||
|
||||
# Make sure that the shapes of the loop outputs are compatible with the
|
||||
@ -497,7 +497,7 @@ def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
|
||||
# `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
|
||||
# 2. Resources, which are output as is.
|
||||
# 3. Forward graph loop invariants, which are output as is.
|
||||
for external_capture, internal_capture in grad_func_graph.captures.items():
|
||||
for external_capture, internal_capture in grad_func_graph.captures:
|
||||
if internal_capture in grad_func_graph.popped_tensor_lists:
|
||||
new_output = grad_func_graph.popped_tensor_lists[internal_capture]
|
||||
elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(
|
||||
|
@ -180,7 +180,7 @@ class Loader(object):
|
||||
concrete_function.graph.capture_distributed_variable(
|
||||
bound_input, internal_capture)
|
||||
else:
|
||||
concrete_function.graph.captures[bound_input] = internal_capture
|
||||
concrete_function.graph._captures[bound_input] = internal_capture # pylint: disable=protected-access
|
||||
if internal_capture.dtype == dtypes.resource:
|
||||
if resource_variable_ops.is_resource_variable(bound_input):
|
||||
try:
|
||||
|
@ -874,14 +874,15 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
root = Root()
|
||||
self.assertIn(root.v.handle,
|
||||
root.use_v.get_concrete_function().graph.captures)
|
||||
root.use_v.get_concrete_function().graph.external_captures)
|
||||
for _ in range(cycles):
|
||||
root = self.cycle(root, 1, signatures=root.use_v.get_concrete_function())
|
||||
func_captures = root.use_v.get_concrete_function().graph.captures
|
||||
func_captures = root.use_v.get_concrete_function().graph.external_captures
|
||||
self.assertLen(func_captures, 2)
|
||||
self.assertIn(root.v.handle, func_captures)
|
||||
self.assertIn(root.v1.handle, func_captures)
|
||||
signature_captures = root.signatures["serving_default"].graph.captures
|
||||
signature_captures = root.signatures[
|
||||
"serving_default"].graph.external_captures
|
||||
self.assertLen(signature_captures, 2)
|
||||
self.assertIn(root.v.handle, signature_captures)
|
||||
self.assertIn(root.v1.handle, signature_captures)
|
||||
|
@ -322,7 +322,7 @@ def _map_captures_to_created_tensors(
|
||||
`resource_map`.
|
||||
"""
|
||||
export_captures = []
|
||||
for exterior, interior in original_captures.items():
|
||||
for exterior, interior in original_captures:
|
||||
mapped_resource = resource_map.get(exterior, None)
|
||||
if mapped_resource is None:
|
||||
raise AssertionError(
|
||||
|
Loading…
Reference in New Issue
Block a user