Fold CapturingGraph into FuncGraph.

There's no need for the two separate classes anymore. This also cleans up some other parts of the interface:
* Removes the clear_resource_control_flow_state, which isn't used anywhere
* Makes capture_value a private method of FuncGraph (_capture_helper)
* Makes create_substitute_placeholder private

PiperOrigin-RevId: 211707906
This commit is contained in:
Skye Wanderman-Milne 2018-09-05 15:13:16 -07:00 committed by TensorFlower Gardener
parent 59c43f26de
commit 99fe2f6034
3 changed files with 111 additions and 104 deletions

View File

@ -59,7 +59,7 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
def create_substitute_placeholder(value, name, dtype=None):
def _create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
@ -91,100 +91,6 @@ def create_substitute_placeholder(value, name, dtype=None):
return placeholder
def capture_value(tensor_map, value, dtype, name):
"""Capture a value from outside the function, to pass in as an extra arg."""
captured_value = tensor_map.get(value, None)
if captured_value is None:
captured_value = create_substitute_placeholder(value, name=name,
dtype=dtype)
tensor_map[value] = captured_value
tape.record_operation("captured_value", [captured_value], [value],
lambda x: [x])
return captured_value
class CapturingGraph(ops.Graph):
"""Graph that can capture tensors from other graphs.
Attributes:
captures: Maps external tensor -> internal tensor (e.g. input placeholder).
The entries are in the order they were captured.
"""
def __init__(self):
super(CapturingGraph, self).__init__()
self.captures = collections.OrderedDict()
self._building_function = True
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
self._last_op_using_resource_tensor = {}
def clear_resource_control_flow_state(self):
self._last_op_using_resource_tensor = {}
# TODO(skyewm): get rid of name and use the name of `tensor`.
def capture(self, tensor, name=None):
"""Capture `tensor` if it's external to this graph.
If `tensor` is from a different graph, returns a placeholder for it.
`tensor` and the placeholder will also appears in self.captures. Multiple
calls to this method with the same `tensor` argument will return the same
placeholder. If `tensor` is from this graph, returns `tensor`.
Args:
tensor: Tensor. May be from this FuncGraph or a different graph.
name: Optional name if a placeholder is created.
Returns:
Tensor from this FuncGraph.
"""
if isinstance(tensor, ops.EagerTensor):
if name is None:
name = str(ops.uid())
return capture_value(self.captures, tensor, tensor.dtype, name)
if tensor.graph is not self:
if name is None:
name = tensor.op.name
return capture_value(self.captures, tensor, tensor.dtype, name)
return tensor
def create_op(
self,
op_type,
inputs,
dtypes, # pylint: disable=redefined-outer-name
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
"""Captures an external inputs before calling Graph.capture_op."""
# This capturing logic interacts poorly with control flow contexts which
# want to replace inputs of ops far too late in the process. This can lead
# the context to get confused and try to create an Enter for an Enter. We
# can detect this here and skip the additional Enter which can confuse loop
# validation logic.
if op_type == "Enter" and inputs[0].op.type == "Enter":
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
return inputs[0].op
# Calling AddValue on the control flow contexts to force creation of the
# backward accumulators in the original graph before we create placeholders
# to capture the inputs.
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
if ctxt is not None and hasattr(ctxt, "AddValue"):
inp = ctxt.AddValue(inp)
inp = self.capture(inp)
inputs[i] = inp
return super(CapturingGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device)
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@ -193,7 +99,7 @@ def _get_device_functions(ctx, graph):
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
class FuncGraph(CapturingGraph):
class FuncGraph(ops.Graph):
"""Graph representing a function body.
Attributes:
@ -210,6 +116,8 @@ class FuncGraph(CapturingGraph):
variables: Variables that should be watched during function execution.
outer_graph: The graph this function is defined in. May be another FuncGraph
or the global default Graph.
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
The entries are in the order they were captured.
seed: The graph-level random seed.
"""
@ -230,6 +138,13 @@ class FuncGraph(CapturingGraph):
self.structured_outputs = None
self.variables = []
self.outer_graph = ops.get_default_graph()
self.captures = collections.OrderedDict()
self._building_function = True
# Map from resource tensor name to last op (in program order) which uses
# this tensor. Used to enforce that execution order matches program order
# for resource tensors.
self._last_op_using_resource_tensor = {}
graph = self.outer_graph
@ -258,15 +173,107 @@ class FuncGraph(CapturingGraph):
self._graph_key = graph._graph_key
# pylint: enable=protected-access
def create_op(
self,
op_type,
inputs,
dtypes,
input_types=None,
name=None,
attrs=None,
op_def=None,
compute_shapes=True,
compute_device=True):
"""Like Graph.create_op, except handles external input tensors.
This overload adds functionality to create_op to "capture" any external
input tensors, i.e. tensors from the eager context or outer function graphs
if this is a nested function. See `capture` for more information.
Args:
op_type: The `Operation` type to create. This corresponds to the
`OpDef.name` field for the proto that defines the operation.
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
dtypes: A list of `DType` objects that will be the types of the tensors
that the operation produces.
input_types: (Optional.) A list of `DType`s that will be the types of
the tensors that the operation consumes. By default, uses the base
`DType` of each input in `inputs`. Operations that expect
reference-typed inputs must specify `input_types` explicitly.
name: (Optional.) A string name for the operation. If not specified, a
name is generated based on `op_type`.
attrs: (Optional.) A dictionary where the key is the attribute name (a
string) and the value is the respective `attr` attribute of the
`NodeDef` proto that will represent the operation (an `AttrValue`
proto).
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
the operation will have.
compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
computed).
compute_device: (Optional.) If True, device functions will be executed
to compute the device property of the Operation.
Returns:
An `Operation` object.
"""
# This capturing logic interacts poorly with control flow contexts which
# want to replace inputs of ops far too late in the process. This can lead
# the context to get confused and try to create an Enter for an Enter. We
# can detect this here and skip the additional Enter which can confuse loop
# validation logic.
if op_type == "Enter" and inputs[0].op.type == "Enter":
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
return inputs[0].op
# Calling AddValue on the control flow contexts to force creation of the
# backward accumulators in the original graph before we create placeholders
# to capture the inputs.
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
for i, inp in enumerate(inputs):
# TPU Estimator defines a control flow context with no AddValue method.
if ctxt is not None and hasattr(ctxt, "AddValue"):
inp = ctxt.AddValue(inp)
inp = self.capture(inp)
inputs[i] = inp
return super(FuncGraph, self).create_op(
op_type, inputs, dtypes, input_types, name, attrs, op_def,
compute_device=compute_device)
def capture(self, tensor, name=None):
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
new_capture = tensor not in self.captures
internal_tensor = super(FuncGraph, self).capture(tensor, name)
"""Captures `tensor` if it's external to this graph.
if new_capture and tensor is not internal_tensor:
self.inputs.append(internal_tensor)
If `tensor` is from a different graph, returns a placeholder for it.
`tensor` and the placeholder will appear in self.captures, and the
placeholder will appear in self.inputs. Multiple calls to this method with
the same `tensor` argument will return the same placeholder. If `tensor` is
from this graph, returns `tensor`.
return internal_tensor
Args:
tensor: Tensor. May be from this FuncGraph or a different graph.
name: Optional name if a placeholder is created.
Returns:
Tensor from this FuncGraph.
"""
if isinstance(tensor, ops.EagerTensor):
if name is None:
name = str(ops.uid())
return self._capture_helper(tensor, name)
if tensor.graph is not self:
if name is None:
name = tensor.op.name
return self._capture_helper(tensor, name)
return tensor
def _capture_helper(self, tensor, name):
captured_tensor = self.captures.get(tensor, None)
if captured_tensor is None:
captured_tensor = _create_substitute_placeholder(tensor, name=name,
dtype=tensor.dtype)
self.captures[tensor] = captured_tensor
self.inputs.append(captured_tensor)
tape.record_operation("captured_value", [captured_tensor], [tensor],
lambda x: [x])
return captured_tensor
@property
def external_captures(self):

View File

@ -1001,7 +1001,7 @@ class Layer(checkpointable.CheckpointableBase):
self.build(input_shape)
with context.graph_mode():
graph = eager_function.CapturingGraph()
graph = eager_function.FuncGraph('graph')
with graph.as_default():
if isinstance(input_shape, list):
inputs = [generate_placeholders_from_shape(shape)

View File

@ -770,7 +770,7 @@ class Network(base_layer.Layer):
# and graph building, the variables created after building the model in
# a Graph are still valid when executing eagerly.
with context.graph_mode():
graph = eager_function.CapturingGraph()
graph = eager_function.FuncGraph('graph')
with graph.as_default():
if isinstance(input_shape, list):
x = [base_layer.generate_placeholders_from_shape(shape)