Make func_graph captures private to class

- Modify callers to use helper methods vs directly touching the captures
  dictionary.
- Change captures call to return list of ordered tuples instead of dict

PiperOrigin-RevId: 260634556
This commit is contained in:
Gaurav Jain 2019-07-29 20:05:37 -07:00 committed by TensorFlower Gardener
parent ace39d8e0a
commit bdbd1d27dc
13 changed files with 113 additions and 78 deletions

View File

@ -721,9 +721,7 @@ class _TapeGradientFunctions(object):
"""Forward+backward functions where the backward function sees `outputs`."""
# First figure out which of `outputs` are trainable. We'll accept gradients
# for each of these in the backward function.
handles_to_variables = {self._func_graph.captures[v.handle]: v
for v in self._func_graph.variables
if v.handle in self._func_graph.captures}
handles_to_variables = self._func_graph.variable_captures
trainable_outputs = []
for output in outputs:
if gradients_util.IsTrainable(output):
@ -753,8 +751,9 @@ class _TapeGradientFunctions(object):
src_graph=self._func_graph)
captures_from_forward = [
c for c in backwards_graph.captures.keys() if
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
c for c in backwards_graph.external_captures
if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
]
existing_outputs = set(self._func_graph.outputs)
for capture in captures_from_forward:
if capture not in existing_outputs:
@ -770,7 +769,7 @@ class _TapeGradientFunctions(object):
# `backward_function` correspond to outputs (including
# side outputs) of `self._tape_forward_function`.
backwards_graph.inputs = (
gradients_wrt_outputs + list(backwards_graph.captures.values()))
gradients_wrt_outputs + backwards_graph.internal_captures)
backwards_graph.outputs.extend(
grad
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
@ -980,9 +979,8 @@ class ConcreteFunction(object):
self._arg_keywords = None
self._num_positional_args = None
self._func_graph = func_graph
self._captured_inputs = list(self._func_graph.captures.keys())
self._captured_closures = [
x[0] for x in self._func_graph.deferred_captures.values()]
self._captured_inputs = self._func_graph.external_captures
self._captured_closures = self._func_graph.deferred_external_captures
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
attrs = _parse_func_attrs(attrs or {})

View File

@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import collections
import six
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
@ -288,11 +287,14 @@ def lift_to_graph(init_tensors,
# When lifting from one FuncGraph to another, we will need to capture the
# relevant tensors as well.
captures = collections.OrderedDict()
captures = []
inverse_captures = {}
internal_captures = []
if (isinstance(base_graph, func_graph.FuncGraph) and
isinstance(graph, func_graph.FuncGraph)):
captures = base_graph.captures
inverse_captures = {v: k for k, v in captures.items()}
inverse_captures = {v: k for k, v in captures}
internal_captures = base_graph.internal_captures
# ops_to_copy now holds a reverse topologically sorted list of ops which
# ends in the initializer. We copy those to the outermost graph and
@ -302,7 +304,7 @@ def lift_to_graph(init_tensors,
}) # Pass through variables.
source_ops = set()
# Add the sources in the same order as the original graph.
for s in six.itervalues(captures):
for s in internal_captures:
if s in sources:
sources.remove(s)
source_ops.add(s.op)

View File

@ -41,7 +41,7 @@ class LiftToGraphTest(test.TestCase):
return v1 + v2 + v3
concrete_fn = fn.get_concrete_function()
original_captures = concrete_fn.graph.captures
original_captures = concrete_fn.graph.internal_captures
outputs = concrete_fn.graph.outputs
for _ in range(100):
@ -49,11 +49,10 @@ class LiftToGraphTest(test.TestCase):
lift_to_graph.lift_to_graph(
outputs, g, add_sources=True, handle_captures=True)
lifted_captures = g.captures
lifted_captures = g.internal_captures
self.assertLen(lifted_captures, 3)
for original_capture, lifted_capture in zip(original_captures.values(),
lifted_captures.values()):
self.assertEqual(original_capture.name, lifted_capture.name)
for original, lifted in zip(original_captures, lifted_captures):
self.assertEqual(original.name, lifted.name)
def testClassAttrsRemoved(self):
"""Tests that _class attrs (from colocate_with()) are removed."""

View File

@ -116,8 +116,7 @@ def _lift_single_variable(old_variable, graph, variable_holder):
trainable=old_variable.trainable,
extra_handle_data=old_variable.handle)
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
graph.inputs.append(old_variable.handle)
graph.captures[new_variable.handle] = old_variable.handle
graph.add_capture(new_variable.handle, old_variable.handle)
# Now that we've added the new variable to graph.captures,
# graph.capture will use that cached value and do some post-processing
# on the capture like recording it on the tape.
@ -311,10 +310,9 @@ class WrappedFunction(function.ConcreteFunction):
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
pruned_graph.control_outputs.extend(
[lift_map[operation] for operation in operation_fetches])
for external_capture, internal_capture in self.graph.captures.items():
pruned_graph.captures[external_capture] = lift_map[internal_capture]
pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
pruned_graph.inputs.extend(pruned_graph.captures.values())
for external_capture, internal_capture in self.graph.captures:
pruned_graph.add_capture(external_capture, lift_map[internal_capture])
for ti in tensor_infos:
if ti.WhichOneof("encoding") == "name": # Dense tensors only
t = pruned_graph.as_graph_element(ti.name)

