Functions returned by the get_concrete_function method of tf.Function objects can now be called with arguments consistent with the original arguments or type specs passed to get_concrete_function. In particular:

* If a composite tensor (such as RaggedTensor or SparseTensor) was passed to `get_concrete_function`, then the returned function will accept a composite tensor of the same type for that argument.

* If a nested structure (such as a list or dict) was passed to `get_concrete_function`, then the returned function will accept a value with the same nesting structure.  Each tensor or composite tensor value must have the same type as was used in the original argument; and each non-Tensor value (such as bools or ints) must be equal value that was used in the original argument.

* If a non-tensor value (such as a bool or int) was passed to `get_concrete_function`, then the returned function no longer deletes that argument; instead, it updates the argument's default value to the value that was passed to `get_concrete_function`.  Passing in any other value will raise an exception.

* Arguments are not renamed based on `TensorSpec.name`.

For backwards compatibility, the functions returned by `get_concrete_function` will continue to accept arguments with the existing calling conventions (where nested structures and composite tensors are flattened; non-tensor arguments are deleted; suffixes are automatically added to disambiguate arguments with the same name; and TensorSpec.name is used to rename arguments).  However, the preferred calling convention is the one that is consistent with the original arguments or type specs passed to `get_concrete_function`.

PiperOrigin-RevId: 307398918
Change-Id: Ie4685b32d9f151c82f6c79a6c41379faa96b5ee8
This commit is contained in:
Edward Loper 2020-04-20 08:00:45 -07:00 committed by TensorFlower Gardener
parent 0887fedd2d
commit f39aab3092
6 changed files with 1041 additions and 101 deletions

View File

@ -830,6 +830,13 @@ class Function(object):
def function_spec(self):
return self._function_spec
def pretty_printed_concrete_signatures(self, verbose=True):
joiner = "\n\n" if verbose else "\n"
return joiner.join([
c.pretty_printed_signature(verbose=verbose)
for c in self._list_all_concrete_functions()
])
def _initialize_uninitialized_variables(self, initializers):
"""Make and call a `ConcreteFunction` which initializes variables."""
@ -913,12 +920,8 @@ class Function(object):
return initialize_variables.get_concrete_function()
def _list_all_concrete_functions_for_serialization(self):
"""Returns all concrete functions for serialization.
Returns:
A list of instances of `ConcreteFunction`.
"""
def _list_all_concrete_functions(self):
"""Returns all concrete functions."""
if self.input_signature is not None:
self.get_concrete_function()
concrete_functions = []
@ -930,6 +933,15 @@ class Function(object):
concrete_functions.extend(
self._stateless_fn._function_cache.all_values())
# pylint: enable=protected-access
return concrete_functions
def _list_all_concrete_functions_for_serialization(self):
"""Returns all concrete functions for serialization.
Returns:
A list of instances of `ConcreteFunction`.
"""
concrete_functions = self._list_all_concrete_functions()
seen_signatures = []
for concrete_function in concrete_functions:
signature = concrete_function.structured_input_signature

View File

