Fold CapturingGraph into FuncGraph.
There's no need for the two separate classes anymore. This also cleans up some other parts of the interface: * Removes the clear_resource_control_flow_state, which isn't used anywhere * Makes capture_value a private method of FuncGraph (_capture_helper) * Makes create_substitute_placeholder private PiperOrigin-RevId: 211707906
This commit is contained in:
parent
59c43f26de
commit
99fe2f6034
@ -59,7 +59,7 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
|
||||
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
|
||||
|
||||
|
||||
def create_substitute_placeholder(value, name, dtype=None):
|
||||
def _create_substitute_placeholder(value, name, dtype=None):
|
||||
"""Creates a placeholder for `value` and propagates shape info to it."""
|
||||
# Note: setting ops.control_dependencies(None) ensures we always put
|
||||
# capturing placeholders outside of any control flow context.
|
||||
@ -91,100 +91,6 @@ def create_substitute_placeholder(value, name, dtype=None):
|
||||
return placeholder
|
||||
|
||||
|
||||
def capture_value(tensor_map, value, dtype, name):
|
||||
"""Capture a value from outside the function, to pass in as an extra arg."""
|
||||
captured_value = tensor_map.get(value, None)
|
||||
if captured_value is None:
|
||||
captured_value = create_substitute_placeholder(value, name=name,
|
||||
dtype=dtype)
|
||||
tensor_map[value] = captured_value
|
||||
tape.record_operation("captured_value", [captured_value], [value],
|
||||
lambda x: [x])
|
||||
return captured_value
|
||||
|
||||
|
||||
class CapturingGraph(ops.Graph):
|
||||
"""Graph that can capture tensors from other graphs.
|
||||
|
||||
Attributes:
|
||||
captures: Maps external tensor -> internal tensor (e.g. input placeholder).
|
||||
The entries are in the order they were captured.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(CapturingGraph, self).__init__()
|
||||
|
||||
self.captures = collections.OrderedDict()
|
||||
self._building_function = True
|
||||
|
||||
# Map from resource tensor name to last op (in program order) which uses
|
||||
# this tensor. Used to enforce that execution order matches program order
|
||||
# for resource tensors.
|
||||
self._last_op_using_resource_tensor = {}
|
||||
|
||||
def clear_resource_control_flow_state(self):
|
||||
self._last_op_using_resource_tensor = {}
|
||||
|
||||
# TODO(skyewm): get rid of name and use the name of `tensor`.
|
||||
def capture(self, tensor, name=None):
|
||||
"""Capture `tensor` if it's external to this graph.
|
||||
|
||||
If `tensor` is from a different graph, returns a placeholder for it.
|
||||
`tensor` and the placeholder will also appears in self.captures. Multiple
|
||||
calls to this method with the same `tensor` argument will return the same
|
||||
placeholder. If `tensor` is from this graph, returns `tensor`.
|
||||
|
||||
Args:
|
||||
tensor: Tensor. May be from this FuncGraph or a different graph.
|
||||
name: Optional name if a placeholder is created.
|
||||
|
||||
Returns:
|
||||
Tensor from this FuncGraph.
|
||||
"""
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
if name is None:
|
||||
name = str(ops.uid())
|
||||
return capture_value(self.captures, tensor, tensor.dtype, name)
|
||||
if tensor.graph is not self:
|
||||
if name is None:
|
||||
name = tensor.op.name
|
||||
return capture_value(self.captures, tensor, tensor.dtype, name)
|
||||
return tensor
|
||||
|
||||
def create_op(
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes, # pylint: disable=redefined-outer-name
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
op_def=None,
|
||||
compute_shapes=True,
|
||||
compute_device=True):
|
||||
"""Captures an external inputs before calling Graph.capture_op."""
|
||||
# This capturing logic interacts poorly with control flow contexts which
|
||||
# want to replace inputs of ops far too late in the process. This can lead
|
||||
# the context to get confused and try to create an Enter for an Enter. We
|
||||
# can detect this here and skip the additional Enter which can confuse loop
|
||||
# validation logic.
|
||||
if op_type == "Enter" and inputs[0].op.type == "Enter":
|
||||
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
|
||||
return inputs[0].op
|
||||
# Calling AddValue on the control flow contexts to force creation of the
|
||||
# backward accumulators in the original graph before we create placeholders
|
||||
# to capture the inputs.
|
||||
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
|
||||
for i, inp in enumerate(inputs):
|
||||
if ctxt is not None and hasattr(ctxt, "AddValue"):
|
||||
inp = ctxt.AddValue(inp)
|
||||
inp = self.capture(inp)
|
||||
inputs[i] = inp
|
||||
return super(CapturingGraph, self).create_op(
|
||||
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
||||
compute_device=compute_device)
|
||||
|
||||
|
||||
def _get_device_functions(ctx, graph):
|
||||
"""Returns a tuple of device functions representing the device stack."""
|
||||
if ctx.executing_eagerly():
|
||||
@ -193,7 +99,7 @@ def _get_device_functions(ctx, graph):
|
||||
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class FuncGraph(CapturingGraph):
|
||||
class FuncGraph(ops.Graph):
|
||||
"""Graph representing a function body.
|
||||
|
||||
Attributes:
|
||||
@ -210,6 +116,8 @@ class FuncGraph(CapturingGraph):
|
||||
variables: Variables that should be watched during function execution.
|
||||
outer_graph: The graph this function is defined in. May be another FuncGraph
|
||||
or the global default Graph.
|
||||
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
|
||||
The entries are in the order they were captured.
|
||||
seed: The graph-level random seed.
|
||||
"""
|
||||
|
||||
@ -230,6 +138,13 @@ class FuncGraph(CapturingGraph):
|
||||
self.structured_outputs = None
|
||||
self.variables = []
|
||||
self.outer_graph = ops.get_default_graph()
|
||||
self.captures = collections.OrderedDict()
|
||||
|
||||
self._building_function = True
|
||||
# Map from resource tensor name to last op (in program order) which uses
|
||||
# this tensor. Used to enforce that execution order matches program order
|
||||
# for resource tensors.
|
||||
self._last_op_using_resource_tensor = {}
|
||||
|
||||
graph = self.outer_graph
|
||||
|
||||
@ -258,15 +173,107 @@ class FuncGraph(CapturingGraph):
|
||||
self._graph_key = graph._graph_key
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def create_op(
|
||||
self,
|
||||
op_type,
|
||||
inputs,
|
||||
dtypes,
|
||||
input_types=None,
|
||||
name=None,
|
||||
attrs=None,
|
||||
op_def=None,
|
||||
compute_shapes=True,
|
||||
compute_device=True):
|
||||
"""Like Graph.create_op, except handles external input tensors.
|
||||
|
||||
This overload adds functionality to create_op to "capture" any external
|
||||
input tensors, i.e. tensors from the eager context or outer function graphs
|
||||
if this is a nested function. See `capture` for more information.
|
||||
|
||||
Args:
|
||||
op_type: The `Operation` type to create. This corresponds to the
|
||||
`OpDef.name` field for the proto that defines the operation.
|
||||
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||
dtypes: A list of `DType` objects that will be the types of the tensors
|
||||
that the operation produces.
|
||||
input_types: (Optional.) A list of `DType`s that will be the types of
|
||||
the tensors that the operation consumes. By default, uses the base
|
||||
`DType` of each input in `inputs`. Operations that expect
|
||||
reference-typed inputs must specify `input_types` explicitly.
|
||||
name: (Optional.) A string name for the operation. If not specified, a
|
||||
name is generated based on `op_type`.
|
||||
attrs: (Optional.) A dictionary where the key is the attribute name (a
|
||||
string) and the value is the respective `attr` attribute of the
|
||||
`NodeDef` proto that will represent the operation (an `AttrValue`
|
||||
proto).
|
||||
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
|
||||
the operation will have.
|
||||
compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
|
||||
computed).
|
||||
compute_device: (Optional.) If True, device functions will be executed
|
||||
to compute the device property of the Operation.
|
||||
|
||||
Returns:
|
||||
An `Operation` object.
|
||||
"""
|
||||
# This capturing logic interacts poorly with control flow contexts which
|
||||
# want to replace inputs of ops far too late in the process. This can lead
|
||||
# the context to get confused and try to create an Enter for an Enter. We
|
||||
# can detect this here and skip the additional Enter which can confuse loop
|
||||
# validation logic.
|
||||
if op_type == "Enter" and inputs[0].op.type == "Enter":
|
||||
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
|
||||
return inputs[0].op
|
||||
# Calling AddValue on the control flow contexts to force creation of the
|
||||
# backward accumulators in the original graph before we create placeholders
|
||||
# to capture the inputs.
|
||||
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
|
||||
for i, inp in enumerate(inputs):
|
||||
# TPU Estimator defines a control flow context with no AddValue method.
|
||||
if ctxt is not None and hasattr(ctxt, "AddValue"):
|
||||
inp = ctxt.AddValue(inp)
|
||||
inp = self.capture(inp)
|
||||
inputs[i] = inp
|
||||
return super(FuncGraph, self).create_op(
|
||||
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
||||
compute_device=compute_device)
|
||||
|
||||
def capture(self, tensor, name=None):
|
||||
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
|
||||
new_capture = tensor not in self.captures
|
||||
internal_tensor = super(FuncGraph, self).capture(tensor, name)
|
||||
"""Captures `tensor` if it's external to this graph.
|
||||
|
||||
if new_capture and tensor is not internal_tensor:
|
||||
self.inputs.append(internal_tensor)
|
||||
If `tensor` is from a different graph, returns a placeholder for it.
|
||||
`tensor` and the placeholder will appear in self.captures, and the
|
||||
placeholder will appear in self.inputs. Multiple calls to this method with
|
||||
the same `tensor` argument will return the same placeholder. If `tensor` is
|
||||
from this graph, returns `tensor`.
|
||||
|
||||
return internal_tensor
|
||||
Args:
|
||||
tensor: Tensor. May be from this FuncGraph or a different graph.
|
||||
name: Optional name if a placeholder is created.
|
||||
|
||||
Returns:
|
||||
Tensor from this FuncGraph.
|
||||
"""
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
if name is None:
|
||||
name = str(ops.uid())
|
||||
return self._capture_helper(tensor, name)
|
||||
if tensor.graph is not self:
|
||||
if name is None:
|
||||
name = tensor.op.name
|
||||
return self._capture_helper(tensor, name)
|
||||
return tensor
|
||||
|
||||
def _capture_helper(self, tensor, name):
|
||||
captured_tensor = self.captures.get(tensor, None)
|
||||
if captured_tensor is None:
|
||||
captured_tensor = _create_substitute_placeholder(tensor, name=name,
|
||||
dtype=tensor.dtype)
|
||||
self.captures[tensor] = captured_tensor
|
||||
self.inputs.append(captured_tensor)
|
||||
tape.record_operation("captured_value", [captured_tensor], [tensor],
|
||||
lambda x: [x])
|
||||
return captured_tensor
|
||||
|
||||
@property
|
||||
def external_captures(self):
|
||||
|
@ -1001,7 +1001,7 @@ class Layer(checkpointable.CheckpointableBase):
|
||||
self.build(input_shape)
|
||||
|
||||
with context.graph_mode():
|
||||
graph = eager_function.CapturingGraph()
|
||||
graph = eager_function.FuncGraph('graph')
|
||||
with graph.as_default():
|
||||
if isinstance(input_shape, list):
|
||||
inputs = [generate_placeholders_from_shape(shape)
|
||||
|
@ -770,7 +770,7 @@ class Network(base_layer.Layer):
|
||||
# and graph building, the variables created after building the model in
|
||||
# a Graph are still valid when executing eagerly.
|
||||
with context.graph_mode():
|
||||
graph = eager_function.CapturingGraph()
|
||||
graph = eager_function.FuncGraph('graph')
|
||||
with graph.as_default():
|
||||
if isinstance(input_shape, list):
|
||||
x = [base_layer.generate_placeholders_from_shape(shape)
|
||||
|
Loading…
Reference in New Issue
Block a user