View File

@ -182,7 +182,7 @@ def _get_tensor_data(func):
}
# Iterates through all captures which are represented as Placeholders.
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures.items()):
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
tensor_name = _get_tensor_name(name_tensor.name)
is_variable = idx in map_index_to_variable
if is_variable:
@ -352,7 +352,7 @@ def _construct_concrete_function(func, output_graph_def,
ConcreteFunction.
"""
# Create a ConcreteFunction from the new GraphDef.
input_tensors = list(func.graph.captures.values())
input_tensors = func.graph.internal_captures
converted_inputs = set(
[input_tensors[index] for index in converted_input_indices])
not_converted_inputs = set(func.inputs).difference(converted_inputs)

View File

@ -152,9 +152,6 @@ class FuncGraph(ops.Graph):
or the global default Graph.
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
The entries are in the order they were captured.
deferred_captures: Maps arbitrary key -> (closure, nest of placeholders),
where at function call time the value of closure() will be used to feed
the nest of placeholders.
control_captures: Set of external ops on which this graph has a control
dependency.
seed: The graph-level random seed.
@ -193,12 +190,15 @@ class FuncGraph(ops.Graph):
self._weak_variables = []
self._watched_variables = weakref.WeakSet()
self.outer_graph = ops.get_default_graph()
self.captures = py_collections.OrderedDict()
self._captures = py_collections.OrderedDict()
# If not None, records the names of output args of this function. Used to
# preserve the output names in the signature of a serialized+deserialized
# function. Private at the moment mostly because it's often out of date.
self._output_names = None
self.deferred_captures = py_collections.OrderedDict()
# Maps arbitrary key -> (closure, nest of placeholders), where at function
# call time the value of closure() will be used to feed the nest of
# placeholders.
self._deferred_captures = py_collections.OrderedDict()
# Inherit capture-by-value from outer graph.
if capture_by_value is not None:
self.capture_by_value = capture_by_value
@ -273,7 +273,7 @@ class FuncGraph(ops.Graph):
"""
if key is None:
key = object()
if key not in self.deferred_captures:
if key not in self._deferred_captures:
def convert_to_placeholder(s):
if not isinstance(s, tensor_spec.TensorSpec):
@ -296,8 +296,8 @@ class FuncGraph(ops.Graph):
# pylint: enable=protected-access
return nest.flatten(y, expand_composites=True)
self.deferred_captures[key] = (wrapped_closure, placeholder)
return self.deferred_captures[key][1]
self._deferred_captures[key] = (wrapped_closure, placeholder)
return self._deferred_captures[key][1]
def control_dependencies(self, control_inputs):
"""Handles control dependencies.
@ -439,7 +439,7 @@ class FuncGraph(ops.Graph):
op_def=None,
compute_device=True):
# When capturing by value, do the read outside
reverse_captures = dict((v, k) for k, v in self.captures.items())
reverse_captures = dict((v, k) for k, v in self.captures)
uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs]
with ops.init_scope():
if context.executing_eagerly():
@ -584,31 +584,82 @@ class FuncGraph(ops.Graph):
return tensor
def _capture_helper(self, tensor, name):
captured_tensor = self.captures.get(tensor, None)
if captured_tensor is None:
captured_tensor = _create_substitute_placeholder(tensor, name=name,
dtype=tensor.dtype)
self.captures[tensor] = captured_tensor
self.inputs.append(captured_tensor)
tape.record_operation("captured_value", [captured_tensor], [tensor],
placeholder = self._captures.get(tensor, None)
if placeholder is None:
placeholder = _create_substitute_placeholder(
tensor, name=name, dtype=tensor.dtype)
self.add_capture(tensor, placeholder)
tape.record_operation("captured_value", [placeholder], [tensor],
lambda x: [x])
return captured_tensor
return placeholder
@property
def captures(self):
"""Order list of tuples containing external and internal captures."""
return self._captures.items()
def add_capture(self, tensor, placeholder):
"""Capture a specific tensor and utilize the provided placeholder.
Args:
tensor: Tensor to captures.
placeholder: Provided placeholder for the tensor.
"""
self._captures[tensor] = placeholder
self.inputs.append(placeholder)
def reset_captures(self, capture_list):
"""Set the captures with the provided list of captures & placeholder."""
self._captures = py_collections.OrderedDict(capture_list)
def pop_capture(self, tensor):
"""Remove the capture and return the generated placeholder."""
return self._captures.pop(tensor, None)
def clear_captures(self):
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
# Clearing captures using clear() leaves some cycles around.
while self._captures:
self._captures.popitem()
memory.dismantle_ordered_dict(self._captures)
while self._deferred_captures:
self._deferred_captures.popitem()
memory.dismantle_ordered_dict(self._deferred_captures)
def capture_distributed_variable(self, variable, placeholder):
"""Add given distributed variable to captures with given placeholder."""
self.captures[variable] = placeholder
self._captures[variable] = placeholder
tape.record_operation("captured_value", [placeholder], [variable],
lambda x: [x])
@property
def external_captures(self):
"""External tensors captured by this function."""
return list(self.captures.keys())
return list(self._captures.keys())
@property
def internal_captures(self):
"""Placeholders in this function corresponding captured tensors."""
return list(self.captures.values())
return list(self._captures.values())
@property
def deferred_external_captures(self):
"""Ordered nest of tensors whose placeholders will be fed at call time."""
return [c[0] for c in self._deferred_captures.values()]
@property
def deferred_internal_captures(self):
"""List of nest of placeholders which at call time will be fed."""
return [c[1] for c in self._deferred_captures.values()]
@property
def variable_captures(self):
"""Map of variable handles to variables that as in the list of captures."""
return {
self._captures[v.handle]: v
for v in self.variables
if v.handle in self._captures
}
def func_graph_from_py_func(name,
@ -813,7 +864,7 @@ def func_graph_from_py_func(name,
# Even if an argument variable was not used in the function, we've
# already manually captured the resource Tensor when creating argument
# placeholders.
resource_placeholder = func_graph.captures.pop(arg.handle, None)
resource_placeholder = func_graph.pop_capture(arg.handle)
if resource_placeholder is None:
continue
arg_variables.add(arg)
@ -822,12 +873,8 @@ def func_graph_from_py_func(name,
inputs.append(arg)
variables = [v for v in graph_variables if v not in arg_variables]
func_graph.inputs = (
inputs +
list(func_graph.captures.values()) +
nest.flatten(
[x[1] for x in func_graph.deferred_captures.values()],
expand_composites=True))
inputs + func_graph.internal_captures + nest.flatten(
func_graph.deferred_internal_captures, expand_composites=True))
func_graph.structured_outputs = func_outputs
# Returning a closed-over tensor does not trigger convert_to_tensor.
func_graph.outputs.extend(
@ -854,7 +901,7 @@ def maybe_captured(tensor):
"""
if (not isinstance(tensor, ops.EagerTensor) and
tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
for input_t, placeholder_t in tensor.op.graph.captures.items():
for input_t, placeholder_t in tensor.op.graph.captures:
if tensor == placeholder_t:
return maybe_captured(input_t)
# pylint: enable=protected-access
@ -1064,12 +1111,5 @@ def dismantle_func_graph(func_graph):
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
after this function.
"""
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
# Clearing captures using clear() leaves some cycles around.
while func_graph.captures:
func_graph.captures.popitem()
memory.dismantle_ordered_dict(func_graph.captures)
while func_graph.deferred_captures:
func_graph.deferred_captures.popitem()
memory.dismantle_ordered_dict(func_graph.deferred_captures)
func_graph.clear_captures()
ops.dismantle_graph(func_graph)

View File

@ -3457,8 +3457,7 @@ class EagerExecutionFunction(object):
with ops.control_dependencies(updates_ops):
self.outputs[0] = array_ops.identity(self.outputs[0])
exec_graph.inputs = self._input_references + list(
exec_graph.captures.values())
exec_graph.inputs = self._input_references + exec_graph.internal_captures
exec_graph.outputs = self.outputs
graph_fn = eager_function.ConcreteFunction(exec_graph)

View File

@ -275,8 +275,7 @@ def get_func_graphs(op):
fdef, input_shapes)
for external_t, internal_t in zip(inputs, func_graph.inputs):
custom_gradient.copy_handle_data(external_t, internal_t)
func_graph.captures = collections.OrderedDict(zip(inputs,
func_graph.inputs))
func_graph.reset_captures(zip(inputs, func_graph.inputs))
# Link the op so that the gradient code can use it.
func_graph._forward_cond = op
return func_graph
@ -482,8 +481,7 @@ def _make_inputs_match(branch_graphs, branch_inputs):
branch_graph.inputs = input_list
# Rewrite the FuncGraphs' state to reflect the new inputs.
branch_graph.captures = collections.OrderedDict(
zip(new_inputs, branch_graph.inputs))
branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
return new_inputs
@ -751,7 +749,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
# XLA does not yet support optionals, so capture intermediates directly.
# TODO(skyewm,jpienaar): can XLA support optionals?
if tensor not in self.captures:
if tensor not in self.external_captures:
self.xla_intermediates.append(tensor)
self.op_needs_rewrite = True
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)

View File

@ -400,7 +400,7 @@ def _Captures(func_graph):
return func_graph.captures
else:
assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
return func_graph._captured # pylint: disable=protected-access
return func_graph._captured.items() # pylint: disable=protected-access
def _MaybeCaptured(t):
@ -415,7 +415,7 @@ def _MaybeCaptured(t):
# pylint: disable=protected-access
if (not isinstance(t, ops.EagerTensor) and
_IsFunction(t.op.graph) and t.op.type == "Placeholder"):
for input_t, placeholder_t in _Captures(t.op.graph).items():
for input_t, placeholder_t in _Captures(t.op.graph):
if t == placeholder_t:
return _MaybeCaptured(input_t)
# pylint: enable=protected-access
@ -481,7 +481,7 @@ def _Consumers(t, func_graphs):
"""
consumers = t.consumers()
for func in func_graphs:
for input_t, placeholder in _Captures(func).items():
for input_t, placeholder in _Captures(func):
if input_t == t:
consumers.extend(_Consumers(placeholder, func_graphs))
return consumers

View File

@ -203,7 +203,7 @@ def while_loop(cond,
assert (cond_graph.external_captures ==
body_graph.external_captures[:num_cond_captures])
for body_capture in body_graph.external_captures[num_cond_captures:]:
assert body_capture not in cond_graph.captures
assert body_capture not in cond_graph.external_captures
cond_graph.capture(body_capture)
# Make sure that the shapes of the loop outputs are compatible with the
@ -497,7 +497,7 @@ def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
# `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
# 2. Resources, which are output as is.
# 3. Forward graph loop invariants, which are output as is.
for external_capture, internal_capture in grad_func_graph.captures.items():
for external_capture, internal_capture in grad_func_graph.captures:
if internal_capture in grad_func_graph.popped_tensor_lists:
new_output = grad_func_graph.popped_tensor_lists[internal_capture]
elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(

View File

@ -180,7 +180,7 @@ class Loader(object):
concrete_function.graph.capture_distributed_variable(
bound_input, internal_capture)
else:
concrete_function.graph.captures[bound_input] = internal_capture
concrete_function.graph._captures[bound_input] = internal_capture # pylint: disable=protected-access
if internal_capture.dtype == dtypes.resource:
if resource_variable_ops.is_resource_variable(bound_input):
try:

View File

@ -874,14 +874,15 @@ class LoadTest(test.TestCase, parameterized.TestCase):
root = Root()
self.assertIn(root.v.handle,
root.use_v.get_concrete_function().graph.captures)
root.use_v.get_concrete_function().graph.external_captures)
for _ in range(cycles):
root = self.cycle(root, 1, signatures=root.use_v.get_concrete_function())
func_captures = root.use_v.get_concrete_function().graph.captures
func_captures = root.use_v.get_concrete_function().graph.external_captures
self.assertLen(func_captures, 2)
self.assertIn(root.v.handle, func_captures)
self.assertIn(root.v1.handle, func_captures)
signature_captures = root.signatures["serving_default"].graph.captures
signature_captures = root.signatures[
"serving_default"].graph.external_captures
self.assertLen(signature_captures, 2)
self.assertIn(root.v.handle, signature_captures)
self.assertIn(root.v1.handle, signature_captures)

View File

@ -322,7 +322,7 @@ def _map_captures_to_created_tensors(
`resource_map`.
"""
export_captures = []
for exterior, interior in original_captures.items():
for exterior, interior in original_captures:
mapped_resource = resource_map.get(exterior, None)
if mapped_resource is None:
raise AssertionError(