@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
@ -340,7 +341,7 @@ class _InterpolateFunctionError(object):
if t.name == compat.as_str(self._func.name):
g = self._func.graph
elif g:
next_func = g._get_function(t.name)
next_func = g._get_function(t.name) # pylint: disable=protected-access
if next_func is not None and isinstance(next_func,
_EagerDefinedFunction):
g = next_func.graph
@ -1499,6 +1500,12 @@ class _ForwardBackwardCall(object):
flat_outputs, self._inference_args, self._input_tangents)
# Sentinel value used by with ConcreteFunction's structured signature to
# indicate that a non-tensor parameter should use the value that was
# specified when the concrete function was created.
_BOUND_VALUE = object()
class ConcreteFunction(object):
"""Callable object encapsulating a function definition and its gradient.
@ -1506,7 +1513,11 @@ class ConcreteFunction(object):
is differentiable under `tf.GradientTape` objects.
"""
def __init__(self, func_graph, attrs=None, shared_func_graph=True):
def __init__(self,
func_graph,
attrs=None,
shared_func_graph=True,
function_spec=None):
"""Initialize a `ConcreteFunction`.
Args:
@ -1517,16 +1528,25 @@ class ConcreteFunction(object):
shared_func_graph: If False, the ConcreteFunction takes ownership of
`func_graph` and will break reference cycles when it is deleted. This
makes the FuncGraph inoperable.
function_spec: FunctionSpec for the original function. If not specified,
then this ConcreteFunction may only be called using the flat signature.
Raises:
ValueError: If number of input_placeholders is not equal to the number
of function inputs.
"""
# _arg_keywords and _num_positional_args define the flat signature. They
# are assigned after construction.
self._arg_keywords = None
self._num_positional_args = None
self._func_graph = func_graph
self._captured_inputs = self._func_graph.external_captures
self._captured_closures = self._func_graph.deferred_external_captures
# function_spec defines the structured signature.
self._set_function_spec(function_spec)
if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
# The alternative is to silently drop "implements" tag
# but it seems likely it would lead to hard to catch bugs.
@ -1576,6 +1596,52 @@ class ConcreteFunction(object):
# building gradients.
self._inference_function = self._delayed_rewrite_functions.forward()
def _set_function_spec(self, function_spec):
"""Enables the structured signature by supplying a function_spec."""
self._function_spec = None
self._pre_initialized_function_spec = function_spec
# Note: when ConcreteFunctions are built by recreate_function() in
# function_deserialization.py, they don't have a structured_input_signature
# yet. In that case, _initialize_function_spec() gets called by
# _setup_functions_structures() in load.py.
if (function_spec is not None and
self.structured_input_signature is not None):
self._initialize_function_spec()
def _initialize_function_spec(self):
"""Updates `self._function_spec` to include varargs and bound variables.
Adds new positional arguments for any varargs (i.e., for args that are
in `structured_input_signature`, but not in the original fullargspec.args).
Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
all args and kwargs in `structured_input_signature`.
Sets `varkw` and `varargs` to None.
"""
if self._pre_initialized_function_spec is None:
return # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
assert not self._function_spec, "already initialized"
function_spec = self._pre_initialized_function_spec
args = function_spec.fullargspec.args
arg_specs, kwarg_specs = self.structured_input_signature
fullargspec = tf_inspect.FullArgSpec(
args=list(args) +
["<arg{}>".format(i + 1) for i in range(len(args), len(arg_specs))],
varargs=None,
varkw=None,
defaults=[_BOUND_VALUE] * len(arg_specs),
kwonlyargs=list(sorted(kwarg_specs)),
kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
annotations=function_spec.fullargspec.annotations)
self._function_spec = FunctionSpec(
fullargspec,
function_spec.is_method,
function_spec.input_signature,
function_spec.is_pure,
name=self._func_graph.name)
@property
def variables(self):
"""Sequence of variables for this function."""
@ -1589,15 +1655,44 @@ class ConcreteFunction(object):
def __call__(self, *args, **kwargs):
"""Executes the wrapped function.
ConcreteFunctions have two signatures:
* The signature of the original function wrapped by this ConcreteFunction.
* A flat signature, where each argument accepts a single Tensor.
The original function signature is generally preferred, but the flat input
signature is supported for backward compatibility.
### Original Function Signature
When calling a ConcreteFunction with the signature of the original function,
each argument must match the type or value that was used when the
ConcreteFunction's graph was traced. In particular:
* Tensor arguments (including CompositeTensors, such as RaggedTensor) must
have matching `TypeSpec`s.
* Non-Tensor arguments (such as booleans or ints) must have equal values.
* Nested arguments (such as lists, tuples, or dictionaries) must have the
same nesting structure; and each nested value must have a matching type
or value.
The default value for any arguments that were traced with non-Tensor values
is the value that was used in the trace. Arguments that were traced with
tensor arguments do not have a default value (even if the original function
had a default value for that argument).
### Flat Signature
When calling a ConcreteFunction with the flat signature, the arguments
correspond to the flattened component tensors of the arguments that were
used to construct the ConcreteFunction. Parameter names are assigned based
on `TensorSpec.name` (when specified) or the original argument names (with
suffixes automatically added for nested arguments or composite tensors with
multiple components).
Args:
*args: Tensors or Variables. Positional arguments are only accepted when
they correspond one-to-one with arguments of the traced Python function.
**kwargs: Tensors or Variables specified by name. When
`get_concrete_function` was called to create this `ConcreteFunction`,
each Tensor input was given a name, defaulting to the name of the Python
function's argument but possibly overridden by the `name=` argument to
`tf.TensorSpec`. These names become the argument names for the concrete
function.
*args: Positional arguments to the concrete function.
**kwargs: Keyword arguments to the concrete function.
Returns:
The result of applying the TF function on the given Tensors.
@ -1605,9 +1700,7 @@ class ConcreteFunction(object):
Raises:
AssertionError: If this `ConcreteFunction` was not created through
`get_concrete_function`.
ValueError: If arguments contains anything other than Tensors or
Variables.
TypeError: For invalid positional/keyword argument combinations.
TypeError: If the arguments do not match the function's signature.
"""
return self._call_impl(args, kwargs)
@ -1615,40 +1708,174 @@ class ConcreteFunction(object):
"""See `__call__` for details."""
with traceme.TraceMe(self._func_graph.name,
tf_function_call="concrete"):
if self._arg_keywords is None or self._num_positional_args is None:
raise AssertionError(
"Tried to call a concrete function obtained from an internal API "
"through the public interface. Use get_concrete_function instead.")
if len(args) > self._num_positional_args:
raise TypeError(
("Expected at most {} positional arguments (and the rest keywords, "
"of {}), got {}. When calling a concrete function, positional "
"arguments may not be bound to Tensors within nested structures."
).format(self._num_positional_args, self._arg_keywords, args))
args = list(args)
for keyword in self._arg_keywords[len(args):]:
# Construct the list of input tensors: check if the structured signature
# applies first; and if not, then use the flat signature.
if self._function_spec is not None:
try:
args.append(kwargs.pop(compat.as_str(keyword)))
except KeyError:
specified_keywords = (list(self._arg_keywords[:len(args)])
+ list(kwargs.keys()))
raise TypeError(
"Expected argument names {} but got values for {}. Missing: {}."
.format(
list(self._arg_keywords),
specified_keywords,
list(set(self._arg_keywords) - set(specified_keywords))))
if kwargs:
positional_arg_keywords = set(self._arg_keywords[:len(args)])
for unused_key in kwargs:
if unused_key in positional_arg_keywords:
raise TypeError("Got two values for keyword '{}'.".format(
unused_key))
raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
list(kwargs.keys()), list(self._arg_keywords)))
return self._call_flat(args, self.captured_inputs, cancellation_manager)
return self._call_with_structured_signature(args, kwargs,
cancellation_manager)
except TypeError as structured_err:
try:
return self._call_with_flat_signature(args, kwargs,
cancellation_manager)
except TypeError:
raise structured_err
def _filtered_call(self, args, kwargs):
return self._call_with_flat_signature(args, kwargs, cancellation_manager)
def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
"""Executes the wrapped function with the flat signature.
Args:
args: Positional arguments to the concrete function.
kwargs: Keyword arguments to the concrete function.
cancellation_manager: A `CancellationManager` that can be used to cancel
function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
`args` and `kwargs`.
Raises:
TypeError: if `args` and `kwargs` do not match the flat signature of this
`ConcreteFunction`.
"""
if len(args) > self._num_positional_args:
raise TypeError(
"{} takes {} positional arguments but {} were given".format(
self._flat_signature_summary(), self._num_positional_args,
len(args)))
args = list(args)
kwargs = dict(kwargs)
for keyword in self._arg_keywords[len(args):]:
try:
args.append(kwargs.pop(compat.as_str(keyword)))
except KeyError:
specified_keywords = (
list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
raise TypeError("{} missing required arguments: {}".format(
self._flat_signature_summary(), ", ".join(
sorted(set(self._arg_keywords) - set(specified_keywords)))))
if kwargs:
positional_arg_keywords = set(self._arg_keywords[:len(args)])
for unused_key in kwargs:
if unused_key in positional_arg_keywords:
raise TypeError("{} got two values for argument '{}'".format(
self._flat_signature_summary(), unused_key))
raise TypeError("{} got unexpected keyword arguments: {}.".format(
self._flat_signature_summary(), ", ".join(sorted(kwargs))))
for i, arg in enumerate(args):
if not isinstance(
arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
"got {} ({})".format(self._flat_signature_summary(), i,
type(arg).__name__, str(arg)))
return self._call_flat(args, self.captured_inputs, cancellation_manager)
def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
"""Executes the wrapped function with the structured signature.
Args:
args: Positional arguments to the concrete function.
kwargs: Keyword arguments to the concrete function.
cancellation_manager: A `CancellationManager` that can be used to cancel
function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
`args` and `kwargs`.
Raises:
TypeError: if `args` and `kwargs` do not match the structured signature
of this `ConcreteFunction`.
"""
args, kwargs = self._function_spec.canonicalize_function_inputs(
*args, **kwargs)
self._structured_signature_check_missing_args(args, kwargs)
self._structured_signature_check_unexpected_args(args, kwargs)
self._structured_signature_check_arg_types(args, kwargs)
return self._filtered_call(args, kwargs, cancellation_manager)
def _structured_signature_check_missing_args(self, args, kwargs):
"""Raises a TypeError if any args are missing."""
arg_specs, kwarg_specs = self.structured_input_signature
missing_arguments = []
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
if arg is _BOUND_VALUE and _contains_type_spec(spec):
missing_arguments.append(self._function_spec.arg_names[i])
for (name, arg) in kwargs.items():
if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
missing_arguments.append(name)
if missing_arguments:
raise TypeError("{} missing required arguments: {}".format(
self._structured_signature_summary(),
", ".join(sorted(missing_arguments))))
def _structured_signature_check_unexpected_args(self, args, kwargs):
"""Raises a TypeError if there are any extra args."""
arg_specs, kwarg_specs = self.structured_input_signature
if len(args) > len(arg_specs):
raise TypeError(
"{} takes {} positional arguments but {} were given".format(
self._structured_signature_summary(),
len(self._function_spec.arg_names), len(args)))
if len(kwargs) > len(kwarg_specs):
extra_args = set(kwargs) - set(kwarg_specs)
raise TypeError("{} got unexpected keyword arguments: {}".format(
self._structured_signature_summary(), ", ".join(extra_args)))
def _structured_signature_check_arg_types(self, args, kwargs):
"""Raises a TypeError if any args have the wrong type."""
# Check argument types
arg_specs, kwarg_specs = self.structured_input_signature
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
name = self._function_spec.arg_names[i]
self._structured_signature_check_arg_type(arg, spec, name)
for (name, arg) in kwargs.items():
self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
def _structured_signature_check_arg_type(self, arg, spec, name):
"""Raise TypeError if `arg`'s type doesn't match `spec`."""
if arg is _BOUND_VALUE:
return
# Check the overall nested structure of the argument.
try:
nest.assert_same_structure(arg, spec, expand_composites=True)
except (ValueError, TypeError):
try:
nest.assert_same_structure(arg, spec, expand_composites=False)
expected, got = spec, arg
except (ValueError, TypeError):
expected, got = _structure_summary(spec), _structure_summary(arg)
raise TypeError("{}: argument {} had incorrect type\n"
" expected: {}\n got: {}".format(
self._structured_signature_summary(), name, expected,
got))
# Check the type for each leaf in the nested structure.
arg_pieces = nest.flatten(arg, expand_composites=True)
spec_pieces = nest.flatten(spec, expand_composites=True)
for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
if isinstance(spec_piece, tensor_spec.DenseSpec):
# TODO(edloper): Consider calling convert_to_tensor on non-tensor
# values here. That would match the behavior of
# _call_concrete_function() in function_deserialization.py. If
# we do, then we need to change the nest assert_same_structure and
# flatten calls above to use shallow variants.
tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
if not isinstance(arg_piece, tensor_types):
raise TypeError(
"{} expected a Tensor in {}, but got {} value {}".format(
self._structured_signature_summary(), name,
type(arg_piece).__name__, arg_piece))
elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece:
raise TypeError("ConcreteFunction {} was constructed with {} value "
"{} in {}, but was called with {} value {}".format(
self._structured_signature_summary(),
type(spec_piece).__name__, spec_piece, name,
type(arg_piece).__name__, arg_piece))
def _filtered_call(self, args, kwargs, cancellation_manager=None):
"""Executes the function, filtering arguments from the Python function.
Objects aside from Tensors, CompositeTensors, and Variables are ignored.
@ -1657,6 +1884,8 @@ class ConcreteFunction(object):
Args:
args: Canonicalized positional arguments of the Python function.
kwargs: Canonicalized keyword arguments of the Python function.
cancellation_manager: (Optional.) A `CancellationManager` that can be
used to cancel function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
@ -1666,7 +1895,8 @@ class ConcreteFunction(object):
(t for t in nest.flatten((args, kwargs), expand_composites=True)
if isinstance(t, (ops.Tensor,
resource_variable_ops.BaseResourceVariable))),
self.captured_inputs)
captured_inputs=self.captured_inputs,
cancellation_manager=cancellation_manager)
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
"""Executes the wrapped function.
@ -1795,7 +2025,26 @@ class ConcreteFunction(object):
@property
def structured_input_signature(self):
"""Returns structured signature of the original function."""
"""Returns structured signature for this concrete function.
Returns:
A tuple `(args, kwargs)`, where:
* `args` is a tuple that specifies the expected type or value each for
positional argument.
* `kwargs` is a dictionary that specifies the expected type or value
for each keyword-only argument.
The type or value for each argument is specified using one of the
following:
* A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
value is expected.
* A Python value, such as an integer, indicating that an equal value
is expected.
* A nested structure of `tf.TypeSpec`s and Python values, indicating
that a corresponding nested structure is expected.
"""
return self._func_graph.structured_input_signature
@property
@ -1982,6 +2231,103 @@ class ConcreteFunction(object):
ret.attr[name].CopyFrom(value)
return ret
def _structured_signature_summary(self, default_values=False):
"""Returns a string summarizing this function's structured signature.
Args:
default_values: If true, then include default values in the signature.
Returns:
A `string`.
"""
# Note: we can't just use self._funcion_spec.signature_summary(), because
# that would show "_BOUND_VALUE" as the default value for all arguments.
assert self._function_spec is not None
arg_specs, kwarg_specs = self.structured_input_signature
arg_names = list(self._function_spec.arg_names)
if default_values:
for i in range(len(arg_names)):
if not _contains_type_spec(arg_specs[i]):
arg_names[i] += "={}".format(arg_specs[i])
if kwarg_specs:
arg_names.append("*")
for name, spec in kwarg_specs.items():
arg_names.append(name)
if default_values and not _contains_type_spec(spec):
arg_names[-1] += "={}".format(spec)
signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names))
return signature
def _flat_signature_summary(self):
"""Returns a string summarizing this function's flat signature."""
assert self._arg_keywords is not None
assert self._num_positional_args is not None
arg_names = self._arg_keywords
if self._num_positional_args > len(arg_names):
arg_names.extend(
"<arg{}>".format(i + 1)
for i in range(len(arg_names), self._num_positional_args))
return "{}({})".format(self._func_graph.name, ", ".join(arg_names))
def pretty_printed_signature(self, verbose=True):
"""Returns a string summarizing the signature of this concrete function."""
if not verbose:
return self._structured_signature_summary(default_values=True)
def pretty_print_spec(spec):
"""Returns a string describing the spec for a single argument."""
if isinstance(spec, tensor_spec.TensorSpec):
return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
elif nest.is_sequence(spec):
pieces = nest.flatten(spec, expand_composites=False)
markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
structure = nest.pack_sequence_as(spec, markers)
result = "{}".format(structure)
for (marker, piece) in zip(markers, pieces):
result += "\n {}: {}".format(marker, pretty_print_spec(piece))
return result
else:
return repr(spec)
lines = [self._structured_signature_summary(default_values=True)]
arg_specs, kwarg_specs = self.structured_input_signature
names = list(self._function_spec.arg_names)
names.extend(sorted(kwarg_specs))
specs = list(arg_specs) + list(kwarg_specs.values())
# note: we can skip bound args, since we already displayed thier bound
# value in the signature summary.
arg_details = []
for (name, spec) in zip(names, specs):
if _contains_type_spec(spec):
arg_details.append(" {}: {}".format(name, pretty_print_spec(spec)))
if arg_details:
lines.append(" Args:")
lines.extend(arg_details)
lines.append(" Returns:")
lines.append(" {}".format(
pretty_print_spec(
nest.map_structure(type_spec.type_spec_from_value,
self.structured_outputs))))
return "\n".join(lines)
def __repr__(self):
if self._function_spec is not None:
return "<ConcreteFunction {} at 0x{:X}>".format(
self.pretty_printed_signature(verbose=False), id(self))
elif not (self._num_positional_args is None or self._arg_keywords is None):
return "<ConcreteFunction {} at 0x{:X}>".format(
self._flat_signature_summary(), id(self))
else:
return object.__repr__(self)
def __str__(self):
if self._function_spec is not None:
return "ConcreteFunction {}".format(self.pretty_printed_signature())
else:
return self.__repr__()
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
@ -2075,17 +2421,37 @@ class FunctionSpec(object):
kwonlydefaults={},
annotations=fullargspec.annotations)
is_method = tf_inspect.ismethod(python_function)
return FunctionSpec(fullargspec, is_method, [], {}, input_signature,
is_pure=is_pure)
def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include,
input_signature, is_pure=False):
# Get the function's name. Remove functools.partial wrappers if necessary.
while isinstance(python_function, functools.partial):
python_function = python_function.func
name = getattr(python_function, "__name__", "f")
return FunctionSpec(
fullargspec, is_method, input_signature, is_pure=is_pure, name=name)
def __init__(self,
fullargspec,
is_method,
input_signature,
is_pure=False,
name=None):
"""Constructs a FunctionSpec describing a python function.
Args:
fullargspec: `tf_inspect.FullArgSpec` object describing the function.
is_method: True if the function is a method.
input_signature: a signature of the function (None, if variable)
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
name: Name of the function
"""
self._fullargspec = fullargspec
self._is_method = is_method
self._is_pure = is_pure
del args_to_prepend
del kwargs_to_include
self._default_values = fullargspec.defaults
# TODO(edloper): Include name when serializing for SavedModel?
self._name = name or "f"
if self._is_method:
# Remove `self`: default arguments shouldn't be matched to it.
@ -2098,21 +2464,21 @@ class FunctionSpec(object):
# A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion.
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
self.arg_names = args
self.vararg_name = fullargspec.varargs
self._arg_names = args
# A cache mapping from arg index to default value, for canonicalization.
offset = len(args) - len(self._default_values or [])
default_values = fullargspec.defaults
offset = len(args) - len(default_values or [])
self._arg_indices_to_default_values = {
offset + index: default
for index, default in enumerate(self._default_values or [])
for index, default in enumerate(default_values or [])
}
if input_signature is None:
self._input_signature = None
else:
if fullargspec.kwonlyargs:
if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"function with keyword-only arguments when "
"input_signature is provided.")
if not isinstance(input_signature, (tuple, list)):
@ -2132,8 +2498,8 @@ class FunctionSpec(object):
return self._is_method
@property
def args_to_prepend(self):
return self._args_to_prepend
def args_to_indices(self):
return self._args_to_indices
@property
def kwargs_to_include(self):
@ -2147,6 +2513,43 @@ class FunctionSpec(object):
def flat_input_signature(self):
return self._flat_input_signature
@property
def is_pure(self):
return self._is_pure
@property
def arg_names(self):
return self._arg_names
@property
def vararg_name(self):
return self._fullargspec.varargs
@property
def varkw_name(self):
return self._fullargspec.varkw
def signature_summary(self, default_values=False):
"""Returns a string summarizing this function's signature.
Args:
default_values: If true, then include default values in the signature.
Returns:
A `string`.
"""
args = list(self._arg_names)
if default_values:
for (i, default) in self._arg_indices_to_default_values.items():
args[i] += "={}".format(default)
if self._fullargspec.kwonlyargs:
args.append("*")
for arg_name in self._fullargspec.kwonlyargs:
args.append(arg_name)
if default_values and arg_name in self._fullargspec.kwonlydefaults:
args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
return "{}({})".format(self._name, ", ".join(args))
def _convert_variables_to_tensors(self, args, kwargs):
args = [ops.convert_to_tensor(x) for x in args]
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
@ -2159,7 +2562,13 @@ class FunctionSpec(object):
instance. In particular, we parse the varags and kwargs that the
original function was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs.
corresponding to its kwargs. Missing default arguments are added.
If this `FunctionSpec` has an input signature, then it is used to convert
arguments to tensors; otherwise, any inputs containing numpy arrays are
converted to tensors.
Additionally, any inputs containing numpy arrays are converted to Tensors.
Args:
*args: The varargs this object was called with.
@ -2180,29 +2589,38 @@ class FunctionSpec(object):
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._input_signature is not None:
if len(args) > len(self._input_signature):
raise TypeError(
"When input_signature is provided, only pass arguments "
"covered by it. Received %d argument(s)." % len(args))
raise TypeError("{} takes {} positional arguments (as specified by the "
"input_signature) but {} were given".format(
self.signature_summary(),
len(self._input_signature), len(args)))
for arg in six.iterkeys(kwargs):
index = self._args_to_indices.get(arg, None)
if index is None:
raise TypeError(
"Function got an unexpected keyword argument %s" % arg)
raise TypeError("{} got unexpected keyword argument `{}`".format(
self.signature_summary(), arg))
if index >= len(self._input_signature):
raise TypeError(
"When input_signature is provided, only pass arguments "
"covered by it. Received argument %s." % arg)
"{} got keyword argument `{}` that was not included in "
"input_signature".format(self.signature_summary(), arg))
if not kwargs:
inputs = args
default_keys = sorted(self._arg_indices_to_default_values.keys())
if default_keys:
assert min(default_keys) <= len(
args), "Not enough arguments (%s, %s, %s)" % (args, default_keys,
self.arg_names)
for index in default_keys:
if index >= len(args):
inputs += (self._arg_indices_to_default_values[index],)
if self._arg_indices_to_default_values:
try:
inputs += tuple(
self._arg_indices_to_default_values[i]
for i in range(len(args), len(self._arg_names)))
except KeyError:
missing_args = [
self._arg_names[i]
for i in range(len(args), len(self._arg_names))
if i not in self._arg_indices_to_default_values
]
raise TypeError("{} missing required arguments: {}".format(
self.signature_summary(), ", ".join(missing_args)))
if self._fullargspec.kwonlydefaults:
kwargs.update(self._fullargspec.kwonlydefaults)
else:
# Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that
@ -2215,18 +2633,28 @@ class FunctionSpec(object):
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
if index < len(args):
raise TypeError("{} got two values for argument '{}'".format(
self.signature_summary(), arg))
arg_indices_to_values[index] = value
consumed_args.append(arg)
elif self._input_signature is not None:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
# After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
# After this loop, `kwargs` will only contain keyword_only arguments,
# and all positional_or_keyword arguments have been moved to `inputs`.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
if kwargs and self._input_signature is not None:
raise TypeError(
"{} got unexpected keyword arguments: {}\n(Cannot define a "
"TensorFlow function from a Python function with keyword arguments "
"when input_signature is provided.)".format(
self.signature_summary(), ", ".join(kwargs)))
if self._fullargspec.kwonlydefaults:
for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
kwargs.setdefault(kwarg, default)
if self._input_signature is None:
inputs = _convert_numpy_inputs(inputs)
kwargs = _convert_numpy_inputs(kwargs)
@ -2447,6 +2875,7 @@ class Function(object):
graph_function, _, _ = self._maybe_define_function(args, kwargs)
return graph_function
# XX TODO: make sure we fix up this path as well!?
def _get_concrete_function_internal(self, *args, **kwargs):
"""Bypasses error checking when getting a graph function."""
graph_function = self._get_concrete_function_internal_garbage_collected(
@ -2664,6 +3093,7 @@ class Function(object):
override_flat_arg_shapes=override_flat_arg_shapes,
capture_by_value=self._capture_by_value),
self._function_attributes,
function_spec=self.function_spec,
# Tell the ConcreteFunction to clean up its graph once it goes out of
# scope. This is not the default behavior since it gets used in some
# places (like Keras) where the FuncGraph lives longer than the
@ -3350,3 +3780,30 @@ class ConcreteFunctionGarbageCollector(object):
func_graph_module.dismantle_func_graph(self._func_graph)
except: # pylint: disable=bare-except
pass
class _Marker(object):
"""Markers used to pretty-print nested args in function signatures."""
def __init__(self, s):
self._s = s
def __repr__(self):
return str(self._s)
def _structure_summary(structure):
"""Displays a summary of the nesting structure of the given value."""
def type_name(x):
if isinstance(x, type_spec.TypeSpec):
return x.value_type.__name__
else:
return type(x).__name__
markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
return str(nest.pack_sequence_as(structure, markers))
def _contains_type_spec(value):
return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))

