Make func_graph captures private to class
- Modify callers to use helper methods vs directly touching the captures dictionary. - Change captures call to return list of ordered tuples instead of dict PiperOrigin-RevId: 260634556
This commit is contained in:
parent
ace39d8e0a
commit
bdbd1d27dc
@ -721,9 +721,7 @@ class _TapeGradientFunctions(object):
|
|||||||
"""Forward+backward functions where the backward function sees `outputs`."""
|
"""Forward+backward functions where the backward function sees `outputs`."""
|
||||||
# First figure out which of `outputs` are trainable. We'll accept gradients
|
# First figure out which of `outputs` are trainable. We'll accept gradients
|
||||||
# for each of these in the backward function.
|
# for each of these in the backward function.
|
||||||
handles_to_variables = {self._func_graph.captures[v.handle]: v
|
handles_to_variables = self._func_graph.variable_captures
|
||||||
for v in self._func_graph.variables
|
|
||||||
if v.handle in self._func_graph.captures}
|
|
||||||
trainable_outputs = []
|
trainable_outputs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
if gradients_util.IsTrainable(output):
|
if gradients_util.IsTrainable(output):
|
||||||
@ -753,8 +751,9 @@ class _TapeGradientFunctions(object):
|
|||||||
src_graph=self._func_graph)
|
src_graph=self._func_graph)
|
||||||
|
|
||||||
captures_from_forward = [
|
captures_from_forward = [
|
||||||
c for c in backwards_graph.captures.keys() if
|
c for c in backwards_graph.external_captures
|
||||||
not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph]
|
if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph
|
||||||
|
]
|
||||||
existing_outputs = set(self._func_graph.outputs)
|
existing_outputs = set(self._func_graph.outputs)
|
||||||
for capture in captures_from_forward:
|
for capture in captures_from_forward:
|
||||||
if capture not in existing_outputs:
|
if capture not in existing_outputs:
|
||||||
@ -770,7 +769,7 @@ class _TapeGradientFunctions(object):
|
|||||||
# `backward_function` correspond to outputs (including
|
# `backward_function` correspond to outputs (including
|
||||||
# side outputs) of `self._tape_forward_function`.
|
# side outputs) of `self._tape_forward_function`.
|
||||||
backwards_graph.inputs = (
|
backwards_graph.inputs = (
|
||||||
gradients_wrt_outputs + list(backwards_graph.captures.values()))
|
gradients_wrt_outputs + backwards_graph.internal_captures)
|
||||||
backwards_graph.outputs.extend(
|
backwards_graph.outputs.extend(
|
||||||
grad
|
grad
|
||||||
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True)
|
||||||
@ -980,9 +979,8 @@ class ConcreteFunction(object):
|
|||||||
self._arg_keywords = None
|
self._arg_keywords = None
|
||||||
self._num_positional_args = None
|
self._num_positional_args = None
|
||||||
self._func_graph = func_graph
|
self._func_graph = func_graph
|
||||||
self._captured_inputs = list(self._func_graph.captures.keys())
|
self._captured_inputs = self._func_graph.external_captures
|
||||||
self._captured_closures = [
|
self._captured_closures = self._func_graph.deferred_external_captures
|
||||||
x[0] for x in self._func_graph.deferred_captures.values()]
|
|
||||||
self._output_shapes = tuple(
|
self._output_shapes = tuple(
|
||||||
output.shape for output in self._func_graph.outputs)
|
output.shape for output in self._func_graph.outputs)
|
||||||
attrs = _parse_func_attrs(attrs or {})
|
attrs = _parse_func_attrs(attrs or {})
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.python.framework import func_graph
|
from tensorflow.python.framework import func_graph
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -288,11 +287,14 @@ def lift_to_graph(init_tensors,
|
|||||||
|
|
||||||
# When lifting from one FuncGraph to another, we will need to capture the
|
# When lifting from one FuncGraph to another, we will need to capture the
|
||||||
# relevant tensors as well.
|
# relevant tensors as well.
|
||||||
captures = collections.OrderedDict()
|
captures = []
|
||||||
|
inverse_captures = {}
|
||||||
|
internal_captures = []
|
||||||
if (isinstance(base_graph, func_graph.FuncGraph) and
|
if (isinstance(base_graph, func_graph.FuncGraph) and
|
||||||
isinstance(graph, func_graph.FuncGraph)):
|
isinstance(graph, func_graph.FuncGraph)):
|
||||||
captures = base_graph.captures
|
captures = base_graph.captures
|
||||||
inverse_captures = {v: k for k, v in captures.items()}
|
inverse_captures = {v: k for k, v in captures}
|
||||||
|
internal_captures = base_graph.internal_captures
|
||||||
|
|
||||||
# ops_to_copy now holds a reverse topologically sorted list of ops which
|
# ops_to_copy now holds a reverse topologically sorted list of ops which
|
||||||
# ends in the initializer. We copy those to the outermost graph and
|
# ends in the initializer. We copy those to the outermost graph and
|
||||||
@ -302,7 +304,7 @@ def lift_to_graph(init_tensors,
|
|||||||
}) # Pass through variables.
|
}) # Pass through variables.
|
||||||
source_ops = set()
|
source_ops = set()
|
||||||
# Add the sources in the same order as the original graph.
|
# Add the sources in the same order as the original graph.
|
||||||
for s in six.itervalues(captures):
|
for s in internal_captures:
|
||||||
if s in sources:
|
if s in sources:
|
||||||
sources.remove(s)
|
sources.remove(s)
|
||||||
source_ops.add(s.op)
|
source_ops.add(s.op)
|
||||||
|
@ -41,7 +41,7 @@ class LiftToGraphTest(test.TestCase):
|
|||||||
return v1 + v2 + v3
|
return v1 + v2 + v3
|
||||||
|
|
||||||
concrete_fn = fn.get_concrete_function()
|
concrete_fn = fn.get_concrete_function()
|
||||||
original_captures = concrete_fn.graph.captures
|
original_captures = concrete_fn.graph.internal_captures
|
||||||
outputs = concrete_fn.graph.outputs
|
outputs = concrete_fn.graph.outputs
|
||||||
|
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
@ -49,11 +49,10 @@ class LiftToGraphTest(test.TestCase):
|
|||||||
|
|
||||||
lift_to_graph.lift_to_graph(
|
lift_to_graph.lift_to_graph(
|
||||||
outputs, g, add_sources=True, handle_captures=True)
|
outputs, g, add_sources=True, handle_captures=True)
|
||||||
lifted_captures = g.captures
|
lifted_captures = g.internal_captures
|
||||||
self.assertLen(lifted_captures, 3)
|
self.assertLen(lifted_captures, 3)
|
||||||
for original_capture, lifted_capture in zip(original_captures.values(),
|
for original, lifted in zip(original_captures, lifted_captures):
|
||||||
lifted_captures.values()):
|
self.assertEqual(original.name, lifted.name)
|
||||||
self.assertEqual(original_capture.name, lifted_capture.name)
|
|
||||||
|
|
||||||
def testClassAttrsRemoved(self):
|
def testClassAttrsRemoved(self):
|
||||||
"""Tests that _class attrs (from colocate_with()) are removed."""
|
"""Tests that _class attrs (from colocate_with()) are removed."""
|
||||||
|
@ -116,8 +116,7 @@ def _lift_single_variable(old_variable, graph, variable_holder):
|
|||||||
trainable=old_variable.trainable,
|
trainable=old_variable.trainable,
|
||||||
extra_handle_data=old_variable.handle)
|
extra_handle_data=old_variable.handle)
|
||||||
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access
|
||||||
graph.inputs.append(old_variable.handle)
|
graph.add_capture(new_variable.handle, old_variable.handle)
|
||||||
graph.captures[new_variable.handle] = old_variable.handle
|
|
||||||
# Now that we've added the new variable to graph.captures,
|
# Now that we've added the new variable to graph.captures,
|
||||||
# graph.capture will use that cached value and do some post-processing
|
# graph.capture will use that cached value and do some post-processing
|
||||||
# on the capture like recording it on the tape.
|
# on the capture like recording it on the tape.
|
||||||
@ -311,10 +310,9 @@ class WrappedFunction(function.ConcreteFunction):
|
|||||||
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
pruned_graph.outputs.extend(lift_map[x] for x in tensor_fetches)
|
||||||
pruned_graph.control_outputs.extend(
|
pruned_graph.control_outputs.extend(
|
||||||
[lift_map[operation] for operation in operation_fetches])
|
[lift_map[operation] for operation in operation_fetches])
|
||||||
for external_capture, internal_capture in self.graph.captures.items():
|
|
||||||
pruned_graph.captures[external_capture] = lift_map[internal_capture]
|
|
||||||
pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
|
pruned_graph.inputs.extend(lift_map[x] for x in flat_feeds)
|
||||||
pruned_graph.inputs.extend(pruned_graph.captures.values())
|
for external_capture, internal_capture in self.graph.captures:
|
||||||
|
pruned_graph.add_capture(external_capture, lift_map[internal_capture])
|
||||||
for ti in tensor_infos:
|
for ti in tensor_infos:
|
||||||
if ti.WhichOneof("encoding") == "name": # Dense tensors only
|
if ti.WhichOneof("encoding") == "name": # Dense tensors only
|
||||||
t = pruned_graph.as_graph_element(ti.name)
|
t = pruned_graph.as_graph_element(ti.name)
|
||||||
|
@ -182,7 +182,7 @@ def _get_tensor_data(func):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Iterates through all captures which are represented as Placeholders.
|
# Iterates through all captures which are represented as Placeholders.
|
||||||
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures.items()):
|
for idx, (val_tensor, name_tensor) in enumerate(func.graph.captures):
|
||||||
tensor_name = _get_tensor_name(name_tensor.name)
|
tensor_name = _get_tensor_name(name_tensor.name)
|
||||||
is_variable = idx in map_index_to_variable
|
is_variable = idx in map_index_to_variable
|
||||||
if is_variable:
|
if is_variable:
|
||||||
@ -352,7 +352,7 @@ def _construct_concrete_function(func, output_graph_def,
|
|||||||
ConcreteFunction.
|
ConcreteFunction.
|
||||||
"""
|
"""
|
||||||
# Create a ConcreteFunction from the new GraphDef.
|
# Create a ConcreteFunction from the new GraphDef.
|
||||||
input_tensors = list(func.graph.captures.values())
|
input_tensors = func.graph.internal_captures
|
||||||
converted_inputs = set(
|
converted_inputs = set(
|
||||||
[input_tensors[index] for index in converted_input_indices])
|
[input_tensors[index] for index in converted_input_indices])
|
||||||
not_converted_inputs = set(func.inputs).difference(converted_inputs)
|
not_converted_inputs = set(func.inputs).difference(converted_inputs)
|
||||||
|
@ -152,9 +152,6 @@ class FuncGraph(ops.Graph):
|
|||||||
or the global default Graph.
|
or the global default Graph.
|
||||||
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
|
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
|
||||||
The entries are in the order they were captured.
|
The entries are in the order they were captured.
|
||||||
deferred_captures: Maps arbitrary key -> (closure, nest of placeholders),
|
|
||||||
where at function call time the value of closure() will be used to feed
|
|
||||||
the nest of placeholders.
|
|
||||||
control_captures: Set of external ops on which this graph has a control
|
control_captures: Set of external ops on which this graph has a control
|
||||||
dependency.
|
dependency.
|
||||||
seed: The graph-level random seed.
|
seed: The graph-level random seed.
|
||||||
@ -193,12 +190,15 @@ class FuncGraph(ops.Graph):
|
|||||||
self._weak_variables = []
|
self._weak_variables = []
|
||||||
self._watched_variables = weakref.WeakSet()
|
self._watched_variables = weakref.WeakSet()
|
||||||
self.outer_graph = ops.get_default_graph()
|
self.outer_graph = ops.get_default_graph()
|
||||||
self.captures = py_collections.OrderedDict()
|
self._captures = py_collections.OrderedDict()
|
||||||
# If not None, records the names of output args of this function. Used to
|
# If not None, records the names of output args of this function. Used to
|
||||||
# preserve the output names in the signature of a serialized+deserialized
|
# preserve the output names in the signature of a serialized+deserialized
|
||||||
# function. Private at the moment mostly because it's often out of date.
|
# function. Private at the moment mostly because it's often out of date.
|
||||||
self._output_names = None
|
self._output_names = None
|
||||||
self.deferred_captures = py_collections.OrderedDict()
|
# Maps arbitrary key -> (closure, nest of placeholders), where at function
|
||||||
|
# call time the value of closure() will be used to feed the nest of
|
||||||
|
# placeholders.
|
||||||
|
self._deferred_captures = py_collections.OrderedDict()
|
||||||
# Inherit capture-by-value from outer graph.
|
# Inherit capture-by-value from outer graph.
|
||||||
if capture_by_value is not None:
|
if capture_by_value is not None:
|
||||||
self.capture_by_value = capture_by_value
|
self.capture_by_value = capture_by_value
|
||||||
@ -273,7 +273,7 @@ class FuncGraph(ops.Graph):
|
|||||||
"""
|
"""
|
||||||
if key is None:
|
if key is None:
|
||||||
key = object()
|
key = object()
|
||||||
if key not in self.deferred_captures:
|
if key not in self._deferred_captures:
|
||||||
|
|
||||||
def convert_to_placeholder(s):
|
def convert_to_placeholder(s):
|
||||||
if not isinstance(s, tensor_spec.TensorSpec):
|
if not isinstance(s, tensor_spec.TensorSpec):
|
||||||
@ -296,8 +296,8 @@ class FuncGraph(ops.Graph):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
return nest.flatten(y, expand_composites=True)
|
return nest.flatten(y, expand_composites=True)
|
||||||
|
|
||||||
self.deferred_captures[key] = (wrapped_closure, placeholder)
|
self._deferred_captures[key] = (wrapped_closure, placeholder)
|
||||||
return self.deferred_captures[key][1]
|
return self._deferred_captures[key][1]
|
||||||
|
|
||||||
def control_dependencies(self, control_inputs):
|
def control_dependencies(self, control_inputs):
|
||||||
"""Handles control dependencies.
|
"""Handles control dependencies.
|
||||||
@ -439,7 +439,7 @@ class FuncGraph(ops.Graph):
|
|||||||
op_def=None,
|
op_def=None,
|
||||||
compute_device=True):
|
compute_device=True):
|
||||||
# When capturing by value, do the read outside
|
# When capturing by value, do the read outside
|
||||||
reverse_captures = dict((v, k) for k, v in self.captures.items())
|
reverse_captures = dict((v, k) for k, v in self.captures)
|
||||||
uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs]
|
uncaptured_inputs = [reverse_captures.get(t, t) for t in inputs]
|
||||||
with ops.init_scope():
|
with ops.init_scope():
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
@ -584,31 +584,82 @@ class FuncGraph(ops.Graph):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def _capture_helper(self, tensor, name):
|
def _capture_helper(self, tensor, name):
|
||||||
captured_tensor = self.captures.get(tensor, None)
|
placeholder = self._captures.get(tensor, None)
|
||||||
if captured_tensor is None:
|
if placeholder is None:
|
||||||
captured_tensor = _create_substitute_placeholder(tensor, name=name,
|
placeholder = _create_substitute_placeholder(
|
||||||
dtype=tensor.dtype)
|
tensor, name=name, dtype=tensor.dtype)
|
||||||
self.captures[tensor] = captured_tensor
|
self.add_capture(tensor, placeholder)
|
||||||
self.inputs.append(captured_tensor)
|
tape.record_operation("captured_value", [placeholder], [tensor],
|
||||||
tape.record_operation("captured_value", [captured_tensor], [tensor],
|
|
||||||
lambda x: [x])
|
lambda x: [x])
|
||||||
return captured_tensor
|
return placeholder
|
||||||
|
|
||||||
|
@property
|
||||||
|
def captures(self):
|
||||||
|
"""Order list of tuples containing external and internal captures."""
|
||||||
|
return self._captures.items()
|
||||||
|
|
||||||
|
def add_capture(self, tensor, placeholder):
|
||||||
|
"""Capture a specific tensor and utilize the provided placeholder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Tensor to captures.
|
||||||
|
placeholder: Provided placeholder for the tensor.
|
||||||
|
"""
|
||||||
|
self._captures[tensor] = placeholder
|
||||||
|
self.inputs.append(placeholder)
|
||||||
|
|
||||||
|
def reset_captures(self, capture_list):
|
||||||
|
"""Set the captures with the provided list of captures & placeholder."""
|
||||||
|
self._captures = py_collections.OrderedDict(capture_list)
|
||||||
|
|
||||||
|
def pop_capture(self, tensor):
|
||||||
|
"""Remove the capture and return the generated placeholder."""
|
||||||
|
return self._captures.pop(tensor, None)
|
||||||
|
|
||||||
|
def clear_captures(self):
|
||||||
|
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
||||||
|
# Clearing captures using clear() leaves some cycles around.
|
||||||
|
while self._captures:
|
||||||
|
self._captures.popitem()
|
||||||
|
memory.dismantle_ordered_dict(self._captures)
|
||||||
|
while self._deferred_captures:
|
||||||
|
self._deferred_captures.popitem()
|
||||||
|
memory.dismantle_ordered_dict(self._deferred_captures)
|
||||||
|
|
||||||
def capture_distributed_variable(self, variable, placeholder):
|
def capture_distributed_variable(self, variable, placeholder):
|
||||||
"""Add given distributed variable to captures with given placeholder."""
|
"""Add given distributed variable to captures with given placeholder."""
|
||||||
self.captures[variable] = placeholder
|
self._captures[variable] = placeholder
|
||||||
tape.record_operation("captured_value", [placeholder], [variable],
|
tape.record_operation("captured_value", [placeholder], [variable],
|
||||||
lambda x: [x])
|
lambda x: [x])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def external_captures(self):
|
def external_captures(self):
|
||||||
"""External tensors captured by this function."""
|
"""External tensors captured by this function."""
|
||||||
return list(self.captures.keys())
|
return list(self._captures.keys())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def internal_captures(self):
|
def internal_captures(self):
|
||||||
"""Placeholders in this function corresponding captured tensors."""
|
"""Placeholders in this function corresponding captured tensors."""
|
||||||
return list(self.captures.values())
|
return list(self._captures.values())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deferred_external_captures(self):
|
||||||
|
"""Ordered nest of tensors whose placeholders will be fed at call time."""
|
||||||
|
return [c[0] for c in self._deferred_captures.values()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def deferred_internal_captures(self):
|
||||||
|
"""List of nest of placeholders which at call time will be fed."""
|
||||||
|
return [c[1] for c in self._deferred_captures.values()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variable_captures(self):
|
||||||
|
"""Map of variable handles to variables that as in the list of captures."""
|
||||||
|
return {
|
||||||
|
self._captures[v.handle]: v
|
||||||
|
for v in self.variables
|
||||||
|
if v.handle in self._captures
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def func_graph_from_py_func(name,
|
def func_graph_from_py_func(name,
|
||||||
@ -813,7 +864,7 @@ def func_graph_from_py_func(name,
|
|||||||
# Even if an argument variable was not used in the function, we've
|
# Even if an argument variable was not used in the function, we've
|
||||||
# already manually captured the resource Tensor when creating argument
|
# already manually captured the resource Tensor when creating argument
|
||||||
# placeholders.
|
# placeholders.
|
||||||
resource_placeholder = func_graph.captures.pop(arg.handle, None)
|
resource_placeholder = func_graph.pop_capture(arg.handle)
|
||||||
if resource_placeholder is None:
|
if resource_placeholder is None:
|
||||||
continue
|
continue
|
||||||
arg_variables.add(arg)
|
arg_variables.add(arg)
|
||||||
@ -822,12 +873,8 @@ def func_graph_from_py_func(name,
|
|||||||
inputs.append(arg)
|
inputs.append(arg)
|
||||||
variables = [v for v in graph_variables if v not in arg_variables]
|
variables = [v for v in graph_variables if v not in arg_variables]
|
||||||
func_graph.inputs = (
|
func_graph.inputs = (
|
||||||
inputs +
|
inputs + func_graph.internal_captures + nest.flatten(
|
||||||
list(func_graph.captures.values()) +
|
func_graph.deferred_internal_captures, expand_composites=True))
|
||||||
nest.flatten(
|
|
||||||
[x[1] for x in func_graph.deferred_captures.values()],
|
|
||||||
expand_composites=True))
|
|
||||||
|
|
||||||
func_graph.structured_outputs = func_outputs
|
func_graph.structured_outputs = func_outputs
|
||||||
# Returning a closed-over tensor does not trigger convert_to_tensor.
|
# Returning a closed-over tensor does not trigger convert_to_tensor.
|
||||||
func_graph.outputs.extend(
|
func_graph.outputs.extend(
|
||||||
@ -854,7 +901,7 @@ def maybe_captured(tensor):
|
|||||||
"""
|
"""
|
||||||
if (not isinstance(tensor, ops.EagerTensor) and
|
if (not isinstance(tensor, ops.EagerTensor) and
|
||||||
tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
|
tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
|
||||||
for input_t, placeholder_t in tensor.op.graph.captures.items():
|
for input_t, placeholder_t in tensor.op.graph.captures:
|
||||||
if tensor == placeholder_t:
|
if tensor == placeholder_t:
|
||||||
return maybe_captured(input_t)
|
return maybe_captured(input_t)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@ -1064,12 +1111,5 @@ def dismantle_func_graph(func_graph):
|
|||||||
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
|
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
|
||||||
after this function.
|
after this function.
|
||||||
"""
|
"""
|
||||||
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
func_graph.clear_captures()
|
||||||
# Clearing captures using clear() leaves some cycles around.
|
|
||||||
while func_graph.captures:
|
|
||||||
func_graph.captures.popitem()
|
|
||||||
memory.dismantle_ordered_dict(func_graph.captures)
|
|
||||||
while func_graph.deferred_captures:
|
|
||||||
func_graph.deferred_captures.popitem()
|
|
||||||
memory.dismantle_ordered_dict(func_graph.deferred_captures)
|
|
||||||
ops.dismantle_graph(func_graph)
|
ops.dismantle_graph(func_graph)
|
||||||
|
@ -3457,8 +3457,7 @@ class EagerExecutionFunction(object):
|
|||||||
with ops.control_dependencies(updates_ops):
|
with ops.control_dependencies(updates_ops):
|
||||||
self.outputs[0] = array_ops.identity(self.outputs[0])
|
self.outputs[0] = array_ops.identity(self.outputs[0])
|
||||||
|
|
||||||
exec_graph.inputs = self._input_references + list(
|
exec_graph.inputs = self._input_references + exec_graph.internal_captures
|
||||||
exec_graph.captures.values())
|
|
||||||
exec_graph.outputs = self.outputs
|
exec_graph.outputs = self.outputs
|
||||||
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
graph_fn = eager_function.ConcreteFunction(exec_graph)
|
||||||
|
|
||||||
|
@ -275,8 +275,7 @@ def get_func_graphs(op):
|
|||||||
fdef, input_shapes)
|
fdef, input_shapes)
|
||||||
for external_t, internal_t in zip(inputs, func_graph.inputs):
|
for external_t, internal_t in zip(inputs, func_graph.inputs):
|
||||||
custom_gradient.copy_handle_data(external_t, internal_t)
|
custom_gradient.copy_handle_data(external_t, internal_t)
|
||||||
func_graph.captures = collections.OrderedDict(zip(inputs,
|
func_graph.reset_captures(zip(inputs, func_graph.inputs))
|
||||||
func_graph.inputs))
|
|
||||||
# Link the op so that the gradient code can use it.
|
# Link the op so that the gradient code can use it.
|
||||||
func_graph._forward_cond = op
|
func_graph._forward_cond = op
|
||||||
return func_graph
|
return func_graph
|
||||||
@ -482,8 +481,7 @@ def _make_inputs_match(branch_graphs, branch_inputs):
|
|||||||
branch_graph.inputs = input_list
|
branch_graph.inputs = input_list
|
||||||
|
|
||||||
# Rewrite the FuncGraphs' state to reflect the new inputs.
|
# Rewrite the FuncGraphs' state to reflect the new inputs.
|
||||||
branch_graph.captures = collections.OrderedDict(
|
branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs))
|
||||||
zip(new_inputs, branch_graph.inputs))
|
|
||||||
|
|
||||||
return new_inputs
|
return new_inputs
|
||||||
|
|
||||||
@ -751,7 +749,7 @@ class _CondGradFuncGraph(util.CondBranchFuncGraph):
|
|||||||
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()):
|
||||||
# XLA does not yet support optionals, so capture intermediates directly.
|
# XLA does not yet support optionals, so capture intermediates directly.
|
||||||
# TODO(skyewm,jpienaar): can XLA support optionals?
|
# TODO(skyewm,jpienaar): can XLA support optionals?
|
||||||
if tensor not in self.captures:
|
if tensor not in self.external_captures:
|
||||||
self.xla_intermediates.append(tensor)
|
self.xla_intermediates.append(tensor)
|
||||||
self.op_needs_rewrite = True
|
self.op_needs_rewrite = True
|
||||||
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
|
return super(_CondGradFuncGraph, self)._capture_helper(tensor, name)
|
||||||
|
@ -400,7 +400,7 @@ def _Captures(func_graph):
|
|||||||
return func_graph.captures
|
return func_graph.captures
|
||||||
else:
|
else:
|
||||||
assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
|
assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
|
||||||
return func_graph._captured # pylint: disable=protected-access
|
return func_graph._captured.items() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _MaybeCaptured(t):
|
def _MaybeCaptured(t):
|
||||||
@ -415,7 +415,7 @@ def _MaybeCaptured(t):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if (not isinstance(t, ops.EagerTensor) and
|
if (not isinstance(t, ops.EagerTensor) and
|
||||||
_IsFunction(t.op.graph) and t.op.type == "Placeholder"):
|
_IsFunction(t.op.graph) and t.op.type == "Placeholder"):
|
||||||
for input_t, placeholder_t in _Captures(t.op.graph).items():
|
for input_t, placeholder_t in _Captures(t.op.graph):
|
||||||
if t == placeholder_t:
|
if t == placeholder_t:
|
||||||
return _MaybeCaptured(input_t)
|
return _MaybeCaptured(input_t)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@ -481,7 +481,7 @@ def _Consumers(t, func_graphs):
|
|||||||
"""
|
"""
|
||||||
consumers = t.consumers()
|
consumers = t.consumers()
|
||||||
for func in func_graphs:
|
for func in func_graphs:
|
||||||
for input_t, placeholder in _Captures(func).items():
|
for input_t, placeholder in _Captures(func):
|
||||||
if input_t == t:
|
if input_t == t:
|
||||||
consumers.extend(_Consumers(placeholder, func_graphs))
|
consumers.extend(_Consumers(placeholder, func_graphs))
|
||||||
return consumers
|
return consumers
|
||||||
|
@ -203,7 +203,7 @@ def while_loop(cond,
|
|||||||
assert (cond_graph.external_captures ==
|
assert (cond_graph.external_captures ==
|
||||||
body_graph.external_captures[:num_cond_captures])
|
body_graph.external_captures[:num_cond_captures])
|
||||||
for body_capture in body_graph.external_captures[num_cond_captures:]:
|
for body_capture in body_graph.external_captures[num_cond_captures:]:
|
||||||
assert body_capture not in cond_graph.captures
|
assert body_capture not in cond_graph.external_captures
|
||||||
cond_graph.capture(body_capture)
|
cond_graph.capture(body_capture)
|
||||||
|
|
||||||
# Make sure that the shapes of the loop outputs are compatible with the
|
# Make sure that the shapes of the loop outputs are compatible with the
|
||||||
@ -497,7 +497,7 @@ def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op,
|
|||||||
# `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
|
# `popped_tensor_lists` by `_WhileBodyGradFuncGraph`.
|
||||||
# 2. Resources, which are output as is.
|
# 2. Resources, which are output as is.
|
||||||
# 3. Forward graph loop invariants, which are output as is.
|
# 3. Forward graph loop invariants, which are output as is.
|
||||||
for external_capture, internal_capture in grad_func_graph.captures.items():
|
for external_capture, internal_capture in grad_func_graph.captures:
|
||||||
if internal_capture in grad_func_graph.popped_tensor_lists:
|
if internal_capture in grad_func_graph.popped_tensor_lists:
|
||||||
new_output = grad_func_graph.popped_tensor_lists[internal_capture]
|
new_output = grad_func_graph.popped_tensor_lists[internal_capture]
|
||||||
elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(
|
elif (internal_capture.dtype == dtypes.resource or _is_loop_invariant(
|
||||||
|
@ -180,7 +180,7 @@ class Loader(object):
|
|||||||
concrete_function.graph.capture_distributed_variable(
|
concrete_function.graph.capture_distributed_variable(
|
||||||
bound_input, internal_capture)
|
bound_input, internal_capture)
|
||||||
else:
|
else:
|
||||||
concrete_function.graph.captures[bound_input] = internal_capture
|
concrete_function.graph._captures[bound_input] = internal_capture # pylint: disable=protected-access
|
||||||
if internal_capture.dtype == dtypes.resource:
|
if internal_capture.dtype == dtypes.resource:
|
||||||
if resource_variable_ops.is_resource_variable(bound_input):
|
if resource_variable_ops.is_resource_variable(bound_input):
|
||||||
try:
|
try:
|
||||||
|
@ -874,14 +874,15 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
root = Root()
|
root = Root()
|
||||||
self.assertIn(root.v.handle,
|
self.assertIn(root.v.handle,
|
||||||
root.use_v.get_concrete_function().graph.captures)
|
root.use_v.get_concrete_function().graph.external_captures)
|
||||||
for _ in range(cycles):
|
for _ in range(cycles):
|
||||||
root = self.cycle(root, 1, signatures=root.use_v.get_concrete_function())
|
root = self.cycle(root, 1, signatures=root.use_v.get_concrete_function())
|
||||||
func_captures = root.use_v.get_concrete_function().graph.captures
|
func_captures = root.use_v.get_concrete_function().graph.external_captures
|
||||||
self.assertLen(func_captures, 2)
|
self.assertLen(func_captures, 2)
|
||||||
self.assertIn(root.v.handle, func_captures)
|
self.assertIn(root.v.handle, func_captures)
|
||||||
self.assertIn(root.v1.handle, func_captures)
|
self.assertIn(root.v1.handle, func_captures)
|
||||||
signature_captures = root.signatures["serving_default"].graph.captures
|
signature_captures = root.signatures[
|
||||||
|
"serving_default"].graph.external_captures
|
||||||
self.assertLen(signature_captures, 2)
|
self.assertLen(signature_captures, 2)
|
||||||
self.assertIn(root.v.handle, signature_captures)
|
self.assertIn(root.v.handle, signature_captures)
|
||||||
self.assertIn(root.v1.handle, signature_captures)
|
self.assertIn(root.v1.handle, signature_captures)
|
||||||
|
@ -322,7 +322,7 @@ def _map_captures_to_created_tensors(
|
|||||||
`resource_map`.
|
`resource_map`.
|
||||||
"""
|
"""
|
||||||
export_captures = []
|
export_captures = []
|
||||||
for exterior, interior in original_captures.items():
|
for exterior, interior in original_captures:
|
||||||
mapped_resource = resource_map.get(exterior, None)
|
mapped_resource = resource_map.get(exterior, None)
|
||||||
if mapped_resource is None:
|
if mapped_resource is None:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
|
Loading…
Reference in New Issue
Block a user