diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index b7822bb4225..3380326f8a4 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -721,9 +721,7 @@ class _TapeGradientFunctions(object): """Forward+backward functions where the backward function sees `outputs`.""" # First figure out which of `outputs` are trainable. We'll accept gradients # for each of these in the backward function. - handles_to_variables = {self._func_graph.captures[v.handle]: v - for v in self._func_graph.variables - if v.handle in self._func_graph.captures} + handles_to_variables = self._func_graph.variable_captures trainable_outputs = [] for output in outputs: if gradients_util.IsTrainable(output): @@ -753,8 +751,9 @@ class _TapeGradientFunctions(object): src_graph=self._func_graph) captures_from_forward = [ - c for c in backwards_graph.captures.keys() if - not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] + c for c in backwards_graph.external_captures + if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph + ] existing_outputs = set(self._func_graph.outputs) for capture in captures_from_forward: if capture not in existing_outputs: @@ -770,7 +769,7 @@ class _TapeGradientFunctions(object): # `backward_function` correspond to outputs (including # side outputs) of `self._tape_forward_function`. backwards_graph.inputs = ( - gradients_wrt_outputs + list(backwards_graph.captures.values())) + gradients_wrt_outputs + backwards_graph.internal_captures) backwards_graph.outputs.extend( grad for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) @@ -980,9 +979,8 @@ class ConcreteFunction(object): self._arg_keywords = None self._num_positional_args = None self._func_graph = func_graph - self._captured_inputs = list(self._func_graph.captures.keys()) - self._captured_closures = [ - x[0] for x in self._func_graph.deferred_captures.values()] + self._captured_inputs = self._func_graph.external_captures + self._captured_closures = self._func_graph.deferred_external_captures self._output_shapes = tuple( output.shape for output in self._func_graph.outputs) attrs = _parse_func_attrs(attrs or {}) diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py index 2b6dfa72588..4a9a8f8d482 100644 --- a/tensorflow/python/eager/lift_to_graph.py +++ b/tensorflow/python/eager/lift_to_graph.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import collections -import six from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops @@ -288,11 +287,14 @@ def lift_to_graph(init_tensors, # When lifting from one FuncGraph to another, we will need to capture the # relevant tensors as well. - captures = collections.OrderedDict() + captures = [] + inverse_captures = {} + internal_captures = [] if (isinstance(base_graph, func_graph.FuncGraph) and isinstance(graph, func_graph.FuncGraph)): captures = base_graph.captures - inverse_captures = {v: k for k, v in captures.items()} + inverse_captures = {v: k for k, v in captures} + internal_captures = base_graph.internal_captures # ops_to_copy now holds a reverse topologically sorted list of ops which # ends in the initializer. We copy those to the outermost graph and @@ -302,7 +304,7 @@ def lift_to_graph(init_tensors, }) # Pass through variables. source_ops = set() # Add the sources in the same order as the original graph. - for s in six.itervalues(captures): + for s in internal_captures: if s in sources: sources.remove(s) source_ops.add(s.op) diff --git a/tensorflow/python/eager/lift_to_graph_test.py b/tensorflow/python/eager/lift_to_graph_test.py index 619b9dc4a7e..90db3ebb0f5 100644 --- a/tensorflow/python/eager/lift_to_graph_test.py +++ b/tensorflow/python/eager/lift_to_graph_test.py @@ -41,7 +41,7 @@ class LiftToGraphTest(test.TestCase): return v1 + v2 + v3 concrete_fn = fn.get_concrete_function() - original_captures = concrete_fn.graph.captures + original_captures = concrete_fn.graph.internal_captures outputs = concrete_fn.graph.outputs for _ in range(100): @@ -49,11 +49,10 @@ class LiftToGraphTest(test.TestCase): lift_to_graph.lift_to_graph( outputs, g, add_sources=True, handle_captures=True) - lifted_captures = g.captures + lifted_captures = g.internal_captures self.assertLen(lifted_captures, 3) - for original_capture, lifted_capture in zip(original_captures.values(), - lifted_captures.values()): - self.assertEqual(original_capture.name, lifted_capture.name) + for original, lifted in zip(original_captures, lifted_captures): + self.assertEqual(original.name, lifted.name) def testClassAttrsRemoved(self): """Tests that _class attrs (from colocate_with()) are removed.""" diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index ad2f24edbe2..269ec344b75 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -116,8 +116,7 @@ def _lift_single_variable(old_variable, graph, variable_holder): trainable=old_variable.trainable, extra_handle_data=old_variable.handle) new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access - graph.inputs.append(old_variable.handle) - graph.captures[new_variable.handle] = old_variable.handle + graph.add_capture(new_variable.handle, old_variable.handle) # Now that we've added the new variable to graph.captures, # graph.capture will use that cached value and do some post-processing # on the capture like recording it on the tape. @@ -311,10 +310,9 @@ class WrappedFunction(function.ConcreteFunction): pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches) pruned_graph.control_outputs.extend( [lift_map[operation] for operation in operation_fetches]) - for external_capture, internal_capture in self.graph.captures.items(): - pruned_graph.captures[external_capture] = lift_map[internal_capture] pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds) - pruned_graph.inputs.extend(pruned_graph.captures.values()) + for external_capture, internal_capture in self.graph.captures: + pruned_graph.add_capture(external_capture, lift_map[internal_capture]) for ti in tensor_infos: if ti.WhichOneof("encoding") == "name": # Dense tensors only t = pruned_graph.as_graph_element(ti.name) diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py index 4e2e24ca6e4..791c63fdaf0 100644 --- a/tensorflow/python/framework/convert_to_constants.py +++ b/tensorflow/python/framework/convert_to_constants.py @@ -182,7 +182,7 @@ def _get_tensor_data(func): } # Iterates through all captures which are represented as Placeholders. - for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures.items()): + for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures): tensor_name = _get_tensor_name(name_tensor.name) is_variable = idx in map_index_to_variable if is_variable: @@ -352,7 +352,7 @@ def _construct_concrete_function(func, output_graph_def, ConcreteFunction. """ # Create a ConcreteFunction from the new GraphDef. - input_tensors = list(func.graph.captures.values()) + input_tensors = func.graph.internal_captures converted_inputs = set( [input_tensors[index] for index in converted_input_indices]) not_converted_inputs = set(func.inputs).difference(converted_inputs) diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 2e6e1190488..24858c1bbde 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -152,9 +152,6 @@ class FuncGraph(ops.Graph): or the global default Graph. captures: Maps external tensor -> internal tensor (i.e. input placeholder). The entries are in the order they were captured. - deferred_captures: Maps arbitrary key -> (closure, nest of placeholders), - where at function call time the value of closure() will be used to feed - the nest of placeholders. control_captures: Set of external ops on which this graph has a control dependency. seed: The graph-level random seed. @@ -193,12 +190,15 @@ class FuncGraph(ops.Graph): self._weak_variables = [] self._watched_variables = weakref.WeakSet() self.outer_graph = ops.get_default_graph() - self.captures = py_collections.OrderedDict() + self._captures = py_collections.OrderedDict() # If not None, records the names of output args of this function. Used to # preserve the output names in the signature of a serialized+deserialized # function. Private at the moment mostly because it's often out of date. self._output_names = None - self.deferred_captures = py_collections.OrderedDict() + # Maps arbitrary key -> (closure, nest of placeholders), where at function + # call time the value of closure() will be used to feed the nest of + # placeholders. + self._deferred_captures = py_collections.OrderedDict() # Inherit capture-by-value from outer graph. if capture_by_value is not None: self.capture_by_value = capture_by_value @@ -273,7 +273,7 @@ class FuncGraph(ops.Graph): """ if key is None: key = object() - if key not in self.deferred_captures: + if key not in self._deferred_captures: def convert_to_placeholder(s): if not isinstance(s, tensor_spec.TensorSpec): @@ -296,8 +296,8 @@ class FuncGraph(ops.Graph): # pylint: enable=protected-access return nest.flatten(y, expand_composites=True) - self.deferred_captures[key] = (wrapped_closure, placeholder) - return self.deferred_captures[key][1] + self._deferred_captures[key] = (wrapped_closure, placeholder) + return self._deferred_captures[key][1] def control_dependencies(self, control_inputs): """Handles control dependencies. @@ -439,7 +439,7 @@ class FuncGraph(ops.Graph): op_def=None, compute_device=True): # When capturing by value, do the read outside - reverse_captures = dict((v, k) for k, v in self.captures.items()) + reverse_captures = dict((v, k) for k, v in self.captures) uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs] with ops.init_scope(): if context.executing_eagerly(): @@ -584,31 +584,82 @@ class FuncGraph(ops.Graph): return tensor def _capture_helper(self, tensor, name): - captured_tensor = self.captures.get(tensor, None) - if captured_tensor is None: - captured_tensor = _create_substitute_placeholder(tensor, name=name, - dtype=tensor.dtype) - self.captures[tensor] = captured_tensor - self.inputs.append(captured_tensor) - tape.record_operation("captured_value", [captured_tensor], [tensor], + placeholder = self._captures.get(tensor, None) + if placeholder is None: + placeholder = _create_substitute_placeholder( + tensor, name=name, dtype=tensor.dtype) + self.add_capture(tensor, placeholder) + tape.record_operation("captured_value", [placeholder], [tensor], lambda x: [x]) - return captured_tensor + return placeholder + + @property + def captures(self): + """Order list of tuples containing external and internal captures.""" + return self._captures.items() + + def add_capture(self, tensor, placeholder): + """Capture a specific tensor and utilize the provided placeholder. + + Args: + tensor: Tensor to captures. + placeholder: Provided placeholder for the tensor. + """ + self._captures[tensor] = placeholder + self.inputs.append(placeholder) + + def reset_captures(self, capture_list): + """Set the captures with the provided list of captures & placeholder.""" + self._captures = py_collections.OrderedDict(capture_list) + + def pop_capture(self, tensor): + """Remove the capture and return the generated placeholder.""" + return self._captures.pop(tensor, None) + + def clear_captures(self): + # TODO(b/115366440): Delete this method when a custom OrderedDict is added. + # Clearing captures using clear() leaves some cycles around. + while self._captures: + self._captures.popitem() + memory.dismantle_ordered_dict(self._captures) + while self._deferred_captures: + self._deferred_captures.popitem() + memory.dismantle_ordered_dict(self._deferred_captures) def capture_distributed_variable(self, variable, placeholder): """Add given distributed variable to captures with given placeholder.""" - self.captures[variable] = placeholder + self._captures[variable] = placeholder tape.record_operation("captured_value", [placeholder], [variable], lambda x: [x]) @property def external_captures(self): """External tensors captured by this function.""" - return list(self.captures.keys()) + return list(self._captures.keys()) @property def internal_captures(self): """Placeholders in this function corresponding captured tensors.""" - return list(self.captures.values()) + return list(self._captures.values()) + + @property + def deferred_external_captures(self): + """Ordered nest of tensors whose placeholders will be fed at call time.""" + return [c[0] for c in self._deferred_captures.values()] + + @property + def deferred_internal_captures(self): + """List of nest of placeholders which at call time will be fed.""" + return [c[1] for c in self._deferred_captures.values()] + + @property + def variable_captures(self): + """Map of variable handles to variables that as in the list of captures.""" + return { + self._captures[v.handle]: v + for v in self.variables + if v.handle in self._captures + } def func_graph_from_py_func(name, @@ -813,7 +864,7 @@ def func_graph_from_py_func(name, # Even if an argument variable was not used in the function, we've # already manually captured the resource Tensor when creating argument # placeholders. - resource_placeholder = func_graph.captures.pop(arg.handle, None) + resource_placeholder = func_graph.pop_capture(arg.handle) if resource_placeholder is None: continue arg_variables.add(arg) @@ -822,12 +873,8 @@ def func_graph_from_py_func(name, inputs.append(arg) variables = [v for v in graph_variables if v not in arg_variables] func_graph.inputs = ( - inputs + - list(func_graph.captures.values()) + - nest.flatten( - [x[1] for x in func_graph.deferred_captures.values()], - expand_composites=True)) - + inputs + func_graph.internal_captures + nest.flatten( + func_graph.deferred_internal_captures, expand_composites=True)) func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( @@ -854,7 +901,7 @@ def maybe_captured(tensor): """ if (not isinstance(tensor, ops.EagerTensor) and tensor.op.graph.building_function and tensor.op.type == "Placeholder"): - for input_t, placeholder_t in tensor.op.graph.captures.items(): + for input_t, placeholder_t in tensor.op.graph.captures: if tensor == placeholder_t: return maybe_captured(input_t) # pylint: enable=protected-access @@ -1064,12 +1111,5 @@ def dismantle_func_graph(func_graph): func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable after this function. """ - # TODO(b/115366440): Delete this method when a custom OrderedDict is added. - # Clearing captures using clear() leaves some cycles around. - while func_graph.captures: - func_graph.captures.popitem() - memory.dismantle_ordered_dict(func_graph.captures) - while func_graph.deferred_captures: - func_graph.deferred_captures.popitem() - memory.dismantle_ordered_dict(func_graph.deferred_captures) + func_graph.clear_captures() ops.dismantle_graph(func_graph) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 186b4f24639..e6b39e3668a 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3457,8 +3457,7 @@ class EagerExecutionFunction(object): with ops.control_dependencies(updates_ops): self.outputs[0] = array_ops.identity(self.outputs[0]) - exec_graph.inputs = self._input_references + list( - exec_graph.captures.values()) + exec_graph.inputs = self._input_references + exec_graph.internal_captures exec_graph.outputs = self.outputs graph_fn = eager_function.ConcreteFunction(exec_graph) diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py index 386aff3dd39..bd29649795b 100644 --- a/tensorflow/python/ops/cond_v2.py +++ b/tensorflow/python/ops/cond_v2.py @@ -275,8 +275,7 @@ def get_func_graphs(op): fdef, input_shapes) for external_t, internal_t in zip(inputs, func_graph.inputs): custom_gradient.copy_handle_data(external_t, internal_t) - func_graph.captures = collections.OrderedDict(zip(inputs, - func_graph.inputs)) + func_graph.reset_captures(zip(inputs, func_graph.inputs)) # Link the op so that the gradient code can use it. func_graph._forward_cond = op return func_graph @@ -482,8 +481,7 @@ def _make_inputs_match(branch_graphs, branch_inputs): branch_graph.inputs = input_list # Rewrite the FuncGraphs' state to reflect the new inputs. - branch_graph.captures = collections.OrderedDict( - zip(new_inputs, branch_graph.inputs)) + branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs)) return new_inputs @@ -751,7 +749,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph): if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): # XLA does not yet support optionals, so capture intermediates directly. # TODO(skyewm,jpienaar): can XLA support optionals? - if tensor not in self.captures: + if tensor not in self.external_captures: self.xla_intermediates.append(tensor) self.op_needs_rewrite = True return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py index 84a21d0bab5..e262f8405ea 100644 --- a/tensorflow/python/ops/gradients_util.py +++ b/tensorflow/python/ops/gradients_util.py @@ -400,7 +400,7 @@ def _Captures(func_graph): return func_graph.captures else: assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access - return func_graph._captured # pylint: disable=protected-access + return func_graph._captured.items() # pylint: disable=protected-access def _MaybeCaptured(t): @@ -415,7 +415,7 @@ def _MaybeCaptured(t): # pylint: disable=protected-access if (not isinstance(t, ops.EagerTensor) and _IsFunction(t.op.graph) and t.op.type == "Placeholder"): - for input_t, placeholder_t in _Captures(t.op.graph).items(): + for input_t, placeholder_t in _Captures(t.op.graph): if t == placeholder_t: return _MaybeCaptured(input_t) # pylint: enable=protected-access @@ -481,7 +481,7 @@ def _Consumers(t, func_graphs): """ consumers = t.consumers() for func in func_graphs: - for input_t, placeholder in _Captures(func).items(): + for input_t, placeholder in _Captures(func): if input_t == t: consumers.extend(_Consumers(placeholder, func_graphs)) return consumers diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 7527c5cfd3e..9c712874a2b 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -203,7 +203,7 @@ def while_loop(cond, assert (cond_graph.external_captures == body_graph.external_captures[:num_cond_captures]) for body_capture in body_graph.external_captures[num_cond_captures:]: - assert body_capture not in cond_graph.captures + assert body_capture not in cond_graph.external_captures cond_graph.capture(body_capture) # Make sure that the shapes of the loop outputs are compatible with the @@ -497,7 +497,7 @@ def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, # `popped_tensor_lists` by `_WhileBodyGradFuncGraph`. # 2. Resources, which are output as is. # 3. Forward graph loop invariants, which are output as is. - for external_capture, internal_capture in grad_func_graph.captures.items(): + for external_capture, internal_capture in grad_func_graph.captures: if internal_capture in grad_func_graph.popped_tensor_lists: new_output = grad_func_graph.popped_tensor_lists[internal_capture] elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant( diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index f2994472aa1..190186c268e 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -180,7 +180,7 @@ class Loader(object): concrete_function.graph.capture_distributed_variable( bound_input, internal_capture) else: - concrete_function.graph.captures[bound_input] = internal_capture + concrete_function.graph._captures[bound_input] = internal_capture # pylint: disable=protected-access if internal_capture.dtype == dtypes.resource: if resource_variable_ops.is_resource_variable(bound_input): try: diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index e28ee4b5546..24abd9c552b 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -874,14 +874,15 @@ class LoadTest(test.TestCase, parameterized.TestCase): root = Root() self.assertIn(root.v.handle, - root.use_v.get_concrete_function().graph.captures) + root.use_v.get_concrete_function().graph.external_captures) for _ in range(cycles): root = self.cycle(root, 1, signatures=root.use_v.get_concrete_function()) - func_captures = root.use_v.get_concrete_function().graph.captures + func_captures = root.use_v.get_concrete_function().graph.external_captures self.assertLen(func_captures, 2) self.assertIn(root.v.handle, func_captures) self.assertIn(root.v1.handle, func_captures) - signature_captures = root.signatures["serving_default"].graph.captures + signature_captures = root.signatures[ + "serving_default"].graph.external_captures self.assertLen(signature_captures, 2) self.assertIn(root.v.handle, signature_captures) self.assertIn(root.v1.handle, signature_captures) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 9520b36a667..e12c03def80 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -322,7 +322,7 @@ def _map_captures_to_created_tensors( `resource_map`. """ export_captures = [] - for exterior, interior in original_captures.items(): + for exterior, interior in original_captures: mapped_resource = resource_map.get(exterior, None) if mapped_resource is None: raise AssertionError(