View File

@ -36,6 +36,7 @@ from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -50,6 +51,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@ -65,6 +67,7 @@ from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
@ -99,6 +102,16 @@ def _example_indexed_slices_without_dense_shape():
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
def _spec_for_value(value):
"""Returns the (nested) TypeSpec for a value."""
if nest.is_sequence(value):
return nest.map_structure(_spec_for_value, value)
elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
return type_spec.type_spec_from_value(value)
else:
return value
class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self):
@ -1789,6 +1802,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'incompatible'):
func([['wrong dtype']])
def testNoKeywordOnlyArgumentsWithInputSignature(self):
if sys.version_info[0] < 3:
self.skipTest('keyword_only arguments only exist in Python 3.')
func = eval('lambda x, *, y: x') # pylint: disable=eval-used
signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
with self.assertRaisesRegexp(
ValueError, 'Cannot define a TensorFlow function from a Python '
'function with keyword-only arguments when input_signature is '
'provided.'):
def_function.function(func, signature)
def testNestedInputSignatures(self):
def expected_foo(a, b):
@ -1905,7 +1930,9 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined(array_ops.ones([2, 1]))
# Wrong number of arguments.
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
with self.assertRaisesRegexp(
TypeError, r'takes 1 positional arguments \(as specified by the '
r'input_signature\) but 2 were given'):
defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegexp(ValueError,
'Structure of Python function inputs.*'):
@ -1946,10 +1973,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return -1.0 * a
x = constant_op.constant(1.0)
with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
with self.assertRaisesRegexp(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=True)
with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
with self.assertRaisesRegexp(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=False)
self.assertAllEqual(x.numpy(), foo(x).numpy())
@ -2472,8 +2503,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return x
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
with self.assertRaisesRegexp(
ValueError, 'All inputs to `ConcreteFunction`s must be Tensors;.*'):
with self.assertRaises((TypeError, ValueError)):
graph_function('Not a Tensor.')
def testSwapImplementationWithGrapplerPlugin(self):
@ -3148,6 +3178,432 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
function.clear_function_callbacks()
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = constant_op.constant(1000)
b = constant_op.constant(200)
c = constant_op.constant(30)
d = {'a': a, 'b': b}
e = (c, 4)
# Test different argument signatures when constructing the concrete func.
for cf in [
f.get_concrete_function(d, e),
f.get_concrete_function(d, y=e),
f.get_concrete_function(y=e, x=d),
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
]:
# Test different calling conventions when calling the concrete func.
for output in [
cf(d, e), # structured signature
cf(d, y=e), # structured signature w/ kwarg
cf(y=e, x=d), # structured signature w/ 2 kwargs
cf(a, b, c), # flat signature
cf(x=a, x_1=b, y=c) # flat signature w/ kwargs
]:
self.assertIsInstance(output, tuple)
self.assertLen(output, 2)
self.assertAllEqual(output[0], 1200)
self.assertAllEqual(output[1], 34)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
b = (50, 3)
for cf in [ # argument y is bound to non-Tensor value (50, 3).
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 1253)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 3000, 'b': 200, 'c': 9000}
b = (constant_op.constant(30), 4)
for cf in [ # argument x is bound to non-tensor value `a`
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 3234)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 5000, 'b': 500}
b = (50, 5)
cf = f.get_concrete_function(a, b)
for output in [cf(), cf(a), cf(y=b)]:
self.assertAllEqual(output[0] + output[1], 5555)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
# Check that keyword-only arguments are sorted appropriately, so that they
# feed the right tensor into each input.
@def_function.function
def g(**kwargs):
return string_ops.reduce_join(
string_ops.reduce_join(
ops.convert_to_tensor(sorted(kwargs.items())),
axis=1,
separator='='),
axis=0,
separator=', ')
s = constant_op.constant('s')
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
self.assertAllEqual(
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (1, constant_op.constant(2)),
call_args=lambda: (1,),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
call_args=lambda: (1, 2),
error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2, 3),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
dict(
testcase_name='MissingKeywordOnlyArg',
conc_args=lambda: (1, 2),
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
call_args=lambda: (1, 2),
error=r'func\(x, y, \*, c\) missing required arguments: c'),
dict(
testcase_name='ExtraKeywordArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
error=r'func\(x, y\) got unexpected keyword arguments: c'),
dict(
testcase_name='ExpectedRaggedGotNest',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: ({
'a': constant_op.constant([1, 2, 3])
},),
error=r'func\(x, y\): argument x had incorrect type\n'
r' expected: RaggedTensor\n'
r" got: {'a': (Eager)?Tensor}"),
dict(
testcase_name='WrongRaggedRank',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='WrongRaggedDType',
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='ExpectedDictGotTensor',
conc_args=lambda: ({
'a': constant_op.constant(1),
'b': constant_op.constant(1)
},),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='ExpectedTupleGotTensor',
conc_args=lambda:
((constant_op.constant(1), constant_op.constant(2)),),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (5,),
error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
dict(
testcase_name='ExpectedIntGotDifferentInt',
conc_args=lambda: (5,),
call_args=lambda: (8,),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
r'value 5 in x, but was called with int value 8'),
dict(
testcase_name='ExpectedIntGotTensor',
conc_args=lambda: (5,),
call_args=lambda: (constant_op.constant(6),),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
'value 5 in x, but was called with (Eager)?Tensor value .*'),
dict(
testcase_name='TwoValuesForArgument',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'x': 3},
error=r"func\(x, y\) got two values for argument 'x'"),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the structrued signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@def_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
with self.assertRaisesRegexp(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='TwoValuesForArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {
'x': constant_op.constant(1),
'y': constant_op.constant(1)
},
error=r"func\(x, y\) got two values for argument 'x'"),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
dict(
testcase_name='UnexpectedKeywordArg',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x\) got unexpected keyword arguments: c'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, varargs_0\) missing required '
r'arguments: varargs_0'),
dict(
testcase_name='MissingKeywordArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': constant_op.constant(1)},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, c\) missing required arguments: c'),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (5, constant_op.constant(2)),
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
r'a Tensor; got int \(5\)'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='MissingKeywordArgNestPiece',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionFlatSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the flat signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@def_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
# Remove _function_spec, to disable the structured signature.
conc._set_function_spec(None) # pylint: disable=protected-access
with self.assertRaisesRegexp(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionAmbiguousSignature(self):
# When both the flat & structured signatures are applicable, but they
# give different results, we use the structured signature. Note: we expect
# this to be extremely rare.
@def_function.function
def f(x, y):
return x * 10 + y
conc = f.get_concrete_function(
x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
self.assertAllEqual(result, 56)
def testPrettyPrintedSignature(self):
@def_function.function
def func(x, kangaroo=None, octopus=7):
del octopus, kangaroo
return x
scalar = constant_op.constant(5)
vector = constant_op.constant([10, 10, 20])
ragged = ragged_factory_ops.constant([[10, 20], [40]])
c1 = func.get_concrete_function(scalar, vector)
c1_summary = r'func\(x, kangaroo, octopus=7\)'
c1_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: int32 Tensor, shape=\(3,\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegexpMatches(
c1.pretty_printed_signature(verbose=False), c1_summary)
self.assertRegexpMatches(
c1.pretty_printed_signature(verbose=True),
c1_summary + '\n' + c1_details)
self.assertRegexpMatches(
repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
self.assertRegexpMatches(
str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
c2 = func.get_concrete_function(scalar, ragged, 3)
c2_summary = r'func\(x, kangaroo, octopus=3\)'
c2_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegexpMatches(c2.pretty_printed_signature(),
c2_summary + '\n' + c2_details)
c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
c3_details = (r' Args:\n'
r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r" {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)')
self.assertRegexpMatches(c3.pretty_printed_signature(),
c3_summary + '\n' + c3_details)
# pylint: disable=keyword-arg-before-vararg
@def_function.function
def func2(x, y=3, *args, **kwargs):
return (x, y, args, kwargs)
c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
c5 = func2.get_concrete_function(8, vector)
c5_summary = 'func2(x=8, y)'
self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
class MultiDeviceTest(test.TestCase, parameterized.TestCase):

View File

@ -43,6 +43,8 @@ def _is_tensor(t):
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the
# structured signature.
def _call_concrete_function(function, inputs):
"""Calls a restored Function with structured inputs.
@ -137,8 +139,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
input_signature = coder.decode_proto(function_spec_proto.input_signature)
return function_lib.FunctionSpec(fullargspec=fullargspec,
is_method=False,
args_to_prepend=[],
kwargs_to_include={},
input_signature=input_signature)
@ -191,6 +191,8 @@ def recreate_function(saved_function, concrete_functions):
Args:
saved_function: `SavedFunction` proto.
concrete_functions: map from function name to `ConcreteFunction`.
As a side effect of this function, the `FunctionSpec` from
`saved_function` is added to each `ConcreteFunction` in this map.
Returns:
A `Function`.
@ -254,6 +256,9 @@ def recreate_function(saved_function, concrete_functions):
for concrete_function_name in saved_function.concrete_functions:
concrete_function_objects.append(concrete_functions[concrete_function_name])
for cf in concrete_function_objects:
cf._set_function_spec(function_spec) # pylint: disable=protected-access
restored_function = RestoredFunction(
restored_function_body,
restored_function_body.__name__,
@ -317,6 +322,11 @@ def load_function_def_library(library, load_shared_name_suffix=None):
for dep in _list_function_deps(fdef, library_function_names):
functions[dep].add_to_graph(func_graph)
# We do not initialize the new ConcreteFunction's function_spec or
# arg_keywords here (which are used to parse the structured and flat
# signatures, respectively). function_spec is set up later by
# recreate_function(); and arg_keywords by setup_bare_concrete_function().
func = function_lib.ConcreteFunction(func_graph)
func.add_to_graph(graph)

View File

@ -77,6 +77,10 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
def serialize_bare_concrete_function(concrete_function, name_map):
"""Build a SavedBareConcreteFunction."""
# TODO(edloper): Currently, bare concrete functions don't have access to a
# function_spec, so they can't be called with the structured signature.
# Update the serialization to include a function_spec.
# pylint: disable=protected-access
name = name_map.get(compat.as_text(concrete_function.name),
concrete_function.name)
@ -151,7 +155,8 @@ def wrap_cached_variables(concrete_function):
func_graph_module.func_graph_from_py_func(
None, wrap_function, args=tuple(args), kwargs={},
func_graph=outer_graph)
fn = defun.ConcreteFunction(outer_graph)
fn = defun.ConcreteFunction(
outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access

View File

@ -173,12 +173,12 @@ class Loader(object):
# The original_outputs here had Tensors converted to TensorSpecs, so
# the restored function's structured_outputs field will not be
# exactly the same. Fortunately the repacking logic cares only about
# the structure.
# TODO(vbardiovsky): Should we just replicate the structures, with
# Nones instead of real objects?
# the structure; and the unpacking logic cares only about structure
# and types.
concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access
concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access
coder.decode_proto(proto.canonicalized_input_signature))
concrete_function._initialize_function_spec() # pylint: disable=protected-access
def _setup_functions_captures(self):
"""Setup captures and variables in restored functions."""