1293 lines
53 KiB
Python
1293 lines
53 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""FuncGraph and related functionality."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections as py_collections
|
|
import itertools
|
|
import weakref
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.core.framework import attr_value_pb2
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import execute
|
|
from tensorflow.python.eager import tape
|
|
from tensorflow.python.eager.graph_only_ops import graph_placeholder
|
|
from tensorflow.python.framework import auto_control_deps
|
|
from tensorflow.python.framework import composite_tensor
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import type_spec
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import custom_gradient
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.ops import tensor_array_ops
|
|
from tensorflow.python.ops import variable_scope
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util import memory
|
|
from tensorflow.python.util import nest
|
|
from tensorflow.python.util import object_identity
|
|
from tensorflow.python.util import tf_contextlib
|
|
from tensorflow.python.util import tf_decorator
|
|
|
|
ALLOWLIST_COLLECTIONS = [
|
|
ops.GraphKeys.GLOBAL_VARIABLES,
|
|
ops.GraphKeys.LOCAL_VARIABLES,
|
|
ops.GraphKeys.TRAINABLE_VARIABLES,
|
|
variable_scope._VARSTORE_KEY, # pylint: disable=protected-access
|
|
variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access
|
|
]
|
|
|
|
|
|
_EAGER_CONST_THRESHOLD = 128
|
|
|
|
|
|
class UnknownArgument(object):
|
|
"""Signifies an argument which is not currently handled."""
|
|
pass
|
|
|
|
|
|
def convert_structure_to_signature(structure, arg_names=None):
|
|
"""Convert a potentially nested structure to a signature.
|
|
|
|
Args:
|
|
structure: Structure to convert, where top level collection is a list or a
|
|
tuple.
|
|
arg_names: Optional list of arguments that has equal number of elements as
|
|
`structure` and is used for naming corresponding TensorSpecs.
|
|
|
|
Returns:
|
|
Identical structure that has TensorSpec objects instead of Tensors and
|
|
UnknownArgument instead of any unsupported types.
|
|
"""
|
|
def encode_arg(arg, path):
|
|
"""A representation for this argument, for converting into signatures."""
|
|
if isinstance(arg, ops.Tensor):
|
|
user_specified_name = None
|
|
try:
|
|
user_specified_name = compat.as_str(
|
|
arg.op.get_attr("_user_specified_name"))
|
|
except ValueError:
|
|
pass
|
|
|
|
if path and user_specified_name and user_specified_name != path[0]:
|
|
# The user has explicitly named the argument differently than the name
|
|
# of the function argument.
|
|
name = user_specified_name
|
|
else:
|
|
name = "/".join(str(p) for p in path)
|
|
return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
|
|
if isinstance(arg, composite_tensor.CompositeTensor):
|
|
# TODO(b/133606651) Do we need to inject arg_name?
|
|
return arg._type_spec # pylint: disable=protected-access
|
|
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
|
name = "/".join(str(p) for p in path)
|
|
return resource_variable_ops.VariableSpec(arg.shape, arg.dtype, name)
|
|
if isinstance(arg, (
|
|
int,
|
|
float,
|
|
bool,
|
|
type(None),
|
|
dtypes.DType,
|
|
tensor_spec.TensorSpec,
|
|
type_spec.TypeSpec,
|
|
)):
|
|
return arg
|
|
return UnknownArgument()
|
|
|
|
# We are using the flattened paths to name the TensorSpecs. We need an
|
|
# explicit name for them downstream.
|
|
flattened = nest.flatten_with_tuple_paths(structure)
|
|
if arg_names:
|
|
if len(arg_names) != len(structure):
|
|
raise ValueError(
|
|
"Passed in arg_names don't match actual signature (%s)." % arg_names)
|
|
# Replace all top-level names with their actual arg_names. If a path before
|
|
# was "(2,'a',1)", it will become "(arg_names[2],'a',1)".
|
|
flattened = [
|
|
((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened
|
|
]
|
|
|
|
mapped = [encode_arg(arg, path) for path, arg in flattened]
|
|
return nest.pack_sequence_as(structure, mapped)
|
|
|
|
|
|
class FuncGraph(ops.Graph):
|
|
"""Graph representing a function body.
|
|
|
|
Attributes:
|
|
name: The name of the function.
|
|
inputs: Placeholder tensors representing the inputs to this function. The
|
|
tensors are in this FuncGraph. This represents "regular" inputs as well as
|
|
captured inputs (i.e. the values of self.captures), with the regular
|
|
inputs coming first.
|
|
outputs: Tensors that will be returned by this function. The tensors are in
|
|
this FuncGraph.
|
|
control_outputs: Operations that must be executed before the function
|
|
represented by this graph can be said to have been executed.
|
|
structured_input_signature: A tuple of (args, kwargs), which are both
|
|
possibly-nested python objects that were received by this function. Note
|
|
that these structures might contain Python `None`s.
|
|
structured_outputs: A possibly-nested python object which will be returned
|
|
by this function. The Tensors in this structure are the same as those of
|
|
self.outputs. Note that this structure might contain Python `None`s.
|
|
variables: Variables that should be watched during function execution.
|
|
outer_graph: The graph this function is defined in. May be another FuncGraph
|
|
or the global default Graph.
|
|
captures: Maps external tensor -> internal tensor (i.e. input placeholder).
|
|
The entries are in the order they were captured.
|
|
control_captures: Set of external ops on which this graph has a control
|
|
dependency.
|
|
seed: The graph-level random seed.
|
|
capture_by_value: If True, the func graph will capture Variables by value
|
|
instead of reference.
|
|
"""
|
|
|
|
def __init__(self, name, collections=None, capture_by_value=None):
|
|
"""Construct a new FuncGraph.
|
|
|
|
The graph will inherit its graph key, collections, seed, and distribution
|
|
strategy stack from the current context or graph.
|
|
|
|
Args:
|
|
name: the name of the function.
|
|
collections: a dictionary of collections this FuncGraph should start
|
|
with. If not specified (None), the FuncGraph will read (but not write
|
|
to) the outer graph's collections that are not allowlisted, and both
|
|
read and write to the outer graph's collections that are allowlisted.
|
|
The current allowlisted collections are the global variables, the
|
|
local variables, and the trainable variables.
|
|
Defaults to None.
|
|
capture_by_value: An optional boolean. If True, the func graph will
|
|
capture Variables by value instead of reference. By default inherit
|
|
from outer graphs, and failing that will default to False.
|
|
"""
|
|
super(FuncGraph, self).__init__()
|
|
|
|
self.name = name
|
|
self.inputs = []
|
|
self.outputs = []
|
|
self.control_outputs = []
|
|
self.control_captures = set()
|
|
self.structured_input_signature = None
|
|
self.structured_outputs = None
|
|
self._weak_variables = []
|
|
self._watched_variables = object_identity.ObjectIdentityWeakSet()
|
|
self.is_control_flow_graph = False
|
|
|
|
outer_graph = ops.get_default_graph()
|
|
self._weak_outer_graph = weakref.ref(outer_graph)
|
|
while outer_graph.building_function:
|
|
outer_graph = outer_graph.outer_graph
|
|
# If self._weak_outer_graph is deleted, we revert to the outermost Graph
|
|
# active when the FuncGraph was traced. This will not be a FuncGraph.
|
|
self._fallback_outer_graph = outer_graph
|
|
self._captures = py_collections.OrderedDict()
|
|
# If not None, records the names of output args of this function. Used to
|
|
# preserve the output names in the signature of a serialized+deserialized
|
|
# function. Private at the moment mostly because it's often out of date.
|
|
self._output_names = None
|
|
# Maps arbitrary key -> (closure, nest of placeholders), where at function
|
|
# call time the value of closure() will be used to feed the nest of
|
|
# placeholders.
|
|
self._deferred_captures = py_collections.OrderedDict()
|
|
# Inherit capture-by-value from outer graph.
|
|
if capture_by_value is not None:
|
|
self.capture_by_value = capture_by_value
|
|
elif self.outer_graph is not None and isinstance(
|
|
self.outer_graph, FuncGraph):
|
|
self.capture_by_value = self.outer_graph.capture_by_value
|
|
else:
|
|
self.capture_by_value = False
|
|
|
|
self._building_function = True
|
|
# Map from resource tensor name to last op (in program order) which uses
|
|
# this tensor. Used to enforce that execution order matches program order
|
|
# for resource tensors.
|
|
self._last_op_using_resource_tensor = {}
|
|
|
|
graph = self.outer_graph
|
|
|
|
if context.executing_eagerly():
|
|
self.seed = context.global_seed()
|
|
# [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of
|
|
# any None op_seed for random_op in the function, in which case we end up
|
|
# using function seed, which could be unintended behavior for the op.
|
|
self._seed_used = False
|
|
else:
|
|
self.seed = graph.seed
|
|
self._seed_used = False
|
|
# TODO(allenl): Figure out if we can remove colocation stack
|
|
# specialization (currently used in cond_v2), here and in the cache key.
|
|
self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
|
|
|
|
if collections is None:
|
|
for collection_name in graph.get_all_collection_keys():
|
|
if collection_name not in ALLOWLIST_COLLECTIONS:
|
|
self._collections[collection_name] = graph.get_collection(
|
|
collection_name)
|
|
for collection_name in ALLOWLIST_COLLECTIONS:
|
|
self._collections[collection_name] = graph.get_collection_ref(
|
|
collection_name)
|
|
else:
|
|
self._collections = collections
|
|
|
|
# Keep track of whether this FuncGraph is exportable to SavedModel. Use
|
|
# `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any
|
|
# dependent functions as unsaveable.
|
|
self._saveable = True
|
|
self._saving_errors = set()
|
|
|
|
# Keep track of callbacks to run when this graph exits default scope
|
|
self._scope_exit_callbacks = None
|
|
|
|
def __str__(self):
|
|
return "FuncGraph(name=%s, id=%s)" % (self.name, id(self))
|
|
|
|
def watch_variable(self, v):
|
|
"""Marks the variable v as accessed while building this graph."""
|
|
while self is not None and isinstance(self, FuncGraph):
|
|
self._watched_variables.add(v)
|
|
self = self.outer_graph
|
|
|
|
def capture_call_time_value(self, closure, spec, key=None):
|
|
"""Creates a placeholder which at call time has the value closure().
|
|
|
|
Useful, for example, to respect TensorFlow context managers, which are often
|
|
dynamically scoped.
|
|
|
|
Args:
|
|
closure: function which takes no arguments, to be evaluated at function
|
|
call time, returning a nest of tensors compatible with `spec`.
|
|
spec: nest of TypeSpec for the value to capture.
|
|
key: optional. If not None, multiple calls to lazy_capture with the same
|
|
key in the same graph will return the same placeholder, and the
|
|
first closure will be used at function call time.
|
|
|
|
Returns:
|
|
Nest of placeholders which, at function call time, will be fed with the
|
|
result of calling closure().
|
|
|
|
Raises:
|
|
ValueError: at function call time, if the return value of closure() is
|
|
not compatible with `spec`.
|
|
"""
|
|
if key is None:
|
|
key = object()
|
|
if key not in self._deferred_captures:
|
|
|
|
def convert_to_placeholder(s):
|
|
if not isinstance(s, tensor_spec.DenseSpec):
|
|
raise TypeError(
|
|
"Expected a nest of `TypeSpec` objects, found %s of type %s." %
|
|
(s, type(s)))
|
|
return array_ops.placeholder(dtype=s.dtype, shape=s.shape)
|
|
|
|
placeholder = nest.map_structure(
|
|
convert_to_placeholder, spec, expand_composites=True)
|
|
|
|
def wrapped_closure():
|
|
ret_nest = closure()
|
|
nest.assert_same_structure(spec, ret_nest, expand_composites=True)
|
|
# This uses the tensor dtype defined in `spec` when converting values
|
|
# in `ret_nest` to tensors.
|
|
# pylint: disable=protected-access
|
|
y = nest.map_structure(lambda s, r: s._to_components(r), spec, ret_nest,
|
|
expand_composites=False)
|
|
# pylint: enable=protected-access
|
|
return nest.flatten(y, expand_composites=True)
|
|
|
|
self._deferred_captures[key] = (wrapped_closure, placeholder)
|
|
return self._deferred_captures[key][1]
|
|
|
|
def control_dependencies(self, control_inputs):
|
|
"""Handles control dependencies.
|
|
|
|
FuncGraph wraps Graph's control_dependencies logic by first filtering out
|
|
any external tensors / operations and storing them in the graph's
|
|
control_captures member. Any consumers of this function graph must then
|
|
decide how to handle the control captures.
|
|
|
|
Args:
|
|
control_inputs: A list of `Operation` or `Tensor` objects which
|
|
must be executed or computed before running the operations
|
|
defined in the context. Can also be `None` to clear the control
|
|
dependencies.
|
|
|
|
Returns:
|
|
A context manager that specifies control dependencies for all
|
|
operations constructed within the context.
|
|
|
|
Raises:
|
|
TypeError: If `control_inputs` is not a list of `Operation` or
|
|
`Tensor` objects.
|
|
"""
|
|
if control_inputs is None:
|
|
return super(FuncGraph, self).control_dependencies(control_inputs)
|
|
|
|
filtered_control_inputs = []
|
|
for c in control_inputs:
|
|
# Check for _UnreadVariable
|
|
if (isinstance(c, ops.IndexedSlices) or
|
|
(hasattr(c, "_handle") and hasattr(c, "op"))):
|
|
c = c.op
|
|
graph_element = ops._as_graph_element(c) # pylint: disable=protected-access
|
|
if graph_element is None:
|
|
graph_element = c
|
|
if graph_element is not None and getattr(
|
|
graph_element, "graph", None) is not self:
|
|
self.control_captures.add(graph_element)
|
|
else:
|
|
filtered_control_inputs.append(graph_element)
|
|
return super(FuncGraph, self).control_dependencies(filtered_control_inputs)
|
|
|
|
def as_default(self):
|
|
outer_cm = super(FuncGraph, self).as_default()
|
|
|
|
@tf_contextlib.contextmanager
|
|
def inner_cm():
|
|
"""Context manager for copying distribute.Strategy scope information."""
|
|
# pylint: disable=protected-access
|
|
# TODO(b/112906995, nareshmodi): distribution strategy depends on
|
|
# inheriting this stack from the default graph even in eager mode. Maybe
|
|
# it should be part of the eager context? This would also allow us to
|
|
# remove a get_default_graph() call from the function cache lookup.
|
|
graph = ops.get_default_graph()
|
|
old_strategy_stack = self._distribution_strategy_stack
|
|
self._distribution_strategy_stack = list(
|
|
graph._distribution_strategy_stack)
|
|
|
|
# We ignore device placements from any outer scopes while tracing the
|
|
# function when possible, to avoid hard-coding them in the function
|
|
# graph. "Default" placements come from the PartitionedCallOp's placement,
|
|
# so that the same trace of the Python function may be placed on several
|
|
# different devices and saved functions may be placed on new devices when
|
|
# restored.
|
|
# However, we need to preserve the outer device stack in the following
|
|
# cases in non eager context:
|
|
# 1. device stack is callable
|
|
# 2. When using distribution strategy with legacy graph mode.
|
|
old_device_stack = self._device_function_stack
|
|
if (not context.executing_eagerly() and
|
|
(device_stack_has_callable(graph._device_function_stack) or
|
|
(self._distribution_strategy_stack and
|
|
not ops.executing_eagerly_outside_functions()))):
|
|
# Hard-code devices from device functions in the function body
|
|
self._device_function_stack = graph._device_function_stack.copy()
|
|
|
|
old_creator_stack = self._variable_creator_stack
|
|
self._variable_creator_stack = graph._variable_creator_stack
|
|
# Inherit the graph key, since this is used for matching variables in
|
|
# optimizers.
|
|
old_graph_key = self._graph_key
|
|
self._graph_key = graph._graph_key
|
|
# pylint: enable=protected-access
|
|
|
|
old_scope_exit_callbacks = self._scope_exit_callbacks
|
|
self._scope_exit_callbacks = []
|
|
|
|
with outer_cm as g:
|
|
try:
|
|
yield g
|
|
finally:
|
|
try:
|
|
for fn in self._scope_exit_callbacks:
|
|
fn()
|
|
finally:
|
|
self._scope_exit_callbacks = old_scope_exit_callbacks
|
|
self._distribution_strategy_stack = old_strategy_stack
|
|
self._device_function_stack = old_device_stack
|
|
self._variable_creator_stack = old_creator_stack
|
|
self._graph_key = old_graph_key
|
|
return inner_cm()
|
|
|
|
@property
|
|
def outer_graph(self):
|
|
"""The Graph this FuncGraph is nested in.
|
|
|
|
Functions may capture Tensors from graphs they are nested in (transitive).
|
|
|
|
Returns:
|
|
A Graph object. Initially set to the current default graph when the
|
|
FuncGraph was created. If the previous `outer_graph` was deleted because
|
|
the function that owns it was deleted, `outer_graph` is reset to the
|
|
outermost default graph active when the FuncGraph was created. This
|
|
FuncGraph won't have captured anything from the new `outer_graph` (and
|
|
likely not from the previous setting, since that would have created a
|
|
strong reference), but it is returned so that FuncGraphs always have a
|
|
parent.
|
|
"""
|
|
current = self._weak_outer_graph()
|
|
if current is None:
|
|
return self._fallback_outer_graph
|
|
return current
|
|
|
|
@property
|
|
def output_types(self):
|
|
return [t.dtype for t in self.outputs]
|
|
|
|
@property
|
|
def output_shapes(self):
|
|
return [t.shape for t in self.outputs]
|
|
|
|
@property
|
|
def trainable_variables(self):
|
|
"""A sequence of trainable variables accessed by this FuncGraph.
|
|
|
|
Note that functions keep only weak references to variables. Calling the
|
|
function after a variable it accesses has been deleted is an error.
|
|
|
|
Returns:
|
|
Sequence of trainable variables for this func graph.
|
|
"""
|
|
return tuple(v for v in self.variables if v.trainable)
|
|
|
|
@property
|
|
def variables(self):
|
|
"""A sequence of variables accessed by this FuncGraph.
|
|
|
|
Note that functions keep only weak references to variables. Calling the
|
|
function after a variable it accesses has been deleted is an error.
|
|
|
|
Returns:
|
|
Sequence of variables for this func graph.
|
|
"""
|
|
def deref(weak_v):
|
|
v = weak_v()
|
|
if v is None:
|
|
raise AssertionError(
|
|
"Called a function referencing variables which have been deleted. "
|
|
"This likely means that function-local variables were created and "
|
|
"not referenced elsewhere in the program. This is generally a "
|
|
"mistake; consider storing variables in an object attribute on "
|
|
"first call.")
|
|
return v
|
|
|
|
return tuple(deref(v) for v in self._weak_variables)
|
|
|
|
@variables.setter
|
|
def variables(self, var_list):
|
|
self._weak_variables = [weakref.ref(v) for v in var_list]
|
|
|
|
def _capture_by_value(
|
|
self,
|
|
op_type,
|
|
inputs,
|
|
dtypes, # pylint: disable=redefined-outer-name
|
|
input_types=None,
|
|
name=None,
|
|
attrs=None,
|
|
op_def=None,
|
|
compute_device=True):
|
|
# When capturing by value, do the read outside
|
|
reverse_captures = dict((id(v), k) for k, v in self.captures)
|
|
uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs]
|
|
with ops.init_scope():
|
|
if context.executing_eagerly():
|
|
attr_list = ("dtype", int(attrs["dtype"].type))
|
|
value, = execute.execute(
|
|
compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list,
|
|
context.context())
|
|
else:
|
|
op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access
|
|
op_type,
|
|
uncaptured_inputs,
|
|
dtypes,
|
|
input_types,
|
|
name,
|
|
attrs,
|
|
op_def,
|
|
compute_device)
|
|
value = op.outputs[0]
|
|
captured_value = self.capture(value)
|
|
return captured_value.op
|
|
|
|
def _create_op_internal(
|
|
self,
|
|
op_type,
|
|
inputs,
|
|
dtypes=None, # pylint: disable=redefined-outer-name
|
|
input_types=None,
|
|
name=None,
|
|
attrs=None,
|
|
op_def=None,
|
|
compute_device=True):
|
|
"""Like Graph.create_op, except handles external input tensors.
|
|
|
|
This overload adds functionality to create_op to "capture" any external
|
|
input tensors, i.e. tensors from the eager context or outer function graphs
|
|
if this is a nested function. See `capture` for more information.
|
|
|
|
Args:
|
|
op_type: The `Operation` type to create. This corresponds to the
|
|
`OpDef.name` field for the proto that defines the operation.
|
|
inputs: A list of `Tensor` objects that will be inputs to the `Operation`.
|
|
dtypes: (Optional) A list of `DType` objects that will be the types of the
|
|
tensors that the operation produces.
|
|
input_types: (Optional.) A list of `DType`s that will be the types of
|
|
the tensors that the operation consumes. By default, uses the base
|
|
`DType` of each input in `inputs`. Operations that expect
|
|
reference-typed inputs must specify `input_types` explicitly.
|
|
name: (Optional.) A string name for the operation. If not specified, a
|
|
name is generated based on `op_type`.
|
|
attrs: (Optional.) A dictionary where the key is the attribute name (a
|
|
string) and the value is the respective `attr` attribute of the
|
|
`NodeDef` proto that will represent the operation (an `AttrValue`
|
|
proto).
|
|
op_def: (Optional.) The `OpDef` proto that describes the `op_type` that
|
|
the operation will have.
|
|
compute_device: (Optional.) If True, device functions will be executed
|
|
to compute the device property of the Operation.
|
|
|
|
Returns:
|
|
An `Operation` object.
|
|
"""
|
|
if self.capture_by_value and op_type in ["ReadVariableOp",
|
|
"ResourceGather"]:
|
|
return self._capture_by_value(op_type, inputs, dtypes, input_types, name,
|
|
attrs, op_def, compute_device)
|
|
|
|
# This capturing logic interacts poorly with control flow contexts which
|
|
# want to replace inputs of ops far too late in the process. This can lead
|
|
# the context to get confused and try to create an Enter for an Enter. We
|
|
# can detect this here and skip the additional Enter which can confuse loop
|
|
# validation logic.
|
|
if op_type == "Enter" and inputs[0].op.type == "Enter":
|
|
if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s:
|
|
return inputs[0].op
|
|
# Calling AddValue on the control flow contexts to force creation of the
|
|
# backward accumulators in the original graph before we create placeholders
|
|
# to capture the inputs.
|
|
ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access
|
|
# Use a different list to avoid modifying the original inputs list.
|
|
captured_inputs = []
|
|
for inp in inputs:
|
|
# TPU Estimator defines a control flow context with no AddValue method.
|
|
if ctxt is not None and hasattr(ctxt, "AddValue"):
|
|
inp = ctxt.AddValue(inp)
|
|
inp = self.capture(inp)
|
|
captured_inputs.append(inp)
|
|
return super(FuncGraph, self)._create_op_internal( # pylint: disable=protected-access
|
|
op_type, captured_inputs, dtypes, input_types, name, attrs, op_def,
|
|
compute_device)
|
|
|
|
def capture(self, tensor, name=None, shape=None):
|
|
"""Captures `tensor` if it's external to this graph.
|
|
|
|
If `tensor` is from a different graph, returns a placeholder for it.
|
|
`tensor` and the placeholder will appear in self.captures, and the
|
|
placeholder will appear in self.inputs. Multiple calls to this method with
|
|
the same `tensor` argument will return the same placeholder. If `tensor` is
|
|
from this graph, returns `tensor`.
|
|
|
|
Args:
|
|
tensor: Tensor. May be from this FuncGraph or a different graph.
|
|
name: Optional name if a placeholder is created.
|
|
shape: Optional shape if a placeholder is created.
|
|
|
|
Returns:
|
|
Tensor from this FuncGraph.
|
|
|
|
Raises:
|
|
InaccessibleTensorError: if any tensors are accessed in a manner that
|
|
bypasses the mechanisms required for the data dependencies to be correctly
|
|
wired.
|
|
"""
|
|
if isinstance(tensor, ops.EagerTensor):
|
|
if name is None:
|
|
name = str(ops.uid())
|
|
|
|
# Small EagerTensors are captured with Const ops
|
|
if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
|
|
np.prod(tensor.shape) <= _EAGER_CONST_THRESHOLD):
|
|
return self.capture_eager_tensor(tensor, name)
|
|
|
|
# Large EagerTensors and resources are captured with Placeholder ops
|
|
return self._capture_helper(tensor, name, shape)
|
|
if tensor.graph is not self:
|
|
if name is None:
|
|
name = tensor.op.name
|
|
inner_graph = tensor.graph
|
|
while inner_graph is not None and isinstance(inner_graph, FuncGraph):
|
|
if inner_graph is self:
|
|
raise errors.InaccessibleTensorError(
|
|
"The tensor '%s' cannot be accessed here: it is defined"
|
|
" in another function or code block. Use return values,"
|
|
" explicit Python locals or TensorFlow collections to access"
|
|
" it. Defined in: %s; accessed from: %s.\n"
|
|
% (tensor, tensor.graph, self))
|
|
inner_graph = inner_graph.outer_graph
|
|
return self._capture_helper(tensor, name)
|
|
return tensor
|
|
|
|
def _capture_helper(self, tensor, name, shape=None):
|
|
capture = self._captures.get(id(tensor))
|
|
if capture is None:
|
|
placeholder = _create_substitute_placeholder(
|
|
tensor, name=name, dtype=tensor.dtype, shape=shape)
|
|
# Record the composite device as an attribute to the placeholder.
|
|
# This attribute would be propogated into the arg_attr of the FunctionDef.
|
|
# Currently, a packed eager tensor is always placed on a CompositeDevice.
|
|
if isinstance(tensor, ops.EagerTensor) and tensor.is_packed:
|
|
placeholder.op._set_attr( # pylint: disable=protected-access
|
|
"_composite_device",
|
|
attr_value_pb2.AttrValue(s=compat.as_bytes(tensor.device)))
|
|
self.add_capture(tensor, placeholder)
|
|
else:
|
|
placeholder = capture[1]
|
|
tape.record_operation("captured_value", [placeholder], [tensor],
|
|
backward_function=lambda x: [x],
|
|
forward_function=lambda x: [x])
|
|
return placeholder
|
|
|
|
@property
|
|
def captures(self):
|
|
"""Order list of tuples containing external and internal captures."""
|
|
return self._captures.values()
|
|
|
|
def add_capture(self, tensor, placeholder):
|
|
"""Capture a specific tensor and utilize the provided placeholder.
|
|
|
|
Args:
|
|
tensor: Tensor to captures.
|
|
placeholder: Provided placeholder for the tensor.
|
|
"""
|
|
self._captures[id(tensor)] = (tensor, placeholder)
|
|
self.inputs.append(placeholder)
|
|
|
|
def replace_capture(self, tensor, placeholder):
|
|
"""Replace already existing capture."""
|
|
self._captures[id(tensor)] = (tensor, placeholder)
|
|
|
|
def reset_captures(self, capture_list):
|
|
"""Set the captures with the provided list of captures & placeholder."""
|
|
self._captures = py_collections.OrderedDict()
|
|
for tensor, placeholder in capture_list:
|
|
self._captures[id(tensor)] = (tensor, placeholder)
|
|
|
|
def pop_capture(self, tensor):
|
|
"""Remove the capture and return the generated placeholder."""
|
|
capture = self._captures.pop(id(tensor), None)
|
|
if capture is None:
|
|
return None
|
|
|
|
return capture[1]
|
|
|
|
def clear_captures(self):
|
|
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
|
# Clearing captures using clear() leaves some cycles around.
|
|
while self._captures:
|
|
self._captures.popitem()
|
|
memory.dismantle_ordered_dict(self._captures)
|
|
while self._deferred_captures:
|
|
self._deferred_captures.popitem()
|
|
memory.dismantle_ordered_dict(self._deferred_captures)
|
|
|
|
def capture_distributed_variable(self, variable, placeholder):
|
|
"""Add given distributed variable to captures with given placeholder."""
|
|
self._captures[id(variable)] = (variable, placeholder)
|
|
tape.record_operation("captured_value", [placeholder], [variable],
|
|
backward_function=lambda x: [x],
|
|
forward_function=lambda x: [x])
|
|
|
|
def capture_eager_tensor(self, tensor, name):
|
|
capture = self._captures.get(id(tensor))
|
|
if capture is None:
|
|
# We clear all control dependencies and place the Const op on the same
|
|
# device as the source tensor. The device placement may be relaxed at
|
|
# a later date.
|
|
with ops.control_dependencies(None), self.device(tensor.device):
|
|
constant_value = tensor_util.constant_value(tensor)
|
|
if constant_value is None:
|
|
# Some eager tensors, e.g. parallel tensors, are not convertible to a
|
|
# single constant. We'll use a placeholder for this case.
|
|
return self._capture_helper(tensor, name)
|
|
graph_const = constant_op.constant(constant_value, dtype=tensor.dtype,
|
|
shape=tensor.shape, name=name)
|
|
self.add_capture(tensor, graph_const)
|
|
else:
|
|
graph_const = capture[1]
|
|
tape.record_operation("captured_value", [graph_const], [tensor],
|
|
backward_function=lambda x: [x],
|
|
forward_function=lambda x: [x])
|
|
return graph_const
|
|
|
|
def captured(self, tensor):
|
|
"""Check if the specified tensor has been captured."""
|
|
return id(tensor) in self._captures
|
|
|
|
@property
|
|
def external_captures(self):
|
|
"""External tensors captured by this function."""
|
|
return [c[0] for c in self._captures.values()]
|
|
|
|
@property
|
|
def internal_captures(self):
|
|
"""Placeholders in this function corresponding captured tensors."""
|
|
return [c[1] for c in self._captures.values()]
|
|
|
|
@property
|
|
def deferred_external_captures(self):
|
|
"""Ordered nest of tensors whose placeholders will be fed at call time."""
|
|
return [c[0] for c in self._deferred_captures.values()]
|
|
|
|
@property
|
|
def deferred_internal_captures(self):
|
|
"""List of nest of placeholders which at call time will be fed."""
|
|
return [c[1] for c in self._deferred_captures.values()]
|
|
|
|
@property
|
|
def variable_captures(self):
|
|
"""Map of python object ids of variables to variables which are captured."""
|
|
return {
|
|
id(self._captures[id(v)][1]): v
|
|
for v in self.variables
|
|
if id(v) in self._captures
|
|
}
|
|
|
|
def mark_as_unsaveable(self, error_message):
|
|
"""Marks this FuncGraph as unsaveable.
|
|
|
|
Any attempts to export this FuncGraph will raise an error with the specified
|
|
message.
|
|
|
|
Args:
|
|
error_message: List or string containing the error message to be raised
|
|
when saving this FuncGraph to SavedModel.
|
|
"""
|
|
self._saveable = False
|
|
if isinstance(error_message, str):
|
|
error_message = [error_message]
|
|
self._saving_errors.update(error_message)
|
|
|
|
@property
|
|
def saveable(self):
|
|
"""Returns whether this FuncGraph is saveable."""
|
|
return self._saveable
|
|
|
|
@property
|
|
def saving_errors(self):
|
|
"""Returns set of errors preventing this FuncGraph from being saved."""
|
|
return self._saving_errors
|
|
|
|
def _add_scope_exit_callback(self, fn):
|
|
"""Add a function to call when this graph exits the default scope."""
|
|
if not callable(fn):
|
|
raise TypeError("fn is not callable: {}".format(fn))
|
|
if self._scope_exit_callbacks is None:
|
|
raise RuntimeError(
|
|
"Attempting to add a scope exit callback, but the default graph is "
|
|
"not the context scope graph. Did you forget to call "
|
|
"'with graph.as_default(): ...'?")
|
|
self._scope_exit_callbacks.append(fn)
|
|
|
|
|
|
def func_graph_from_py_func(name,
|
|
python_func,
|
|
args,
|
|
kwargs,
|
|
signature=None,
|
|
func_graph=None,
|
|
autograph=False,
|
|
autograph_options=None,
|
|
add_control_dependencies=True,
|
|
arg_names=None,
|
|
op_return_value=None,
|
|
collections=None,
|
|
capture_by_value=None,
|
|
override_flat_arg_shapes=None):
|
|
"""Returns a `FuncGraph` generated from `python_func`.
|
|
|
|
Args:
|
|
name: an identifier for the function.
|
|
python_func: the Python function to trace.
|
|
args: the positional args with which the Python function should be called;
|
|
ignored if a signature is provided.
|
|
kwargs: the keyword args with which the Python function should be called;
|
|
ignored if a signature is provided.
|
|
signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
|
|
and dtypes of the arguments. When a signature is provided, `args` and
|
|
`kwargs` are ignored, and `python_func` is traced with Tensors conforming
|
|
to `signature`. If `None`, the shapes and dtypes are inferred from the
|
|
inputs.
|
|
func_graph: Optional. An instance of FuncGraph. If provided, we will use
|
|
this graph else a new one is built and returned.
|
|
autograph: whether to use autograph to compile `python_func`.
|
|
See https://www.tensorflow.org/guide/autograph for more information.
|
|
autograph_options: additional knobs to control when `autograph=True`.
|
|
See https://www.tensorflow.org/guide/autograph for more information.
|
|
add_control_dependencies: If True, automatically adds control dependencies
|
|
to ensure program order matches execution order and stateful ops always
|
|
execute.
|
|
arg_names: Optional list of argument names, used to give input placeholders
|
|
recognizable names.
|
|
op_return_value: Optional. A Tensor. If set and `python_func` returns
|
|
Operations, those return values will be replaced with this value. If not
|
|
set, returning an Operation triggers an error.
|
|
collections: a dictionary of collections this FuncGraph should start
|
|
with. If not specified (None), the FuncGraph will read (but not write to)
|
|
the outer graph's collections that are not allowlisted, and both
|
|
read and write to the outer graph's collections that are allowlisted.
|
|
The current allowlisted collections are the global variables, the
|
|
local variables, and the trainable variables.
|
|
Defaults to None.
|
|
capture_by_value: An optional boolean. If True, the func graph will capture
|
|
Variables by value instead of reference. By default inherit from outer
|
|
graphs, and failing that will default to False.
|
|
override_flat_arg_shapes: An optional list of instances that are either
|
|
`None` or `TensorShape`. The length must match that of
|
|
`nest.flatten((args, kwargs), expand_composites=True)`. The entries
|
|
containing value `None` must match entries in flattened arguments
|
|
containing non-tensors, while entries containing a `TensorShape` must
|
|
match entries in the flattened arguments containing tensors.
|
|
|
|
Returns:
|
|
A FuncGraph.
|
|
|
|
Raises:
|
|
TypeError: If any of `python_func`'s return values is neither `None` nor a
|
|
`Tensor`.
|
|
ValueError: If both `signature` and `override_flat_arg_shapes` are
|
|
passed in.
|
|
"""
|
|
if op_return_value is not None:
|
|
assert isinstance(op_return_value, ops.Tensor), op_return_value
|
|
if func_graph is None:
|
|
func_graph = FuncGraph(name, collections=collections,
|
|
capture_by_value=capture_by_value)
|
|
assert isinstance(func_graph, FuncGraph)
|
|
if add_control_dependencies:
|
|
deps_control_manager = auto_control_deps.AutomaticControlDependencies()
|
|
else:
|
|
deps_control_manager = ops.NullContextmanager()
|
|
|
|
with func_graph.as_default(), deps_control_manager as deps_ctx:
|
|
current_scope = variable_scope.get_variable_scope()
|
|
default_use_recource = current_scope.use_resource
|
|
current_scope.set_use_resource(True)
|
|
|
|
if signature is not None and override_flat_arg_shapes is not None:
|
|
raise ValueError(
|
|
"Passed both signature and override_flat_arg_shapes: %s and %s."
|
|
% (signature, override_flat_arg_shapes))
|
|
|
|
if signature is not None:
|
|
args = signature
|
|
kwargs = {}
|
|
|
|
# Creates and names placeholders for all arguments.
|
|
if override_flat_arg_shapes is not None:
|
|
flat_args = nest.flatten(args, expand_composites=True)
|
|
arg_shapes = override_flat_arg_shapes[:len(flat_args)]
|
|
kwarg_shapes = override_flat_arg_shapes[len(flat_args):]
|
|
else:
|
|
arg_shapes = None
|
|
kwarg_shapes = None
|
|
func_args = _get_defun_inputs_from_args(
|
|
args, arg_names, flat_shapes=arg_shapes)
|
|
func_kwargs = _get_defun_inputs_from_kwargs(
|
|
kwargs, flat_shapes=kwarg_shapes)
|
|
|
|
# Convert all Tensors into TensorSpecs before saving the structured inputs.
|
|
# If storing pure concrete functions that are not called through polymorphic
|
|
# functions, we don't have access to FunctionSpec, so we need to call the
|
|
# TensorSpecs by their `arg_names` for later binding.
|
|
func_graph.structured_input_signature = (
|
|
convert_structure_to_signature(func_args, arg_names),
|
|
convert_structure_to_signature(func_kwargs))
|
|
|
|
flat_func_args = nest.flatten(func_args, expand_composites=True)
|
|
flat_func_kwargs = nest.flatten(func_kwargs, expand_composites=True)
|
|
# Temporarily set inputs to allow graph building code to inspect
|
|
# them. Reassigned below.
|
|
func_graph.inputs = [arg for arg in flat_func_args + flat_func_kwargs
|
|
if isinstance(arg, ops.Tensor)]
|
|
|
|
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
|
|
# Variables to help check whether mutation happens in calling the function
|
|
# Copy the recursive list, tuple and map structure, but not base objects
|
|
func_args_before = nest.pack_sequence_as(func_args, flat_func_args,
|
|
expand_composites=True)
|
|
func_kwargs_before = nest.pack_sequence_as(
|
|
func_kwargs, flat_func_kwargs, expand_composites=True)
|
|
|
|
def convert(x):
|
|
"""Converts a function output to a Tensor."""
|
|
if x is None:
|
|
return None
|
|
if op_return_value is not None and isinstance(x, ops.Operation):
|
|
# TODO(b/79881896): we currently can't capture external control deps, so
|
|
# this won't work if x needs to be captured (i.e. if python_func returns
|
|
# captured Operations).
|
|
with ops.control_dependencies([x]):
|
|
x = array_ops.identity(op_return_value)
|
|
elif not isinstance(x, tensor_array_ops.TensorArray):
|
|
try:
|
|
x = ops.convert_to_tensor_or_composite(x)
|
|
except (ValueError, TypeError):
|
|
raise TypeError(
|
|
"To be compatible with tf.eager.defun, Python functions "
|
|
"must return zero or more Tensors; in compilation of %s, found "
|
|
"return value of type %s, which is not a Tensor." %
|
|
(str(python_func), type(x)))
|
|
if add_control_dependencies:
|
|
x = deps_ctx.mark_as_return(x)
|
|
return x
|
|
|
|
try:
|
|
if autograph:
|
|
from tensorflow.python import autograph # pylint: disable=g-import-not-at-top
|
|
_, original_func = tf_decorator.unwrap(python_func)
|
|
|
|
def wrapper(*args, **kwargs):
|
|
"""Calls a converted version of original_func."""
|
|
# TODO(mdan): Push this block higher in tf.function's call stack.
|
|
try:
|
|
return autograph.converted_call(
|
|
original_func,
|
|
args,
|
|
kwargs,
|
|
options=autograph.ConversionOptions(
|
|
recursive=True,
|
|
optional_features=autograph_options,
|
|
user_requested=True,
|
|
))
|
|
except Exception as e: # pylint:disable=broad-except
|
|
if hasattr(e, "ag_error_metadata"):
|
|
raise e.ag_error_metadata.to_exception(e)
|
|
else:
|
|
raise
|
|
|
|
# Wrapping around a decorator allows checks like tf_inspect.getargspec
|
|
# to be accurate.
|
|
converted_func = tf_decorator.make_decorator(original_func, wrapper)
|
|
python_func = tf_decorator.rewrap(python_func, original_func,
|
|
converted_func)
|
|
|
|
else:
|
|
_, original_func = tf_decorator.unwrap(python_func)
|
|
|
|
func_outputs = python_func(*func_args, **func_kwargs)
|
|
|
|
# invariant: `func_outputs` contains only Tensors, CompositeTensors,
|
|
# TensorArrays and `None`s.
|
|
func_outputs = nest.map_structure(convert, func_outputs,
|
|
expand_composites=True)
|
|
|
|
check_mutation(func_args_before, func_args, original_func)
|
|
check_mutation(func_kwargs_before, func_kwargs, original_func)
|
|
finally:
|
|
current_scope.set_use_resource(default_use_recource)
|
|
|
|
# Variables in `func_args`, `func_kwargs` should be explicit inputs
|
|
# to the function, not captured inputs.
|
|
graph_variables = list(func_graph._watched_variables) # pylint: disable=protected-access
|
|
arg_variables = object_identity.ObjectIdentitySet()
|
|
inputs = []
|
|
for arg in (nest.flatten(func_args, expand_composites=True) +
|
|
nest.flatten(func_kwargs, expand_composites=True)):
|
|
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
|
# Even if an argument variable was not used in the function, we've
|
|
# already manually captured the resource Tensor when creating argument
|
|
# placeholders.
|
|
resource_placeholder = func_graph.pop_capture(arg.handle)
|
|
if resource_placeholder is None:
|
|
continue
|
|
arg_variables.add(arg)
|
|
inputs.append(resource_placeholder)
|
|
elif isinstance(arg, ops.Tensor):
|
|
inputs.append(arg)
|
|
variables = [v for v in graph_variables if v not in arg_variables]
|
|
func_graph.inputs = (
|
|
inputs + func_graph.internal_captures + nest.flatten(
|
|
func_graph.deferred_internal_captures, expand_composites=True))
|
|
func_graph.structured_outputs = func_outputs
|
|
# Returning a closed-over tensor does not trigger convert_to_tensor.
|
|
func_graph.outputs.extend(
|
|
func_graph.capture(x)
|
|
for x in flatten(func_graph.structured_outputs)
|
|
if x is not None)
|
|
|
|
func_graph.variables = variables
|
|
|
|
if add_control_dependencies:
|
|
func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run)
|
|
func_graph.collective_manager_ids_used = (
|
|
deps_control_manager.collective_manager_ids_used)
|
|
|
|
return func_graph
|
|
|
|
|
|
def maybe_captured(tensor):
|
|
"""If t is a captured value placeholder, returns the original captured value.
|
|
|
|
Args:
|
|
tensor: Tensor.
|
|
|
|
Returns:
|
|
A tensor, potentially from a different Graph/FuncGraph.
|
|
"""
|
|
if (not isinstance(tensor, ops.EagerTensor) and
|
|
tensor.op.graph.building_function and tensor.op.type == "Placeholder"):
|
|
for input_t, placeholder_t in tensor.op.graph.captures:
|
|
if tensor == placeholder_t:
|
|
return maybe_captured(input_t)
|
|
# pylint: enable=protected-access
|
|
return tensor
|
|
|
|
|
|
def device_stack_has_callable(device_stack):
|
|
"""Checks whether a device stack contains a callable."""
|
|
return any(callable(spec._device_name_or_function) # pylint: disable=protected-access
|
|
for spec in device_stack.peek_objs())
|
|
|
|
|
|
def check_mutation(n1, n2, func):
|
|
"""Check if two list of arguments are exactly the same."""
|
|
func_name = getattr(func, "__name__", func)
|
|
|
|
errmsg = ("{}() should not modify its Python input arguments."
|
|
" Check if it modifies any lists or dicts passed as"
|
|
" arguments. Modifying a copy is allowed.".format(func_name))
|
|
try:
|
|
# TODO(mdan): Compare more robustly so that argument names can be reported.
|
|
nest.assert_same_structure(n1, n2, expand_composites=True)
|
|
except ValueError:
|
|
raise ValueError(errmsg)
|
|
|
|
for arg1, arg2 in zip(nest.flatten(n1, expand_composites=True),
|
|
nest.flatten(n2, expand_composites=True)):
|
|
if arg1 is not arg2:
|
|
raise ValueError(errmsg)
|
|
|
|
|
|
# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
|
|
def flatten(sequence):
|
|
"""Like nest.flatten w/ expand_composites, but returns flow for TensorArrays.
|
|
|
|
Args:
|
|
sequence: A nested structure of Tensors, CompositeTensors, and
|
|
TensorArrays.
|
|
|
|
Returns:
|
|
A list of tensors.
|
|
"""
|
|
flat_sequence = nest.flatten(sequence, expand_composites=True)
|
|
return [
|
|
item.flow if isinstance(item, tensor_array_ops.TensorArray) else item
|
|
for item in flat_sequence]
|
|
|
|
|
|
# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this.
|
|
def pack_sequence_as(structure, flat_sequence):
|
|
"""Like `nest.pack_sequence_as` but also builds TensorArrays from flows.
|
|
|
|
Args:
|
|
structure: The structure to pack into. May contain Tensors,
|
|
CompositeTensors, or TensorArrays.
|
|
flat_sequence: An iterable containing tensors.
|
|
|
|
Returns:
|
|
A nested structure.
|
|
|
|
Raises:
|
|
AssertionError if `structure` and `flat_sequence` are not compatible.
|
|
"""
|
|
flat_sequence = list(flat_sequence)
|
|
flattened_structure = nest.flatten(structure, expand_composites=True)
|
|
if len(flattened_structure) != len(flat_sequence):
|
|
raise ValueError("Mismatch in element count")
|
|
for i in range(len(flat_sequence)):
|
|
if isinstance(flattened_structure[i], tensor_array_ops.TensorArray):
|
|
flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow(
|
|
old_ta=flattened_structure[i], flow=flat_sequence[i])
|
|
return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True)
|
|
|
|
|
|
def _create_substitute_placeholder(value, name=None, dtype=None, shape=None):
|
|
"""Creates a placeholder for `value` and propagates shape info to it."""
|
|
# Note: setting ops.control_dependencies(None) ensures we always put
|
|
# capturing placeholders outside of any control flow context.
|
|
if shape is None:
|
|
shape = value.shape
|
|
with ops.control_dependencies(None):
|
|
placeholder = graph_placeholder(
|
|
dtype=dtype or value.dtype, shape=shape, name=name)
|
|
custom_gradient.copy_handle_data(value, placeholder)
|
|
return placeholder
|
|
|
|
|
|
def _get_defun_inputs_from_args(args, names, flat_shapes=None):
|
|
"""Maps Python function positional args to graph-construction inputs."""
|
|
return _get_defun_inputs(
|
|
args, names, structure=args, flat_shapes=flat_shapes)
|
|
|
|
|
|
def _get_composite_tensor_spec(x):
|
|
"""Returns the TypeSpec for x if it's a composite tensor, or x otherwise."""
|
|
return (x._type_spec # pylint: disable=protected-access
|
|
if isinstance(x, composite_tensor.CompositeTensor) else x)
|
|
|
|
|
|
def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
|
"""Maps python function args to graph-construction inputs.
|
|
|
|
Args:
|
|
args: A flat list of user-specified arguments.
|
|
names: A list of strings with user-specified argument names, same length as
|
|
`args`. May be `None`, in which case a generic name is used.
|
|
structure: The original argument list or dictionary.
|
|
flat_shapes: A flat list of values that are either `None` or
|
|
instances of `TensorShape`. If provided, then length must match
|
|
that of `nest.flatten(args, expand_composites=True)`; and locations where
|
|
`args` are instances of `Tensor` must have a corresponding `TensorShape`
|
|
in `flat_shapes`. May be `None`, in which case exact shapes are read
|
|
directly from the args.
|
|
|
|
Returns:
|
|
Placeholders with the same structure as `structure`.
|
|
|
|
Raises:
|
|
RuntimeError: if `flat_shapes` is provided, but
|
|
`len(flat_shapes) != len(nest.flatten(args, expand_composites=True))`.
|
|
RuntimeError: if a shape from `flat_shapes` is not None
|
|
for an argument that is not a `Tensor`, `TensorSpec`,
|
|
or `ResourceVariable`.
|
|
"""
|
|
func_graph = ops.get_default_graph()
|
|
function_inputs = []
|
|
if names is None:
|
|
names = [None] * len(args)
|
|
if flat_shapes is None:
|
|
shapes_iter = itertools.repeat(None)
|
|
else:
|
|
len_flat_args = len(nest.flatten(args, expand_composites=True))
|
|
if len_flat_args != len(flat_shapes):
|
|
raise RuntimeError(
|
|
"Length of fully flat shapes (%d) must match that of "
|
|
"flatten(args) (%d). args: %s, flat_shapes: %s"
|
|
% (len(flat_shapes),
|
|
len_flat_args,
|
|
args,
|
|
flat_shapes))
|
|
shapes_iter = iter(flat_shapes)
|
|
for arg_value, name in zip(args, names):
|
|
|
|
# Replace any composite tensors with their TypeSpecs. This is important
|
|
# for ensuring that shape information that's not preserved by the TypeSpec
|
|
# (such as the number of values in a SparseTensor) gets properly masked.
|
|
arg_value = nest.map_structure(_get_composite_tensor_spec, arg_value)
|
|
|
|
flattened = nest.flatten(arg_value, expand_composites=True)
|
|
|
|
for arg in flattened:
|
|
# We have a shape entry for each arg, regardless of whether it's a real
|
|
# Tensor or not. For non-tensor entries it should be None.
|
|
shape = next(shapes_iter)
|
|
if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec)):
|
|
arg_is_spec = isinstance(arg, tensor_spec.TensorSpec)
|
|
if arg_is_spec and arg.name:
|
|
requested_name = arg.name
|
|
else:
|
|
requested_name = name
|
|
placeholder_shape = shape if shape is not None else arg.shape
|
|
try:
|
|
placeholder = graph_placeholder(
|
|
arg.dtype, placeholder_shape,
|
|
name=requested_name)
|
|
except ValueError:
|
|
# Sometimes parameter names are not valid op names, so fall back to
|
|
# unnamed placeholders.
|
|
placeholder = graph_placeholder(arg.dtype, placeholder_shape)
|
|
if not arg_is_spec:
|
|
custom_gradient.copy_handle_data(arg, placeholder)
|
|
if name is not None:
|
|
# Record the requested/user-specified name in case it's different than
|
|
# the uniquified name, for validation when exporting signatures.
|
|
placeholder.op._set_attr( # pylint: disable=protected-access
|
|
"_user_specified_name",
|
|
attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
|
|
function_inputs.append(placeholder)
|
|
elif isinstance(arg, (resource_variable_ops.BaseResourceVariable,
|
|
resource_variable_ops.VariableSpec)):
|
|
if isinstance(arg, resource_variable_ops.VariableSpec):
|
|
name = arg.name or name
|
|
with func_graph.outer_graph.as_default():
|
|
placeholder = graph_placeholder(dtypes.resource, arg.shape,
|
|
name=name)
|
|
|
|
arg = resource_variable_ops.BaseResourceVariable(
|
|
name=name,
|
|
shape=arg.shape,
|
|
dtype=arg.dtype,
|
|
handle=placeholder,
|
|
handle_name=name)
|
|
# Capture arg variables to create placeholders for them. These will be
|
|
# removed as captures after the function is traced (since otherwise we'd
|
|
# just add it back with a new placeholder when the variable was
|
|
# referenced).
|
|
placeholder = func_graph.capture(arg.handle, name=name)
|
|
placeholder.op._set_attr( # pylint: disable=protected-access
|
|
"_user_specified_name",
|
|
attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
|
|
function_inputs.append(arg)
|
|
else:
|
|
if shape is not None:
|
|
raise RuntimeError(
|
|
"Expected provided shape override to be None for arg that isn't "
|
|
"a Tensor, but saw arg: '%s', shape: '%s'. args: %s"
|
|
% (arg, shape, args))
|
|
function_inputs.append(arg)
|
|
return nest.pack_sequence_as(structure, function_inputs,
|
|
expand_composites=True)
|
|
|
|
|
|
def _get_defun_inputs_from_kwargs(kwargs, flat_shapes):
|
|
"""Maps Python function keyword args to graph-construction inputs."""
|
|
if kwargs:
|
|
names, args = zip(*sorted(kwargs.items()))
|
|
else:
|
|
names = []
|
|
args = []
|
|
return _get_defun_inputs(
|
|
args, names, structure=kwargs, flat_shapes=flat_shapes)
|
|
|
|
|
|
def dismantle_func_graph(func_graph):
|
|
"""Removes reference cycles in `func_graph` FuncGraph.
|
|
|
|
Helpful for making sure the garbage collector doesn't need to run when
|
|
the FuncGraph goes out of scope, e.g. in tests using defun with
|
|
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
|
|
|
|
Args:
|
|
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
|
|
after this function.
|
|
"""
|
|
func_graph.clear_captures()
|
|
ops.dismantle_graph(func_graph)
|
|
|
|
|
|
def override_func_graph_name_scope(func_graph, name_scope):
|
|
func_graph._name_stack = name_scope # pylint: disable=protected-access
|