Fold CapturingGraph into FuncGraph.
There's no need for the two separate classes anymore. This also cleans up some other parts of the interface: * Removes the clear_resource_control_flow_state, which isn't used anywhere * Makes capture_value a private method of FuncGraph (_capture_helper) * Makes create_substitute_placeholder private PiperOrigin-RevId: 211707906
This commit is contained in:
parent
59c43f26de
commit
99fe2f6034
@ -59,7 +59,7 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
|
|||||||
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
|
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def create_substitute_placeholder(value, name, dtype=None):
|
def _create_substitute_placeholder(value, name, dtype=None):
|
||||||
"""Creates a placeholder for `value` and propagates shape info to it."""
|
"""Creates a placeholder for `value` and propagates shape info to it."""
|
||||||
# Note: setting ops.control_dependencies(None) ensures we always put
|
# Note: setting ops.control_dependencies(None) ensures we always put
|
||||||
# capturing placeholders outside of any control flow context.
|
# capturing placeholders outside of any control flow context.
|
||||||
@ -91,100 +91,6 @@ def create_substitute_placeholder(value, name, dtype=None):
|
|||||||
return placeholder
|
return placeholder
|
||||||
|
|
||||||
|
|
||||||
def capture_value(tensor_map, value, dtype, name):
|
|
||||||
"""Capture a value from outside the function, to pass in as an extra arg."""
|
|
||||||
captured_value = tensor_map.get(value, None)
|
|
||||||
if captured_value is None:
|
|
||||||
captured_value = create_substitute_placeholder(value, name=name,
|
|
||||||
dtype=dtype)
|
|
||||||
tensor_map[value] = captured_value
|
|
||||||
tape.record_operation("captured_value", [captured_value], [value],
|
|
||||||
lambda x: [x])
|
|
||||||
return captured_value
|
|
||||||
|
|
||||||
|
|
||||||
class CapturingGraph(ops.Graph):
|
|
||||||
"""Graph that can capture tensors from other graphs.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
captures: Maps external tensor -> internal tensor (e.g. input placeholder).
|
|
||||||
The entries are in the order they were captured.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super(CapturingGraph, self).__init__()
|
|
||||||
|
|
||||||
self.captures = collections.OrderedDict()
|
|
||||||
self._building_function = True
|
|
||||||
|
|
||||||
# Map from resource tensor name to last op (in program order) which uses
|
|
||||||
# this tensor. Used to enforce that execution order matches program order
|
|
||||||
# for resource tensors.
|
|
||||||
self._last_op_using_resource_tensor = {}
|
|
||||||
|
|
||||||
def clear_resource_control_flow_state(self):
|
|
||||||
self._last_op_using_resource_tensor = {}
|
|
||||||
|
|
||||||
# TODO(skyewm): get rid of name and use the name of `tensor`.
|
|
||||||
def capture(self, tensor, name=None):
|
|
||||||
"""Capture `tensor` if it's external to this graph.
|
|
||||||
|
|
||||||
If `tensor` is from a different graph, returns a placeholder for it.
|
|
||||||
`tensor` and the placeholder will also appears in self.captures. Multiple
|
|
||||||
calls to this method with the same `tensor` argument will return the same
|
|
||||||
placeholder. If `tensor` is from this graph, returns `tensor`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor: Tensor. May be from this FuncGraph or a different graph.
|
|
||||||
name: Optional name if a placeholder is created.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor from this FuncGraph.
|
|
||||||
"""
|
|
||||||
if isinstance(tensor, ops.EagerTensor):
|
|
||||||
if name is None:
|
|
||||||
name = str(ops.uid())
|
|
||||||
return capture_value(self.captures, tensor, tensor.dtype, name)
|
|
||||||
if tensor.graph is not self:
|
|
||||||
if name is None:
|
|
||||||
name = tensor.op.name
|
|
||||||
return capture_value(self.captures, tensor, tensor.dtype, name)
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
def create_op(
|
|
||||||
self,
|
|
||||||
op_type,
|
|
||||||
inputs,
|
|
||||||
dtypes, # pylint: disable=redefined-outer-name
|
|
||||||
input_types=None,
|
|
||||||
name=None,
|
|
||||||
attrs=None,
|
|
||||||
op_def=None,
|
|
||||||
compute_shapes=True,
|
|
||||||
compute_device=True):
|
|
||||||
"""Captures an external inputs before calling Graph.capture_op."""
|
|
||||||
# This capturing logic interacts poorly with control flow contexts which
|
|
||||||
# want to replace inputs of ops far too late in the process. This can lead
|
|
||||||
# the context to get confused and try to create an Enter for an Enter. We
|
|
||||||
# can detect this here and skip the additional Enter which can confuse loop
|
|
||||||
# validation logic.
|
|
||||||
if op_type == "Enter" and inputs[0].op.type == "Enter":
|
|
||||||
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
|
|
||||||
return inputs[0].op
|
|
||||||
# Calling AddValue on the control flow contexts to force creation of the
|
|
||||||
# backward accumulators in the original graph before we create placeholders
|
|
||||||
# to capture the inputs.
|
|
||||||
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
|
|
||||||
for i, inp in enumerate(inputs):
|
|
||||||
if ctxt is not None and hasattr(ctxt, "AddValue"):
|
|
||||||
inp = ctxt.AddValue(inp)
|
|
||||||
inp = self.capture(inp)
|
|
||||||
inputs[i] = inp
|
|
||||||
return super(CapturingGraph, self).create_op(
|
|
||||||
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
|
||||||
compute_device=compute_device)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_device_functions(ctx, graph):
|
def _get_device_functions(ctx, graph):
|
||||||
"""Returns a tuple of device functions representing the device stack."""
|
"""Returns a tuple of device functions representing the device stack."""
|
||||||
if ctx.executing_eagerly():
|
if ctx.executing_eagerly():
|
||||||
@ -193,7 +99,7 @@ def _get_device_functions(ctx, graph):
|
|||||||
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
|
return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|
||||||
class FuncGraph(CapturingGraph):
|
class FuncGraph(ops.Graph):
|
||||||
"""Graph representing a function body.
|
"""Graph representing a function body.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
@ -210,6 +116,8 @@ class FuncGraph(CapturingGraph):
|
|||||||
variables: Variables that should be watched during function execution.
|
variables: Variables that should be watched during function execution.
|
||||||
outer_graph: The graph this function is defined in. May be another FuncGraph
|
outer_graph: The graph this function is defined in. May be another FuncGraph
|
||||||
or the global default Graph.
|
or the global default Graph.
|
||||||
|
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
|
||||||
|
The entries are in the order they were captured.
|
||||||
seed: The graph-level random seed.
|
seed: The graph-level random seed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -230,6 +138,13 @@ class FuncGraph(CapturingGraph):
|
|||||||
self.structured_outputs = None
|
self.structured_outputs = None
|
||||||
self.variables = []
|
self.variables = []
|
||||||
self.outer_graph = ops.get_default_graph()
|
self.outer_graph = ops.get_default_graph()
|
||||||
|
self.captures = collections.OrderedDict()
|
||||||
|
|
||||||
|
self._building_function = True
|
||||||
|
# Map from resource tensor name to last op (in program order) which uses
|
||||||
|
# this tensor. Used to enforce that execution order matches program order
|
||||||
|
# for resource tensors.
|
||||||
|
self._last_op_using_resource_tensor = {}
|
||||||
|
|
||||||
graph = self.outer_graph
|
graph = self.outer_graph
|
||||||
|
|
||||||
@ -258,15 +173,107 @@ class FuncGraph(CapturingGraph):
|
|||||||
self._graph_key = graph._graph_key
|
self._graph_key = graph._graph_key
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def create_op(
|
||||||
|
self,
|
||||||
|
op_type,
|
||||||
|
inputs,
|
||||||
|
dtypes,
|
||||||
|
input_types=None,
|
||||||
|
name=None,
|
||||||
|
attrs=None,
|
||||||
|
op_def=None,
|
||||||
|
compute_shapes=True,
|
||||||
|
compute_device=True):
|
||||||
|
"""Like Graph.create_op, except handles external input tensors.
|
||||||
|
|
||||||
|
This overload adds functionality to create_op to "capture" any external
|
||||||
|
input tensors, i.e. tensors from the eager context or outer function graphs
|
||||||
|
if this is a nested function. See `capture` for more information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op_type: The `Operation` type to create. This corresponds to the
|
||||||
|
`OpDef.name` field for the proto that defines the operation.
|
||||||
|
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
||||||
|
dtypes: A list of `DType` objects that will be the types of the tensors
|
||||||
|
that the operation produces.
|
||||||
|
input_types: (Optional.) A list of `DType`s that will be the types of
|
||||||
|
the tensors that the operation consumes. By default, uses the base
|
||||||
|
`DType` of each input in `inputs`. Operations that expect
|
||||||
|
reference-typed inputs must specify `input_types` explicitly.
|
||||||
|
name: (Optional.) A string name for the operation. If not specified, a
|
||||||
|
name is generated based on `op_type`.
|
||||||
|
attrs: (Optional.) A dictionary where the key is the attribute name (a
|
||||||
|
string) and the value is the respective `attr` attribute of the
|
||||||
|
`NodeDef` proto that will represent the operation (an `AttrValue`
|
||||||
|
proto).
|
||||||
|
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
|
||||||
|
the operation will have.
|
||||||
|
compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always
|
||||||
|
computed).
|
||||||
|
compute_device: (Optional.) If True, device functions will be executed
|
||||||
|
to compute the device property of the Operation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An `Operation` object.
|
||||||
|
"""
|
||||||
|
# This capturing logic interacts poorly with control flow contexts which
|
||||||
|
# want to replace inputs of ops far too late in the process. This can lead
|
||||||
|
# the context to get confused and try to create an Enter for an Enter. We
|
||||||
|
# can detect this here and skip the additional Enter which can confuse loop
|
||||||
|
# validation logic.
|
||||||
|
if op_type == "Enter" and inputs[0].op.type == "Enter":
|
||||||
|
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
|
||||||
|
return inputs[0].op
|
||||||
|
# Calling AddValue on the control flow contexts to force creation of the
|
||||||
|
# backward accumulators in the original graph before we create placeholders
|
||||||
|
# to capture the inputs.
|
||||||
|
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
|
||||||
|
for i, inp in enumerate(inputs):
|
||||||
|
# TPU Estimator defines a control flow context with no AddValue method.
|
||||||
|
if ctxt is not None and hasattr(ctxt, "AddValue"):
|
||||||
|
inp = ctxt.AddValue(inp)
|
||||||
|
inp = self.capture(inp)
|
||||||
|
inputs[i] = inp
|
||||||
|
return super(FuncGraph, self).create_op(
|
||||||
|
op_type, inputs, dtypes, input_types, name, attrs, op_def,
|
||||||
|
compute_device=compute_device)
|
||||||
|
|
||||||
def capture(self, tensor, name=None):
|
def capture(self, tensor, name=None):
|
||||||
"""Calls CapturingGraph.capture and updates self.inputs if necessary."""
|
"""Captures `tensor` if it's external to this graph.
|
||||||
new_capture = tensor not in self.captures
|
|
||||||
internal_tensor = super(FuncGraph, self).capture(tensor, name)
|
|
||||||
|
|
||||||
if new_capture and tensor is not internal_tensor:
|
If `tensor` is from a different graph, returns a placeholder for it.
|
||||||
self.inputs.append(internal_tensor)
|
`tensor` and the placeholder will appear in self.captures, and the
|
||||||
|
placeholder will appear in self.inputs. Multiple calls to this method with
|
||||||
|
the same `tensor` argument will return the same placeholder. If `tensor` is
|
||||||
|
from this graph, returns `tensor`.
|
||||||
|
|
||||||
return internal_tensor
|
Args:
|
||||||
|
tensor: Tensor. May be from this FuncGraph or a different graph.
|
||||||
|
name: Optional name if a placeholder is created.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor from this FuncGraph.
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, ops.EagerTensor):
|
||||||
|
if name is None:
|
||||||
|
name = str(ops.uid())
|
||||||
|
return self._capture_helper(tensor, name)
|
||||||
|
if tensor.graph is not self:
|
||||||
|
if name is None:
|
||||||
|
name = tensor.op.name
|
||||||
|
return self._capture_helper(tensor, name)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def _capture_helper(self, tensor, name):
|
||||||
|
captured_tensor = self.captures.get(tensor, None)
|
||||||
|
if captured_tensor is None:
|
||||||
|
captured_tensor = _create_substitute_placeholder(tensor, name=name,
|
||||||
|
dtype=tensor.dtype)
|
||||||
|
self.captures[tensor] = captured_tensor
|
||||||
|
self.inputs.append(captured_tensor)
|
||||||
|
tape.record_operation("captured_value", [captured_tensor], [tensor],
|
||||||
|
lambda x: [x])
|
||||||
|
return captured_tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def external_captures(self):
|
def external_captures(self):
|
||||||
|
@ -1001,7 +1001,7 @@ class Layer(checkpointable.CheckpointableBase):
|
|||||||
self.build(input_shape)
|
self.build(input_shape)
|
||||||
|
|
||||||
with context.graph_mode():
|
with context.graph_mode():
|
||||||
graph = eager_function.CapturingGraph()
|
graph = eager_function.FuncGraph('graph')
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
if isinstance(input_shape, list):
|
if isinstance(input_shape, list):
|
||||||
inputs = [generate_placeholders_from_shape(shape)
|
inputs = [generate_placeholders_from_shape(shape)
|
||||||
|
@ -770,7 +770,7 @@ class Network(base_layer.Layer):
|
|||||||
# and graph building, the variables created after building the model in
|
# and graph building, the variables created after building the model in
|
||||||
# a Graph are still valid when executing eagerly.
|
# a Graph are still valid when executing eagerly.
|
||||||
with context.graph_mode():
|
with context.graph_mode():
|
||||||
graph = eager_function.CapturingGraph()
|
graph = eager_function.FuncGraph('graph')
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
if isinstance(input_shape, list):
|
if isinstance(input_shape, list):
|
||||||
x = [base_layer.generate_placeholders_from_shape(shape)
|
x = [base_layer.generate_placeholders_from_shape(shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user