Functions returned by the get_concrete_function
method of tf.Function
objects can now be called with arguments consistent with the original arguments or type specs passed to get_concrete_function
. In particular:
* If a composite tensor (such as RaggedTensor or SparseTensor) was passed to `get_concrete_function`, then the returned function will accept a composite tensor of the same type for that argument. * If a nested structure (such as a list or dict) was passed to `get_concrete_function`, then the returned function will accept a value with the same nesting structure. Each tensor or composite tensor value must have the same type as was used in the original argument; and each non-Tensor value (such as bools or ints) must be equal value that was used in the original argument. * If a non-tensor value (such as a bool or int) was passed to `get_concrete_function`, then the returned function no longer deletes that argument; instead, it updates the argument's default value to the value that was passed to `get_concrete_function`. Passing in any other value will raise an exception. * Arguments are not renamed based on `TensorSpec.name`. For backwards compatibility, the functions returned by `get_concrete_function` will continue to accept arguments with the existing calling conventions (where nested structures and composite tensors are flattened; non-tensor arguments are deleted; suffixes are automatically added to disambiguate arguments with the same name; and TensorSpec.name is used to rename arguments). However, the preferred calling convention is the one that is consistent with the original arguments or type specs passed to `get_concrete_function`. PiperOrigin-RevId: 307398918 Change-Id: Ie4685b32d9f151c82f6c79a6c41379faa96b5ee8
This commit is contained in:
parent
0887fedd2d
commit
f39aab3092
tensorflow/python
@ -830,6 +830,13 @@ class Function(object):
|
||||
def function_spec(self):
|
||||
return self._function_spec
|
||||
|
||||
def pretty_printed_concrete_signatures(self, verbose=True):
|
||||
joiner = "\n\n" if verbose else "\n"
|
||||
return joiner.join([
|
||||
c.pretty_printed_signature(verbose=verbose)
|
||||
for c in self._list_all_concrete_functions()
|
||||
])
|
||||
|
||||
def _initialize_uninitialized_variables(self, initializers):
|
||||
"""Make and call a `ConcreteFunction` which initializes variables."""
|
||||
|
||||
@ -913,12 +920,8 @@ class Function(object):
|
||||
|
||||
return initialize_variables.get_concrete_function()
|
||||
|
||||
def _list_all_concrete_functions_for_serialization(self):
|
||||
"""Returns all concrete functions for serialization.
|
||||
|
||||
Returns:
|
||||
A list of instances of `ConcreteFunction`.
|
||||
"""
|
||||
def _list_all_concrete_functions(self):
|
||||
"""Returns all concrete functions."""
|
||||
if self.input_signature is not None:
|
||||
self.get_concrete_function()
|
||||
concrete_functions = []
|
||||
@ -930,6 +933,15 @@ class Function(object):
|
||||
concrete_functions.extend(
|
||||
self._stateless_fn._function_cache.all_values())
|
||||
# pylint: enable=protected-access
|
||||
return concrete_functions
|
||||
|
||||
def _list_all_concrete_functions_for_serialization(self):
|
||||
"""Returns all concrete functions for serialization.
|
||||
|
||||
Returns:
|
||||
A list of instances of `ConcreteFunction`.
|
||||
"""
|
||||
concrete_functions = self._list_all_concrete_functions()
|
||||
seen_signatures = []
|
||||
for concrete_function in concrete_functions:
|
||||
signature = concrete_function.structured_input_signature
|
||||
|
@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
@ -340,7 +341,7 @@ class _InterpolateFunctionError(object):
|
||||
if t.name == compat.as_str(self._func.name):
|
||||
g = self._func.graph
|
||||
elif g:
|
||||
next_func = g._get_function(t.name)
|
||||
next_func = g._get_function(t.name) # pylint: disable=protected-access
|
||||
if next_func is not None and isinstance(next_func,
|
||||
_EagerDefinedFunction):
|
||||
g = next_func.graph
|
||||
@ -1499,6 +1500,12 @@ class _ForwardBackwardCall(object):
|
||||
flat_outputs, self._inference_args, self._input_tangents)
|
||||
|
||||
|
||||
# Sentinel value used by with ConcreteFunction's structured signature to
|
||||
# indicate that a non-tensor parameter should use the value that was
|
||||
# specified when the concrete function was created.
|
||||
_BOUND_VALUE = object()
|
||||
|
||||
|
||||
class ConcreteFunction(object):
|
||||
"""Callable object encapsulating a function definition and its gradient.
|
||||
|
||||
@ -1506,7 +1513,11 @@ class ConcreteFunction(object):
|
||||
is differentiable under `tf.GradientTape` objects.
|
||||
"""
|
||||
|
||||
def __init__(self, func_graph, attrs=None, shared_func_graph=True):
|
||||
def __init__(self,
|
||||
func_graph,
|
||||
attrs=None,
|
||||
shared_func_graph=True,
|
||||
function_spec=None):
|
||||
"""Initialize a `ConcreteFunction`.
|
||||
|
||||
Args:
|
||||
@ -1517,16 +1528,25 @@ class ConcreteFunction(object):
|
||||
shared_func_graph: If False, the ConcreteFunction takes ownership of
|
||||
`func_graph` and will break reference cycles when it is deleted. This
|
||||
makes the FuncGraph inoperable.
|
||||
function_spec: FunctionSpec for the original function. If not specified,
|
||||
then this ConcreteFunction may only be called using the flat signature.
|
||||
|
||||
Raises:
|
||||
ValueError: If number of input_placeholders is not equal to the number
|
||||
of function inputs.
|
||||
"""
|
||||
# _arg_keywords and _num_positional_args define the flat signature. They
|
||||
# are assigned after construction.
|
||||
self._arg_keywords = None
|
||||
self._num_positional_args = None
|
||||
|
||||
self._func_graph = func_graph
|
||||
self._captured_inputs = self._func_graph.external_captures
|
||||
self._captured_closures = self._func_graph.deferred_external_captures
|
||||
|
||||
# function_spec defines the structured signature.
|
||||
self._set_function_spec(function_spec)
|
||||
|
||||
if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
|
||||
# The alternative is to silently drop "implements" tag
|
||||
# but it seems likely it would lead to hard to catch bugs.
|
||||
@ -1576,6 +1596,52 @@ class ConcreteFunction(object):
|
||||
# building gradients.
|
||||
self._inference_function = self._delayed_rewrite_functions.forward()
|
||||
|
||||
def _set_function_spec(self, function_spec):
|
||||
"""Enables the structured signature by supplying a function_spec."""
|
||||
self._function_spec = None
|
||||
self._pre_initialized_function_spec = function_spec
|
||||
|
||||
# Note: when ConcreteFunctions are built by recreate_function() in
|
||||
# function_deserialization.py, they don't have a structured_input_signature
|
||||
# yet. In that case, _initialize_function_spec() gets called by
|
||||
# _setup_functions_structures() in load.py.
|
||||
if (function_spec is not None and
|
||||
self.structured_input_signature is not None):
|
||||
self._initialize_function_spec()
|
||||
|
||||
def _initialize_function_spec(self):
|
||||
"""Updates `self._function_spec` to include varargs and bound variables.
|
||||
|
||||
Adds new positional arguments for any varargs (i.e., for args that are
|
||||
in `structured_input_signature`, but not in the original fullargspec.args).
|
||||
|
||||
Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
|
||||
all args and kwargs in `structured_input_signature`.
|
||||
|
||||
Sets `varkw` and `varargs` to None.
|
||||
"""
|
||||
if self._pre_initialized_function_spec is None:
|
||||
return # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
|
||||
assert not self._function_spec, "already initialized"
|
||||
function_spec = self._pre_initialized_function_spec
|
||||
args = function_spec.fullargspec.args
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
fullargspec = tf_inspect.FullArgSpec(
|
||||
args=list(args) +
|
||||
["<arg{}>".format(i + 1) for i in range(len(args), len(arg_specs))],
|
||||
varargs=None,
|
||||
varkw=None,
|
||||
defaults=[_BOUND_VALUE] * len(arg_specs),
|
||||
kwonlyargs=list(sorted(kwarg_specs)),
|
||||
kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
|
||||
annotations=function_spec.fullargspec.annotations)
|
||||
self._function_spec = FunctionSpec(
|
||||
fullargspec,
|
||||
function_spec.is_method,
|
||||
function_spec.input_signature,
|
||||
function_spec.is_pure,
|
||||
name=self._func_graph.name)
|
||||
|
||||
@property
|
||||
def variables(self):
|
||||
"""Sequence of variables for this function."""
|
||||
@ -1589,15 +1655,44 @@ class ConcreteFunction(object):
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Executes the wrapped function.
|
||||
|
||||
ConcreteFunctions have two signatures:
|
||||
|
||||
* The signature of the original function wrapped by this ConcreteFunction.
|
||||
* A flat signature, where each argument accepts a single Tensor.
|
||||
|
||||
The original function signature is generally preferred, but the flat input
|
||||
signature is supported for backward compatibility.
|
||||
|
||||
### Original Function Signature
|
||||
|
||||
When calling a ConcreteFunction with the signature of the original function,
|
||||
each argument must match the type or value that was used when the
|
||||
ConcreteFunction's graph was traced. In particular:
|
||||
|
||||
* Tensor arguments (including CompositeTensors, such as RaggedTensor) must
|
||||
have matching `TypeSpec`s.
|
||||
* Non-Tensor arguments (such as booleans or ints) must have equal values.
|
||||
* Nested arguments (such as lists, tuples, or dictionaries) must have the
|
||||
same nesting structure; and each nested value must have a matching type
|
||||
or value.
|
||||
|
||||
The default value for any arguments that were traced with non-Tensor values
|
||||
is the value that was used in the trace. Arguments that were traced with
|
||||
tensor arguments do not have a default value (even if the original function
|
||||
had a default value for that argument).
|
||||
|
||||
### Flat Signature
|
||||
|
||||
When calling a ConcreteFunction with the flat signature, the arguments
|
||||
correspond to the flattened component tensors of the arguments that were
|
||||
used to construct the ConcreteFunction. Parameter names are assigned based
|
||||
on `TensorSpec.name` (when specified) or the original argument names (with
|
||||
suffixes automatically added for nested arguments or composite tensors with
|
||||
multiple components).
|
||||
|
||||
Args:
|
||||
*args: Tensors or Variables. Positional arguments are only accepted when
|
||||
they correspond one-to-one with arguments of the traced Python function.
|
||||
**kwargs: Tensors or Variables specified by name. When
|
||||
`get_concrete_function` was called to create this `ConcreteFunction`,
|
||||
each Tensor input was given a name, defaulting to the name of the Python
|
||||
function's argument but possibly overridden by the `name=` argument to
|
||||
`tf.TensorSpec`. These names become the argument names for the concrete
|
||||
function.
|
||||
*args: Positional arguments to the concrete function.
|
||||
**kwargs: Keyword arguments to the concrete function.
|
||||
|
||||
Returns:
|
||||
The result of applying the TF function on the given Tensors.
|
||||
@ -1605,9 +1700,7 @@ class ConcreteFunction(object):
|
||||
Raises:
|
||||
AssertionError: If this `ConcreteFunction` was not created through
|
||||
`get_concrete_function`.
|
||||
ValueError: If arguments contains anything other than Tensors or
|
||||
Variables.
|
||||
TypeError: For invalid positional/keyword argument combinations.
|
||||
TypeError: If the arguments do not match the function's signature.
|
||||
"""
|
||||
return self._call_impl(args, kwargs)
|
||||
|
||||
@ -1615,40 +1708,174 @@ class ConcreteFunction(object):
|
||||
"""See `__call__` for details."""
|
||||
with traceme.TraceMe(self._func_graph.name,
|
||||
tf_function_call="concrete"):
|
||||
if self._arg_keywords is None or self._num_positional_args is None:
|
||||
raise AssertionError(
|
||||
"Tried to call a concrete function obtained from an internal API "
|
||||
"through the public interface. Use get_concrete_function instead.")
|
||||
if len(args) > self._num_positional_args:
|
||||
raise TypeError(
|
||||
("Expected at most {} positional arguments (and the rest keywords, "
|
||||
"of {}), got {}. When calling a concrete function, positional "
|
||||
"arguments may not be bound to Tensors within nested structures."
|
||||
).format(self._num_positional_args, self._arg_keywords, args))
|
||||
args = list(args)
|
||||
for keyword in self._arg_keywords[len(args):]:
|
||||
# Construct the list of input tensors: check if the structured signature
|
||||
# applies first; and if not, then use the flat signature.
|
||||
if self._function_spec is not None:
|
||||
try:
|
||||
args.append(kwargs.pop(compat.as_str(keyword)))
|
||||
except KeyError:
|
||||
specified_keywords = (list(self._arg_keywords[:len(args)])
|
||||
+ list(kwargs.keys()))
|
||||
raise TypeError(
|
||||
"Expected argument names {} but got values for {}. Missing: {}."
|
||||
.format(
|
||||
list(self._arg_keywords),
|
||||
specified_keywords,
|
||||
list(set(self._arg_keywords) - set(specified_keywords))))
|
||||
if kwargs:
|
||||
positional_arg_keywords = set(self._arg_keywords[:len(args)])
|
||||
for unused_key in kwargs:
|
||||
if unused_key in positional_arg_keywords:
|
||||
raise TypeError("Got two values for keyword '{}'.".format(
|
||||
unused_key))
|
||||
raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
|
||||
list(kwargs.keys()), list(self._arg_keywords)))
|
||||
return self._call_flat(args, self.captured_inputs, cancellation_manager)
|
||||
return self._call_with_structured_signature(args, kwargs,
|
||||
cancellation_manager)
|
||||
except TypeError as structured_err:
|
||||
try:
|
||||
return self._call_with_flat_signature(args, kwargs,
|
||||
cancellation_manager)
|
||||
except TypeError:
|
||||
raise structured_err
|
||||
|
||||
def _filtered_call(self, args, kwargs):
|
||||
return self._call_with_flat_signature(args, kwargs, cancellation_manager)
|
||||
|
||||
def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
|
||||
"""Executes the wrapped function with the flat signature.
|
||||
|
||||
Args:
|
||||
args: Positional arguments to the concrete function.
|
||||
kwargs: Keyword arguments to the concrete function.
|
||||
cancellation_manager: A `CancellationManager` that can be used to cancel
|
||||
function invocation.
|
||||
|
||||
Returns:
|
||||
The result of applying the function on the Tensors/Variables contained in
|
||||
`args` and `kwargs`.
|
||||
Raises:
|
||||
TypeError: if `args` and `kwargs` do not match the flat signature of this
|
||||
`ConcreteFunction`.
|
||||
"""
|
||||
if len(args) > self._num_positional_args:
|
||||
raise TypeError(
|
||||
"{} takes {} positional arguments but {} were given".format(
|
||||
self._flat_signature_summary(), self._num_positional_args,
|
||||
len(args)))
|
||||
args = list(args)
|
||||
kwargs = dict(kwargs)
|
||||
for keyword in self._arg_keywords[len(args):]:
|
||||
try:
|
||||
args.append(kwargs.pop(compat.as_str(keyword)))
|
||||
except KeyError:
|
||||
specified_keywords = (
|
||||
list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
self._flat_signature_summary(), ", ".join(
|
||||
sorted(set(self._arg_keywords) - set(specified_keywords)))))
|
||||
if kwargs:
|
||||
positional_arg_keywords = set(self._arg_keywords[:len(args)])
|
||||
for unused_key in kwargs:
|
||||
if unused_key in positional_arg_keywords:
|
||||
raise TypeError("{} got two values for argument '{}'".format(
|
||||
self._flat_signature_summary(), unused_key))
|
||||
raise TypeError("{} got unexpected keyword arguments: {}.".format(
|
||||
self._flat_signature_summary(), ", ".join(sorted(kwargs))))
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
if not isinstance(
|
||||
arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
|
||||
raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
|
||||
"got {} ({})".format(self._flat_signature_summary(), i,
|
||||
type(arg).__name__, str(arg)))
|
||||
return self._call_flat(args, self.captured_inputs, cancellation_manager)
|
||||
|
||||
def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
|
||||
"""Executes the wrapped function with the structured signature.
|
||||
|
||||
Args:
|
||||
args: Positional arguments to the concrete function.
|
||||
kwargs: Keyword arguments to the concrete function.
|
||||
cancellation_manager: A `CancellationManager` that can be used to cancel
|
||||
function invocation.
|
||||
|
||||
Returns:
|
||||
The result of applying the function on the Tensors/Variables contained in
|
||||
`args` and `kwargs`.
|
||||
Raises:
|
||||
TypeError: if `args` and `kwargs` do not match the structured signature
|
||||
of this `ConcreteFunction`.
|
||||
"""
|
||||
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
self._structured_signature_check_missing_args(args, kwargs)
|
||||
self._structured_signature_check_unexpected_args(args, kwargs)
|
||||
self._structured_signature_check_arg_types(args, kwargs)
|
||||
return self._filtered_call(args, kwargs, cancellation_manager)
|
||||
|
||||
def _structured_signature_check_missing_args(self, args, kwargs):
|
||||
"""Raises a TypeError if any args are missing."""
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
missing_arguments = []
|
||||
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
|
||||
if arg is _BOUND_VALUE and _contains_type_spec(spec):
|
||||
missing_arguments.append(self._function_spec.arg_names[i])
|
||||
for (name, arg) in kwargs.items():
|
||||
if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
|
||||
missing_arguments.append(name)
|
||||
if missing_arguments:
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
self._structured_signature_summary(),
|
||||
", ".join(sorted(missing_arguments))))
|
||||
|
||||
def _structured_signature_check_unexpected_args(self, args, kwargs):
|
||||
"""Raises a TypeError if there are any extra args."""
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
if len(args) > len(arg_specs):
|
||||
raise TypeError(
|
||||
"{} takes {} positional arguments but {} were given".format(
|
||||
self._structured_signature_summary(),
|
||||
len(self._function_spec.arg_names), len(args)))
|
||||
if len(kwargs) > len(kwarg_specs):
|
||||
extra_args = set(kwargs) - set(kwarg_specs)
|
||||
raise TypeError("{} got unexpected keyword arguments: {}".format(
|
||||
self._structured_signature_summary(), ", ".join(extra_args)))
|
||||
|
||||
def _structured_signature_check_arg_types(self, args, kwargs):
|
||||
"""Raises a TypeError if any args have the wrong type."""
|
||||
# Check argument types
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
|
||||
name = self._function_spec.arg_names[i]
|
||||
self._structured_signature_check_arg_type(arg, spec, name)
|
||||
for (name, arg) in kwargs.items():
|
||||
self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
|
||||
|
||||
def _structured_signature_check_arg_type(self, arg, spec, name):
|
||||
"""Raise TypeError if `arg`'s type doesn't match `spec`."""
|
||||
if arg is _BOUND_VALUE:
|
||||
return
|
||||
|
||||
# Check the overall nested structure of the argument.
|
||||
try:
|
||||
nest.assert_same_structure(arg, spec, expand_composites=True)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
nest.assert_same_structure(arg, spec, expand_composites=False)
|
||||
expected, got = spec, arg
|
||||
except (ValueError, TypeError):
|
||||
expected, got = _structure_summary(spec), _structure_summary(arg)
|
||||
raise TypeError("{}: argument {} had incorrect type\n"
|
||||
" expected: {}\n got: {}".format(
|
||||
self._structured_signature_summary(), name, expected,
|
||||
got))
|
||||
|
||||
# Check the type for each leaf in the nested structure.
|
||||
arg_pieces = nest.flatten(arg, expand_composites=True)
|
||||
spec_pieces = nest.flatten(spec, expand_composites=True)
|
||||
for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
|
||||
if isinstance(spec_piece, tensor_spec.DenseSpec):
|
||||
# TODO(edloper): Consider calling convert_to_tensor on non-tensor
|
||||
# values here. That would match the behavior of
|
||||
# _call_concrete_function() in function_deserialization.py. If
|
||||
# we do, then we need to change the nest assert_same_structure and
|
||||
# flatten calls above to use shallow variants.
|
||||
tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
|
||||
if not isinstance(arg_piece, tensor_types):
|
||||
raise TypeError(
|
||||
"{} expected a Tensor in {}, but got {} value {}".format(
|
||||
self._structured_signature_summary(), name,
|
||||
type(arg_piece).__name__, arg_piece))
|
||||
elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece:
|
||||
raise TypeError("ConcreteFunction {} was constructed with {} value "
|
||||
"{} in {}, but was called with {} value {}".format(
|
||||
self._structured_signature_summary(),
|
||||
type(spec_piece).__name__, spec_piece, name,
|
||||
type(arg_piece).__name__, arg_piece))
|
||||
|
||||
def _filtered_call(self, args, kwargs, cancellation_manager=None):
|
||||
"""Executes the function, filtering arguments from the Python function.
|
||||
|
||||
Objects aside from Tensors, CompositeTensors, and Variables are ignored.
|
||||
@ -1657,6 +1884,8 @@ class ConcreteFunction(object):
|
||||
Args:
|
||||
args: Canonicalized positional arguments of the Python function.
|
||||
kwargs: Canonicalized keyword arguments of the Python function.
|
||||
cancellation_manager: (Optional.) A `CancellationManager` that can be
|
||||
used to cancel function invocation.
|
||||
|
||||
Returns:
|
||||
The result of applying the function on the Tensors/Variables contained in
|
||||
@ -1666,7 +1895,8 @@ class ConcreteFunction(object):
|
||||
(t for t in nest.flatten((args, kwargs), expand_composites=True)
|
||||
if isinstance(t, (ops.Tensor,
|
||||
resource_variable_ops.BaseResourceVariable))),
|
||||
self.captured_inputs)
|
||||
captured_inputs=self.captured_inputs,
|
||||
cancellation_manager=cancellation_manager)
|
||||
|
||||
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
|
||||
"""Executes the wrapped function.
|
||||
@ -1795,7 +2025,26 @@ class ConcreteFunction(object):
|
||||
|
||||
@property
|
||||
def structured_input_signature(self):
|
||||
"""Returns structured signature of the original function."""
|
||||
"""Returns structured signature for this concrete function.
|
||||
|
||||
Returns:
|
||||
A tuple `(args, kwargs)`, where:
|
||||
|
||||
* `args` is a tuple that specifies the expected type or value each for
|
||||
positional argument.
|
||||
* `kwargs` is a dictionary that specifies the expected type or value
|
||||
for each keyword-only argument.
|
||||
|
||||
The type or value for each argument is specified using one of the
|
||||
following:
|
||||
|
||||
* A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
|
||||
value is expected.
|
||||
* A Python value, such as an integer, indicating that an equal value
|
||||
is expected.
|
||||
* A nested structure of `tf.TypeSpec`s and Python values, indicating
|
||||
that a corresponding nested structure is expected.
|
||||
"""
|
||||
return self._func_graph.structured_input_signature
|
||||
|
||||
@property
|
||||
@ -1982,6 +2231,103 @@ class ConcreteFunction(object):
|
||||
ret.attr[name].CopyFrom(value)
|
||||
return ret
|
||||
|
||||
def _structured_signature_summary(self, default_values=False):
|
||||
"""Returns a string summarizing this function's structured signature.
|
||||
|
||||
Args:
|
||||
default_values: If true, then include default values in the signature.
|
||||
|
||||
Returns:
|
||||
A `string`.
|
||||
"""
|
||||
# Note: we can't just use self._funcion_spec.signature_summary(), because
|
||||
# that would show "_BOUND_VALUE" as the default value for all arguments.
|
||||
assert self._function_spec is not None
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
arg_names = list(self._function_spec.arg_names)
|
||||
if default_values:
|
||||
for i in range(len(arg_names)):
|
||||
if not _contains_type_spec(arg_specs[i]):
|
||||
arg_names[i] += "={}".format(arg_specs[i])
|
||||
if kwarg_specs:
|
||||
arg_names.append("*")
|
||||
for name, spec in kwarg_specs.items():
|
||||
arg_names.append(name)
|
||||
if default_values and not _contains_type_spec(spec):
|
||||
arg_names[-1] += "={}".format(spec)
|
||||
signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names))
|
||||
|
||||
return signature
|
||||
|
||||
def _flat_signature_summary(self):
|
||||
"""Returns a string summarizing this function's flat signature."""
|
||||
assert self._arg_keywords is not None
|
||||
assert self._num_positional_args is not None
|
||||
arg_names = self._arg_keywords
|
||||
if self._num_positional_args > len(arg_names):
|
||||
arg_names.extend(
|
||||
"<arg{}>".format(i + 1)
|
||||
for i in range(len(arg_names), self._num_positional_args))
|
||||
return "{}({})".format(self._func_graph.name, ", ".join(arg_names))
|
||||
|
||||
def pretty_printed_signature(self, verbose=True):
|
||||
"""Returns a string summarizing the signature of this concrete function."""
|
||||
if not verbose:
|
||||
return self._structured_signature_summary(default_values=True)
|
||||
|
||||
def pretty_print_spec(spec):
|
||||
"""Returns a string describing the spec for a single argument."""
|
||||
if isinstance(spec, tensor_spec.TensorSpec):
|
||||
return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
|
||||
elif nest.is_sequence(spec):
|
||||
pieces = nest.flatten(spec, expand_composites=False)
|
||||
markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
|
||||
structure = nest.pack_sequence_as(spec, markers)
|
||||
result = "{}".format(structure)
|
||||
for (marker, piece) in zip(markers, pieces):
|
||||
result += "\n {}: {}".format(marker, pretty_print_spec(piece))
|
||||
return result
|
||||
else:
|
||||
return repr(spec)
|
||||
|
||||
lines = [self._structured_signature_summary(default_values=True)]
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
names = list(self._function_spec.arg_names)
|
||||
names.extend(sorted(kwarg_specs))
|
||||
specs = list(arg_specs) + list(kwarg_specs.values())
|
||||
# note: we can skip bound args, since we already displayed thier bound
|
||||
# value in the signature summary.
|
||||
arg_details = []
|
||||
for (name, spec) in zip(names, specs):
|
||||
if _contains_type_spec(spec):
|
||||
arg_details.append(" {}: {}".format(name, pretty_print_spec(spec)))
|
||||
if arg_details:
|
||||
lines.append(" Args:")
|
||||
lines.extend(arg_details)
|
||||
lines.append(" Returns:")
|
||||
lines.append(" {}".format(
|
||||
pretty_print_spec(
|
||||
nest.map_structure(type_spec.type_spec_from_value,
|
||||
self.structured_outputs))))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def __repr__(self):
|
||||
if self._function_spec is not None:
|
||||
return "<ConcreteFunction {} at 0x{:X}>".format(
|
||||
self.pretty_printed_signature(verbose=False), id(self))
|
||||
elif not (self._num_positional_args is None or self._arg_keywords is None):
|
||||
return "<ConcreteFunction {} at 0x{:X}>".format(
|
||||
self._flat_signature_summary(), id(self))
|
||||
else:
|
||||
return object.__repr__(self)
|
||||
|
||||
def __str__(self):
|
||||
if self._function_spec is not None:
|
||||
return "ConcreteFunction {}".format(self.pretty_printed_signature())
|
||||
else:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
||||
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
|
||||
@ -2075,17 +2421,37 @@ class FunctionSpec(object):
|
||||
kwonlydefaults={},
|
||||
annotations=fullargspec.annotations)
|
||||
is_method = tf_inspect.ismethod(python_function)
|
||||
return FunctionSpec(fullargspec, is_method, [], {}, input_signature,
|
||||
is_pure=is_pure)
|
||||
|
||||
def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include,
|
||||
input_signature, is_pure=False):
|
||||
# Get the function's name. Remove functools.partial wrappers if necessary.
|
||||
while isinstance(python_function, functools.partial):
|
||||
python_function = python_function.func
|
||||
name = getattr(python_function, "__name__", "f")
|
||||
|
||||
return FunctionSpec(
|
||||
fullargspec, is_method, input_signature, is_pure=is_pure, name=name)
|
||||
|
||||
def __init__(self,
|
||||
fullargspec,
|
||||
is_method,
|
||||
input_signature,
|
||||
is_pure=False,
|
||||
name=None):
|
||||
"""Constructs a FunctionSpec describing a python function.
|
||||
|
||||
Args:
|
||||
fullargspec: `tf_inspect.FullArgSpec` object describing the function.
|
||||
is_method: True if the function is a method.
|
||||
input_signature: a signature of the function (None, if variable)
|
||||
is_pure: if True all input arguments (including variables and constants)
|
||||
will be converted to tensors and no variable changes allowed.
|
||||
name: Name of the function
|
||||
"""
|
||||
self._fullargspec = fullargspec
|
||||
self._is_method = is_method
|
||||
self._is_pure = is_pure
|
||||
del args_to_prepend
|
||||
del kwargs_to_include
|
||||
self._default_values = fullargspec.defaults
|
||||
|
||||
# TODO(edloper): Include name when serializing for SavedModel?
|
||||
self._name = name or "f"
|
||||
|
||||
if self._is_method:
|
||||
# Remove `self`: default arguments shouldn't be matched to it.
|
||||
@ -2098,21 +2464,21 @@ class FunctionSpec(object):
|
||||
# A cache mapping from argument name to index, for canonicalizing
|
||||
# arguments that are called in a keyword-like fashion.
|
||||
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
||||
self.arg_names = args
|
||||
self.vararg_name = fullargspec.varargs
|
||||
self._arg_names = args
|
||||
|
||||
# A cache mapping from arg index to default value, for canonicalization.
|
||||
offset = len(args) - len(self._default_values or [])
|
||||
default_values = fullargspec.defaults
|
||||
offset = len(args) - len(default_values or [])
|
||||
self._arg_indices_to_default_values = {
|
||||
offset + index: default
|
||||
for index, default in enumerate(self._default_values or [])
|
||||
for index, default in enumerate(default_values or [])
|
||||
}
|
||||
if input_signature is None:
|
||||
self._input_signature = None
|
||||
else:
|
||||
if fullargspec.kwonlyargs:
|
||||
if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"function with keyword-only arguments when "
|
||||
"input_signature is provided.")
|
||||
|
||||
if not isinstance(input_signature, (tuple, list)):
|
||||
@ -2132,8 +2498,8 @@ class FunctionSpec(object):
|
||||
return self._is_method
|
||||
|
||||
@property
|
||||
def args_to_prepend(self):
|
||||
return self._args_to_prepend
|
||||
def args_to_indices(self):
|
||||
return self._args_to_indices
|
||||
|
||||
@property
|
||||
def kwargs_to_include(self):
|
||||
@ -2147,6 +2513,43 @@ class FunctionSpec(object):
|
||||
def flat_input_signature(self):
|
||||
return self._flat_input_signature
|
||||
|
||||
@property
|
||||
def is_pure(self):
|
||||
return self._is_pure
|
||||
|
||||
@property
|
||||
def arg_names(self):
|
||||
return self._arg_names
|
||||
|
||||
@property
|
||||
def vararg_name(self):
|
||||
return self._fullargspec.varargs
|
||||
|
||||
@property
|
||||
def varkw_name(self):
|
||||
return self._fullargspec.varkw
|
||||
|
||||
def signature_summary(self, default_values=False):
|
||||
"""Returns a string summarizing this function's signature.
|
||||
|
||||
Args:
|
||||
default_values: If true, then include default values in the signature.
|
||||
|
||||
Returns:
|
||||
A `string`.
|
||||
"""
|
||||
args = list(self._arg_names)
|
||||
if default_values:
|
||||
for (i, default) in self._arg_indices_to_default_values.items():
|
||||
args[i] += "={}".format(default)
|
||||
if self._fullargspec.kwonlyargs:
|
||||
args.append("*")
|
||||
for arg_name in self._fullargspec.kwonlyargs:
|
||||
args.append(arg_name)
|
||||
if default_values and arg_name in self._fullargspec.kwonlydefaults:
|
||||
args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
|
||||
return "{}({})".format(self._name, ", ".join(args))
|
||||
|
||||
def _convert_variables_to_tensors(self, args, kwargs):
|
||||
args = [ops.convert_to_tensor(x) for x in args]
|
||||
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
|
||||
@ -2159,7 +2562,13 @@ class FunctionSpec(object):
|
||||
instance. In particular, we parse the varags and kwargs that the
|
||||
original function was called with into a tuple corresponding to the
|
||||
Python function's positional (named) arguments and a dictionary
|
||||
corresponding to its kwargs.
|
||||
corresponding to its kwargs. Missing default arguments are added.
|
||||
|
||||
If this `FunctionSpec` has an input signature, then it is used to convert
|
||||
arguments to tensors; otherwise, any inputs containing numpy arrays are
|
||||
converted to tensors.
|
||||
|
||||
Additionally, any inputs containing numpy arrays are converted to Tensors.
|
||||
|
||||
Args:
|
||||
*args: The varargs this object was called with.
|
||||
@ -2180,29 +2589,38 @@ class FunctionSpec(object):
|
||||
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
||||
if self._input_signature is not None:
|
||||
if len(args) > len(self._input_signature):
|
||||
raise TypeError(
|
||||
"When input_signature is provided, only pass arguments "
|
||||
"covered by it. Received %d argument(s)." % len(args))
|
||||
raise TypeError("{} takes {} positional arguments (as specified by the "
|
||||
"input_signature) but {} were given".format(
|
||||
self.signature_summary(),
|
||||
len(self._input_signature), len(args)))
|
||||
for arg in six.iterkeys(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is None:
|
||||
raise TypeError(
|
||||
"Function got an unexpected keyword argument %s" % arg)
|
||||
raise TypeError("{} got unexpected keyword argument `{}`".format(
|
||||
self.signature_summary(), arg))
|
||||
if index >= len(self._input_signature):
|
||||
raise TypeError(
|
||||
"When input_signature is provided, only pass arguments "
|
||||
"covered by it. Received argument %s." % arg)
|
||||
"{} got keyword argument `{}` that was not included in "
|
||||
"input_signature".format(self.signature_summary(), arg))
|
||||
|
||||
if not kwargs:
|
||||
inputs = args
|
||||
default_keys = sorted(self._arg_indices_to_default_values.keys())
|
||||
if default_keys:
|
||||
assert min(default_keys) <= len(
|
||||
args), "Not enough arguments (%s, %s, %s)" % (args, default_keys,
|
||||
self.arg_names)
|
||||
for index in default_keys:
|
||||
if index >= len(args):
|
||||
inputs += (self._arg_indices_to_default_values[index],)
|
||||
if self._arg_indices_to_default_values:
|
||||
try:
|
||||
inputs += tuple(
|
||||
self._arg_indices_to_default_values[i]
|
||||
for i in range(len(args), len(self._arg_names)))
|
||||
except KeyError:
|
||||
missing_args = [
|
||||
self._arg_names[i]
|
||||
for i in range(len(args), len(self._arg_names))
|
||||
if i not in self._arg_indices_to_default_values
|
||||
]
|
||||
raise TypeError("{} missing required arguments: {}".format(
|
||||
self.signature_summary(), ", ".join(missing_args)))
|
||||
|
||||
if self._fullargspec.kwonlydefaults:
|
||||
kwargs.update(self._fullargspec.kwonlydefaults)
|
||||
else:
|
||||
# Maps from index of arg to its corresponding value, according to `args`
|
||||
# and `kwargs`; seeded with the default values for the named args that
|
||||
@ -2215,18 +2633,28 @@ class FunctionSpec(object):
|
||||
for arg, value in six.iteritems(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is not None:
|
||||
if index < len(args):
|
||||
raise TypeError("{} got two values for argument '{}'".format(
|
||||
self.signature_summary(), arg))
|
||||
arg_indices_to_values[index] = value
|
||||
consumed_args.append(arg)
|
||||
elif self._input_signature is not None:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"input_signature is provided.")
|
||||
for arg in consumed_args:
|
||||
# After this loop, `kwargs` will only contain true keyword arguments, as
|
||||
# opposed to named arguments called in a keyword-like fashion.
|
||||
# After this loop, `kwargs` will only contain keyword_only arguments,
|
||||
# and all positional_or_keyword arguments have been moved to `inputs`.
|
||||
kwargs.pop(arg)
|
||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||
|
||||
if kwargs and self._input_signature is not None:
|
||||
raise TypeError(
|
||||
"{} got unexpected keyword arguments: {}\n(Cannot define a "
|
||||
"TensorFlow function from a Python function with keyword arguments "
|
||||
"when input_signature is provided.)".format(
|
||||
self.signature_summary(), ", ".join(kwargs)))
|
||||
|
||||
if self._fullargspec.kwonlydefaults:
|
||||
for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
|
||||
kwargs.setdefault(kwarg, default)
|
||||
|
||||
if self._input_signature is None:
|
||||
inputs = _convert_numpy_inputs(inputs)
|
||||
kwargs = _convert_numpy_inputs(kwargs)
|
||||
@ -2447,6 +2875,7 @@ class Function(object):
|
||||
graph_function, _, _ = self._maybe_define_function(args, kwargs)
|
||||
return graph_function
|
||||
|
||||
# XX TODO: make sure we fix up this path as well!?
|
||||
def _get_concrete_function_internal(self, *args, **kwargs):
|
||||
"""Bypasses error checking when getting a graph function."""
|
||||
graph_function = self._get_concrete_function_internal_garbage_collected(
|
||||
@ -2664,6 +3093,7 @@ class Function(object):
|
||||
override_flat_arg_shapes=override_flat_arg_shapes,
|
||||
capture_by_value=self._capture_by_value),
|
||||
self._function_attributes,
|
||||
function_spec=self.function_spec,
|
||||
# Tell the ConcreteFunction to clean up its graph once it goes out of
|
||||
# scope. This is not the default behavior since it gets used in some
|
||||
# places (like Keras) where the FuncGraph lives longer than the
|
||||
@ -3350,3 +3780,30 @@ class ConcreteFunctionGarbageCollector(object):
|
||||
func_graph_module.dismantle_func_graph(self._func_graph)
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
|
||||
class _Marker(object):
|
||||
"""Markers used to pretty-print nested args in function signatures."""
|
||||
|
||||
def __init__(self, s):
|
||||
self._s = s
|
||||
|
||||
def __repr__(self):
|
||||
return str(self._s)
|
||||
|
||||
|
||||
def _structure_summary(structure):
|
||||
"""Displays a summary of the nesting structure of the given value."""
|
||||
|
||||
def type_name(x):
|
||||
if isinstance(x, type_spec.TypeSpec):
|
||||
return x.value_type.__name__
|
||||
else:
|
||||
return type(x).__name__
|
||||
|
||||
markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
|
||||
return str(nest.pack_sequence_as(structure, markers))
|
||||
|
||||
|
||||
def _contains_type_spec(value):
|
||||
return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.eager import cancellation
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -50,6 +51,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
@ -65,6 +67,7 @@ from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
@ -99,6 +102,16 @@ def _example_indexed_slices_without_dense_shape():
|
||||
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
|
||||
|
||||
|
||||
def _spec_for_value(value):
|
||||
"""Returns the (nested) TypeSpec for a value."""
|
||||
if nest.is_sequence(value):
|
||||
return nest.map_structure(_spec_for_value, value)
|
||||
elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
|
||||
return type_spec.type_spec_from_value(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1789,6 +1802,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, 'incompatible'):
|
||||
func([['wrong dtype']])
|
||||
|
||||
def testNoKeywordOnlyArgumentsWithInputSignature(self):
|
||||
if sys.version_info[0] < 3:
|
||||
self.skipTest('keyword_only arguments only exist in Python 3.')
|
||||
|
||||
func = eval('lambda x, *, y: x') # pylint: disable=eval-used
|
||||
signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'Cannot define a TensorFlow function from a Python '
|
||||
'function with keyword-only arguments when input_signature is '
|
||||
'provided.'):
|
||||
def_function.function(func, signature)
|
||||
|
||||
def testNestedInputSignatures(self):
|
||||
|
||||
def expected_foo(a, b):
|
||||
@ -1905,7 +1930,9 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
defined(array_ops.ones([2, 1]))
|
||||
|
||||
# Wrong number of arguments.
|
||||
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, r'takes 1 positional arguments \(as specified by the '
|
||||
r'input_signature\) but 2 were given'):
|
||||
defined(array_ops.ones([2]), array_ops.ones([2]))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'Structure of Python function inputs.*'):
|
||||
@ -1946,10 +1973,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
return -1.0 * a
|
||||
|
||||
x = constant_op.constant(1.0)
|
||||
with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'got keyword argument `training` '
|
||||
'that was not included in input_signature'):
|
||||
foo(x, training=True)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, 'got keyword argument `training` '
|
||||
'that was not included in input_signature'):
|
||||
foo(x, training=False)
|
||||
|
||||
self.assertAllEqual(x.numpy(), foo(x).numpy())
|
||||
@ -2472,8 +2503,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
return x
|
||||
|
||||
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, 'All inputs to `ConcreteFunction`s must be Tensors;.*'):
|
||||
with self.assertRaises((TypeError, ValueError)):
|
||||
graph_function('Not a Tensor.')
|
||||
|
||||
def testSwapImplementationWithGrapplerPlugin(self):
|
||||
@ -3148,6 +3178,432 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||
function.clear_function_callbacks()
|
||||
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionWithNestedTensorInputs(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
return (x['a'] + x['b'], y[0] + y[1])
|
||||
|
||||
a = constant_op.constant(1000)
|
||||
b = constant_op.constant(200)
|
||||
c = constant_op.constant(30)
|
||||
d = {'a': a, 'b': b}
|
||||
e = (c, 4)
|
||||
|
||||
# Test different argument signatures when constructing the concrete func.
|
||||
for cf in [
|
||||
f.get_concrete_function(d, e),
|
||||
f.get_concrete_function(d, y=e),
|
||||
f.get_concrete_function(y=e, x=d),
|
||||
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
|
||||
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
|
||||
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
|
||||
]:
|
||||
# Test different calling conventions when calling the concrete func.
|
||||
for output in [
|
||||
cf(d, e), # structured signature
|
||||
cf(d, y=e), # structured signature w/ kwarg
|
||||
cf(y=e, x=d), # structured signature w/ 2 kwargs
|
||||
cf(a, b, c), # flat signature
|
||||
cf(x=a, x_1=b, y=c) # flat signature w/ kwargs
|
||||
]:
|
||||
self.assertIsInstance(output, tuple)
|
||||
self.assertLen(output, 2)
|
||||
self.assertAllEqual(output[0], 1200)
|
||||
self.assertAllEqual(output[1], 34)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionWithNestedNonTensorInputs(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
return (x['a'] + x['b'], y[0] + y[1])
|
||||
|
||||
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
|
||||
b = (50, 3)
|
||||
|
||||
for cf in [ # argument y is bound to non-Tensor value (50, 3).
|
||||
f.get_concrete_function(a, b),
|
||||
f.get_concrete_function(a, y=b),
|
||||
f.get_concrete_function(x=a, y=b)
|
||||
]:
|
||||
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
|
||||
self.assertAllEqual(output[0] + output[1], 1253)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
return (x['a'] + x['b'], y[0] + y[1])
|
||||
|
||||
a = {'a': 3000, 'b': 200, 'c': 9000}
|
||||
b = (constant_op.constant(30), 4)
|
||||
|
||||
for cf in [ # argument x is bound to non-tensor value `a`
|
||||
f.get_concrete_function(a, b),
|
||||
f.get_concrete_function(a, y=b),
|
||||
f.get_concrete_function(x=a, y=b)
|
||||
]:
|
||||
for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
|
||||
self.assertAllEqual(output[0] + output[1], 3234)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
|
||||
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
return (x['a'] + x['b'], y[0] + y[1])
|
||||
|
||||
a = {'a': 5000, 'b': 500}
|
||||
b = (50, 5)
|
||||
|
||||
cf = f.get_concrete_function(a, b)
|
||||
for output in [cf(), cf(a), cf(y=b)]:
|
||||
self.assertAllEqual(output[0] + output[1], 5555)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
|
||||
# Check that keyword-only arguments are sorted appropriately, so that they
|
||||
# feed the right tensor into each input.
|
||||
@def_function.function
|
||||
def g(**kwargs):
|
||||
return string_ops.reduce_join(
|
||||
string_ops.reduce_join(
|
||||
ops.convert_to_tensor(sorted(kwargs.items())),
|
||||
axis=1,
|
||||
separator='='),
|
||||
axis=0,
|
||||
separator=', ')
|
||||
|
||||
s = constant_op.constant('s')
|
||||
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
|
||||
self.assertAllEqual(
|
||||
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
|
||||
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
|
||||
self.assertAllEqual(
|
||||
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
|
||||
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
|
||||
self.assertAllEqual(
|
||||
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
|
||||
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='MissingArg',
|
||||
conc_args=lambda: (1, constant_op.constant(2)),
|
||||
call_args=lambda: (1,),
|
||||
error=r'func\(x, y\) missing required arguments: y'),
|
||||
dict(
|
||||
testcase_name='MissingVararg',
|
||||
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
|
||||
call_args=lambda: (1, 2),
|
||||
error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
|
||||
dict(
|
||||
testcase_name='ExtraPositionalArg',
|
||||
conc_args=lambda: (1, 2),
|
||||
call_args=lambda: (1, 2, 3),
|
||||
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
|
||||
dict(
|
||||
testcase_name='MissingKeywordOnlyArg',
|
||||
conc_args=lambda: (1, 2),
|
||||
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
|
||||
call_args=lambda: (1, 2),
|
||||
error=r'func\(x, y, \*, c\) missing required arguments: c'),
|
||||
dict(
|
||||
testcase_name='ExtraKeywordArg',
|
||||
conc_args=lambda: (1, 2),
|
||||
call_args=lambda: (1, 2),
|
||||
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
|
||||
error=r'func\(x, y\) got unexpected keyword arguments: c'),
|
||||
dict(
|
||||
testcase_name='ExpectedRaggedGotNest',
|
||||
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
|
||||
call_args=lambda: ({
|
||||
'a': constant_op.constant([1, 2, 3])
|
||||
},),
|
||||
error=r'func\(x, y\): argument x had incorrect type\n'
|
||||
r' expected: RaggedTensor\n'
|
||||
r" got: {'a': (Eager)?Tensor}"),
|
||||
dict(
|
||||
testcase_name='WrongRaggedRank',
|
||||
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
|
||||
call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
|
||||
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||
dict(
|
||||
testcase_name='WrongRaggedDType',
|
||||
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
|
||||
call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
|
||||
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||
dict(
|
||||
testcase_name='ExpectedDictGotTensor',
|
||||
conc_args=lambda: ({
|
||||
'a': constant_op.constant(1),
|
||||
'b': constant_op.constant(1)
|
||||
},),
|
||||
call_args=lambda: (constant_op.constant(1),),
|
||||
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||
dict(
|
||||
testcase_name='ExpectedTupleGotTensor',
|
||||
conc_args=lambda:
|
||||
((constant_op.constant(1), constant_op.constant(2)),),
|
||||
call_args=lambda: (constant_op.constant(1),),
|
||||
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||
dict(
|
||||
testcase_name='WrongDType',
|
||||
conc_args=lambda: (constant_op.constant(1),),
|
||||
call_args=lambda: (constant_op.constant(1.0),),
|
||||
exception=(ValueError, errors.InvalidArgumentError,
|
||||
# on xla_gpu, we get InternalError instead.
|
||||
errors.InternalError)),
|
||||
dict(
|
||||
testcase_name='ExpectedTensorGotInt',
|
||||
conc_args=lambda: (constant_op.constant(1),),
|
||||
call_args=lambda: (5,),
|
||||
error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
|
||||
dict(
|
||||
testcase_name='ExpectedIntGotDifferentInt',
|
||||
conc_args=lambda: (5,),
|
||||
call_args=lambda: (8,),
|
||||
error=r'ConcreteFunction func\(x, y\) was constructed with int '
|
||||
r'value 5 in x, but was called with int value 8'),
|
||||
dict(
|
||||
testcase_name='ExpectedIntGotTensor',
|
||||
conc_args=lambda: (5,),
|
||||
call_args=lambda: (constant_op.constant(6),),
|
||||
error=r'ConcreteFunction func\(x, y\) was constructed with int '
|
||||
'value 5 in x, but was called with (Eager)?Tensor value .*'),
|
||||
dict(
|
||||
testcase_name='TwoValuesForArgument',
|
||||
conc_args=lambda: (1, 2),
|
||||
call_args=lambda: (1, 2),
|
||||
call_kwargs=lambda: {'x': 3},
|
||||
error=r"func\(x, y\) got two values for argument 'x'"),
|
||||
])
|
||||
# pylint: enable=g-long-lambda
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionStructuredSignatureError(self,
|
||||
conc_args=(),
|
||||
conc_kwargs=None,
|
||||
call_args=(),
|
||||
call_kwargs=None,
|
||||
error='.*',
|
||||
exception=TypeError):
|
||||
"""Tests for errors in the structrued signature.
|
||||
|
||||
Args:
|
||||
conc_args: Positional arguments used for get_concrete_function.
|
||||
conc_kwargs: Keyword arguments used for get_concrete_function.
|
||||
call_args: Positional arguments used to call the function.
|
||||
call_kwargs: Keyword arguments used to call the function.
|
||||
error: Expected exception message.
|
||||
exception: Expected exception type.
|
||||
"""
|
||||
conc_args = conc_args() if callable(conc_args) else conc_args
|
||||
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
|
||||
call_args = call_args() if callable(call_args) else call_args
|
||||
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
|
||||
self.assertIsInstance(conc_args, tuple)
|
||||
self.assertIsInstance(call_args, tuple)
|
||||
self.assertIsInstance(conc_kwargs, dict)
|
||||
self.assertIsInstance(call_kwargs, dict)
|
||||
|
||||
@def_function.function
|
||||
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
|
||||
del y, varargs, kwargs
|
||||
return x
|
||||
|
||||
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
|
||||
with self.assertRaisesRegexp(exception, error):
|
||||
self.evaluate(conc(*call_args, **call_kwargs))
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='MissingArg',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
call_args=lambda: (constant_op.constant(1),),
|
||||
error=r'func\(x, y\) missing required arguments: y'),
|
||||
dict(
|
||||
testcase_name='TwoValuesForArg',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
call_args=lambda: (constant_op.constant(1),),
|
||||
call_kwargs=lambda: {
|
||||
'x': constant_op.constant(1),
|
||||
'y': constant_op.constant(1)
|
||||
},
|
||||
error=r"func\(x, y\) got two values for argument 'x'"),
|
||||
dict(
|
||||
testcase_name='ExtraPositionalArg',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
|
||||
constant_op.constant(3)),
|
||||
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
|
||||
dict(
|
||||
testcase_name='UnexpectedKeywordArg',
|
||||
conc_args=lambda: (constant_op.constant(1),),
|
||||
call_args=lambda: (constant_op.constant(1),),
|
||||
call_kwargs=lambda: {'c': constant_op.constant(1)},
|
||||
error=r'func\(x\) got unexpected keyword arguments: c'),
|
||||
dict(
|
||||
testcase_name='MissingVararg',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
|
||||
constant_op.constant(3)),
|
||||
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
error=r'func\(x, y, varargs_0\) missing required '
|
||||
r'arguments: varargs_0'),
|
||||
dict(
|
||||
testcase_name='MissingKeywordArg',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
conc_kwargs=lambda: {'c': constant_op.constant(1)},
|
||||
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
error=r'func\(x, y, c\) missing required arguments: c'),
|
||||
dict(
|
||||
testcase_name='ExpectedTensorGotInt',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
call_args=lambda: (5, constant_op.constant(2)),
|
||||
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
|
||||
r'a Tensor; got int \(5\)'),
|
||||
dict(
|
||||
testcase_name='WrongDType',
|
||||
conc_args=lambda: (constant_op.constant(1),),
|
||||
call_args=lambda: (constant_op.constant(1.0),),
|
||||
exception=(ValueError, errors.InvalidArgumentError,
|
||||
# on xla_gpu, we get InternalError instead.
|
||||
errors.InternalError)),
|
||||
dict(
|
||||
testcase_name='MissingKeywordArgNestPiece',
|
||||
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
|
||||
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||
call_kwargs=lambda: {'c': constant_op.constant(1)},
|
||||
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
|
||||
])
|
||||
# pylint: enable=g-long-lambda
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionFlatSignatureError(self,
|
||||
conc_args=(),
|
||||
conc_kwargs=None,
|
||||
call_args=(),
|
||||
call_kwargs=None,
|
||||
error='.*',
|
||||
exception=TypeError):
|
||||
"""Tests for errors in the flat signature.
|
||||
|
||||
Args:
|
||||
conc_args: Positional arguments used for get_concrete_function.
|
||||
conc_kwargs: Keyword arguments used for get_concrete_function.
|
||||
call_args: Positional arguments used to call the function.
|
||||
call_kwargs: Keyword arguments used to call the function.
|
||||
error: Expected exception message.
|
||||
exception: Expected exception type.
|
||||
"""
|
||||
conc_args = conc_args() if callable(conc_args) else conc_args
|
||||
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
|
||||
call_args = call_args() if callable(call_args) else call_args
|
||||
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
|
||||
self.assertIsInstance(conc_args, tuple)
|
||||
self.assertIsInstance(call_args, tuple)
|
||||
self.assertIsInstance(conc_kwargs, dict)
|
||||
self.assertIsInstance(call_kwargs, dict)
|
||||
|
||||
@def_function.function
|
||||
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
|
||||
del y, varargs, kwargs
|
||||
return x
|
||||
|
||||
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
|
||||
|
||||
# Remove _function_spec, to disable the structured signature.
|
||||
conc._set_function_spec(None) # pylint: disable=protected-access
|
||||
|
||||
with self.assertRaisesRegexp(exception, error):
|
||||
self.evaluate(conc(*call_args, **call_kwargs))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConcreteFunctionAmbiguousSignature(self):
|
||||
# When both the flat & structured signatures are applicable, but they
|
||||
# give different results, we use the structured signature. Note: we expect
|
||||
# this to be extremely rare.
|
||||
@def_function.function
|
||||
def f(x, y):
|
||||
return x * 10 + y
|
||||
|
||||
conc = f.get_concrete_function(
|
||||
x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
|
||||
y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
|
||||
|
||||
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
|
||||
self.assertAllEqual(result, 56)
|
||||
|
||||
def testPrettyPrintedSignature(self):
|
||||
|
||||
@def_function.function
|
||||
def func(x, kangaroo=None, octopus=7):
|
||||
del octopus, kangaroo
|
||||
return x
|
||||
|
||||
scalar = constant_op.constant(5)
|
||||
vector = constant_op.constant([10, 10, 20])
|
||||
ragged = ragged_factory_ops.constant([[10, 20], [40]])
|
||||
|
||||
c1 = func.get_concrete_function(scalar, vector)
|
||||
c1_summary = r'func\(x, kangaroo, octopus=7\)'
|
||||
c1_details = (r' Args:\n'
|
||||
r' x: int32 Tensor, shape=\(\)\n'
|
||||
r' kangaroo: int32 Tensor, shape=\(3,\)\n'
|
||||
r' Returns:\n'
|
||||
r' int32 Tensor, shape=\(\)')
|
||||
self.assertRegexpMatches(
|
||||
c1.pretty_printed_signature(verbose=False), c1_summary)
|
||||
self.assertRegexpMatches(
|
||||
c1.pretty_printed_signature(verbose=True),
|
||||
c1_summary + '\n' + c1_details)
|
||||
self.assertRegexpMatches(
|
||||
repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
|
||||
self.assertRegexpMatches(
|
||||
str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
|
||||
|
||||
c2 = func.get_concrete_function(scalar, ragged, 3)
|
||||
c2_summary = r'func\(x, kangaroo, octopus=3\)'
|
||||
c2_details = (r' Args:\n'
|
||||
r' x: int32 Tensor, shape=\(\)\n'
|
||||
r' kangaroo: RaggedTensorSpec\(.*\)\n'
|
||||
r' Returns:\n'
|
||||
r' int32 Tensor, shape=\(\)')
|
||||
self.assertRegexpMatches(c2.pretty_printed_signature(),
|
||||
c2_summary + '\n' + c2_details)
|
||||
|
||||
c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
|
||||
c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
|
||||
c3_details = (r' Args:\n'
|
||||
r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
|
||||
r' <1>: int32 Tensor, shape=\(\)\n'
|
||||
r' <2>: RaggedTensorSpec\(.*\)\n'
|
||||
r' <3>: RaggedTensorSpec\(.*\)\n'
|
||||
r' Returns:\n'
|
||||
r" {'a': <1>, 'b': \[<2>, <3>\]}\n"
|
||||
r' <1>: int32 Tensor, shape=\(\)\n'
|
||||
r' <2>: RaggedTensorSpec\(.*\)\n'
|
||||
r' <3>: RaggedTensorSpec\(.*\)')
|
||||
self.assertRegexpMatches(c3.pretty_printed_signature(),
|
||||
c3_summary + '\n' + c3_details)
|
||||
|
||||
# pylint: disable=keyword-arg-before-vararg
|
||||
@def_function.function
|
||||
def func2(x, y=3, *args, **kwargs):
|
||||
return (x, y, args, kwargs)
|
||||
|
||||
c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
|
||||
c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
|
||||
self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
|
||||
|
||||
c5 = func2.get_concrete_function(8, vector)
|
||||
c5_summary = 'func2(x=8, y)'
|
||||
self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
|
||||
|
||||
|
||||
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -43,6 +43,8 @@ def _is_tensor(t):
|
||||
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
|
||||
|
||||
|
||||
# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the
|
||||
# structured signature.
|
||||
def _call_concrete_function(function, inputs):
|
||||
"""Calls a restored Function with structured inputs.
|
||||
|
||||
@ -137,8 +139,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
|
||||
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
||||
return function_lib.FunctionSpec(fullargspec=fullargspec,
|
||||
is_method=False,
|
||||
args_to_prepend=[],
|
||||
kwargs_to_include={},
|
||||
input_signature=input_signature)
|
||||
|
||||
|
||||
@ -191,6 +191,8 @@ def recreate_function(saved_function, concrete_functions):
|
||||
Args:
|
||||
saved_function: `SavedFunction` proto.
|
||||
concrete_functions: map from function name to `ConcreteFunction`.
|
||||
As a side effect of this function, the `FunctionSpec` from
|
||||
`saved_function` is added to each `ConcreteFunction` in this map.
|
||||
|
||||
Returns:
|
||||
A `Function`.
|
||||
@ -254,6 +256,9 @@ def recreate_function(saved_function, concrete_functions):
|
||||
for concrete_function_name in saved_function.concrete_functions:
|
||||
concrete_function_objects.append(concrete_functions[concrete_function_name])
|
||||
|
||||
for cf in concrete_function_objects:
|
||||
cf._set_function_spec(function_spec) # pylint: disable=protected-access
|
||||
|
||||
restored_function = RestoredFunction(
|
||||
restored_function_body,
|
||||
restored_function_body.__name__,
|
||||
@ -317,6 +322,11 @@ def load_function_def_library(library, load_shared_name_suffix=None):
|
||||
|
||||
for dep in _list_function_deps(fdef, library_function_names):
|
||||
functions[dep].add_to_graph(func_graph)
|
||||
|
||||
# We do not initialize the new ConcreteFunction's function_spec or
|
||||
# arg_keywords here (which are used to parse the structured and flat
|
||||
# signatures, respectively). function_spec is set up later by
|
||||
# recreate_function(); and arg_keywords by setup_bare_concrete_function().
|
||||
func = function_lib.ConcreteFunction(func_graph)
|
||||
func.add_to_graph(graph)
|
||||
|
||||
|
@ -77,6 +77,10 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
|
||||
|
||||
def serialize_bare_concrete_function(concrete_function, name_map):
|
||||
"""Build a SavedBareConcreteFunction."""
|
||||
# TODO(edloper): Currently, bare concrete functions don't have access to a
|
||||
# function_spec, so they can't be called with the structured signature.
|
||||
# Update the serialization to include a function_spec.
|
||||
|
||||
# pylint: disable=protected-access
|
||||
name = name_map.get(compat.as_text(concrete_function.name),
|
||||
concrete_function.name)
|
||||
@ -151,7 +155,8 @@ def wrap_cached_variables(concrete_function):
|
||||
func_graph_module.func_graph_from_py_func(
|
||||
None, wrap_function, args=tuple(args), kwargs={},
|
||||
func_graph=outer_graph)
|
||||
fn = defun.ConcreteFunction(outer_graph)
|
||||
fn = defun.ConcreteFunction(
|
||||
outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access
|
||||
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
|
||||
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
|
||||
|
||||
|
@ -173,12 +173,12 @@ class Loader(object):
|
||||
# The original_outputs here had Tensors converted to TensorSpecs, so
|
||||
# the restored function's structured_outputs field will not be
|
||||
# exactly the same. Fortunately the repacking logic cares only about
|
||||
# the structure.
|
||||
# TODO(vbardiovsky): Should we just replicate the structures, with
|
||||
# Nones instead of real objects?
|
||||
# the structure; and the unpacking logic cares only about structure
|
||||
# and types.
|
||||
concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access
|
||||
concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access
|
||||
coder.decode_proto(proto.canonicalized_input_signature))
|
||||
concrete_function._initialize_function_spec() # pylint: disable=protected-access
|
||||
|
||||
def _setup_functions_captures(self):
|
||||
"""Setup captures and variables in restored functions."""
|
||||
|
Loading…
Reference in New Issue
Block a user