Functions returned by the get_concrete_function method of tf.Function objects can now be called with arguments consistent with the original arguments or type specs passed to get_concrete_function. In particular:

* If a composite tensor (such as RaggedTensor or SparseTensor) was passed to `get_concrete_function`, then the returned function will accept a composite tensor of the same type for that argument.

* If a nested structure (such as a list or dict) was passed to `get_concrete_function`, then the returned function will accept a value with the same nesting structure.  Each tensor or composite tensor value must have the same type as was used in the original argument; and each non-Tensor value (such as bools or ints) must be equal value that was used in the original argument.

* If a non-tensor value (such as a bool or int) was passed to `get_concrete_function`, then the returned function no longer deletes that argument; instead, it updates the argument's default value to the value that was passed to `get_concrete_function`.  Passing in any other value will raise an exception.

* Arguments are not renamed based on `TensorSpec.name`.

For backwards compatibility, the functions returned by `get_concrete_function` will continue to accept arguments with the existing calling conventions (where nested structures and composite tensors are flattened; non-tensor arguments are deleted; suffixes are automatically added to disambiguate arguments with the same name; and TensorSpec.name is used to rename arguments).  However, the preferred calling convention is the one that is consistent with the original arguments or type specs passed to `get_concrete_function`.

PiperOrigin-RevId: 307398918
Change-Id: Ie4685b32d9f151c82f6c79a6c41379faa96b5ee8
This commit is contained in:
Edward Loper 2020-04-20 08:00:45 -07:00 committed by TensorFlower Gardener
parent 0887fedd2d
commit f39aab3092
6 changed files with 1041 additions and 101 deletions

View File

@ -830,6 +830,13 @@ class Function(object):
def function_spec(self): def function_spec(self):
return self._function_spec return self._function_spec
def pretty_printed_concrete_signatures(self, verbose=True):
joiner = "\n\n" if verbose else "\n"
return joiner.join([
c.pretty_printed_signature(verbose=verbose)
for c in self._list_all_concrete_functions()
])
def _initialize_uninitialized_variables(self, initializers): def _initialize_uninitialized_variables(self, initializers):
"""Make and call a `ConcreteFunction` which initializes variables.""" """Make and call a `ConcreteFunction` which initializes variables."""
@ -913,12 +920,8 @@ class Function(object):
return initialize_variables.get_concrete_function() return initialize_variables.get_concrete_function()
def _list_all_concrete_functions_for_serialization(self): def _list_all_concrete_functions(self):
"""Returns all concrete functions for serialization. """Returns all concrete functions."""
Returns:
A list of instances of `ConcreteFunction`.
"""
if self.input_signature is not None: if self.input_signature is not None:
self.get_concrete_function() self.get_concrete_function()
concrete_functions = [] concrete_functions = []
@ -930,6 +933,15 @@ class Function(object):
concrete_functions.extend( concrete_functions.extend(
self._stateless_fn._function_cache.all_values()) self._stateless_fn._function_cache.all_values())
# pylint: enable=protected-access # pylint: enable=protected-access
return concrete_functions
def _list_all_concrete_functions_for_serialization(self):
"""Returns all concrete functions for serialization.
Returns:
A list of instances of `ConcreteFunction`.
"""
concrete_functions = self._list_all_concrete_functions()
seen_signatures = [] seen_signatures = []
for concrete_function in concrete_functions: for concrete_function in concrete_functions:
signature = concrete_function.structured_input_signature signature = concrete_function.structured_input_signature

View File

@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import custom_gradient
@ -340,7 +341,7 @@ class _InterpolateFunctionError(object):
if t.name == compat.as_str(self._func.name): if t.name == compat.as_str(self._func.name):
g = self._func.graph g = self._func.graph
elif g: elif g:
next_func = g._get_function(t.name) next_func = g._get_function(t.name) # pylint: disable=protected-access
if next_func is not None and isinstance(next_func, if next_func is not None and isinstance(next_func,
_EagerDefinedFunction): _EagerDefinedFunction):
g = next_func.graph g = next_func.graph
@ -1499,6 +1500,12 @@ class _ForwardBackwardCall(object):
flat_outputs, self._inference_args, self._input_tangents) flat_outputs, self._inference_args, self._input_tangents)
# Sentinel value used by with ConcreteFunction's structured signature to
# indicate that a non-tensor parameter should use the value that was
# specified when the concrete function was created.
_BOUND_VALUE = object()
class ConcreteFunction(object): class ConcreteFunction(object):
"""Callable object encapsulating a function definition and its gradient. """Callable object encapsulating a function definition and its gradient.
@ -1506,7 +1513,11 @@ class ConcreteFunction(object):
is differentiable under `tf.GradientTape` objects. is differentiable under `tf.GradientTape` objects.
""" """
def __init__(self, func_graph, attrs=None, shared_func_graph=True): def __init__(self,
func_graph,
attrs=None,
shared_func_graph=True,
function_spec=None):
"""Initialize a `ConcreteFunction`. """Initialize a `ConcreteFunction`.
Args: Args:
@ -1517,16 +1528,25 @@ class ConcreteFunction(object):
shared_func_graph: If False, the ConcreteFunction takes ownership of shared_func_graph: If False, the ConcreteFunction takes ownership of
`func_graph` and will break reference cycles when it is deleted. This `func_graph` and will break reference cycles when it is deleted. This
makes the FuncGraph inoperable. makes the FuncGraph inoperable.
function_spec: FunctionSpec for the original function. If not specified,
then this ConcreteFunction may only be called using the flat signature.
Raises: Raises:
ValueError: If number of input_placeholders is not equal to the number ValueError: If number of input_placeholders is not equal to the number
of function inputs. of function inputs.
""" """
# _arg_keywords and _num_positional_args define the flat signature. They
# are assigned after construction.
self._arg_keywords = None self._arg_keywords = None
self._num_positional_args = None self._num_positional_args = None
self._func_graph = func_graph self._func_graph = func_graph
self._captured_inputs = self._func_graph.external_captures self._captured_inputs = self._func_graph.external_captures
self._captured_closures = self._func_graph.deferred_external_captures self._captured_closures = self._func_graph.deferred_external_captures
# function_spec defines the structured signature.
self._set_function_spec(function_spec)
if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs: if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
# The alternative is to silently drop "implements" tag # The alternative is to silently drop "implements" tag
# but it seems likely it would lead to hard to catch bugs. # but it seems likely it would lead to hard to catch bugs.
@ -1576,6 +1596,52 @@ class ConcreteFunction(object):
# building gradients. # building gradients.
self._inference_function = self._delayed_rewrite_functions.forward() self._inference_function = self._delayed_rewrite_functions.forward()
def _set_function_spec(self, function_spec):
"""Enables the structured signature by supplying a function_spec."""
self._function_spec = None
self._pre_initialized_function_spec = function_spec
# Note: when ConcreteFunctions are built by recreate_function() in
# function_deserialization.py, they don't have a structured_input_signature
# yet. In that case, _initialize_function_spec() gets called by
# _setup_functions_structures() in load.py.
if (function_spec is not None and
self.structured_input_signature is not None):
self._initialize_function_spec()
def _initialize_function_spec(self):
"""Updates `self._function_spec` to include varargs and bound variables.
Adds new positional arguments for any varargs (i.e., for args that are
in `structured_input_signature`, but not in the original fullargspec.args).
Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
all args and kwargs in `structured_input_signature`.
Sets `varkw` and `varargs` to None.
"""
if self._pre_initialized_function_spec is None:
return # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
assert not self._function_spec, "already initialized"
function_spec = self._pre_initialized_function_spec
args = function_spec.fullargspec.args
arg_specs, kwarg_specs = self.structured_input_signature
fullargspec = tf_inspect.FullArgSpec(
args=list(args) +
["<arg{}>".format(i + 1) for i in range(len(args), len(arg_specs))],
varargs=None,
varkw=None,
defaults=[_BOUND_VALUE] * len(arg_specs),
kwonlyargs=list(sorted(kwarg_specs)),
kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
annotations=function_spec.fullargspec.annotations)
self._function_spec = FunctionSpec(
fullargspec,
function_spec.is_method,
function_spec.input_signature,
function_spec.is_pure,
name=self._func_graph.name)
@property @property
def variables(self): def variables(self):
"""Sequence of variables for this function.""" """Sequence of variables for this function."""
@ -1589,15 +1655,44 @@ class ConcreteFunction(object):
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Executes the wrapped function. """Executes the wrapped function.
ConcreteFunctions have two signatures:
* The signature of the original function wrapped by this ConcreteFunction.
* A flat signature, where each argument accepts a single Tensor.
The original function signature is generally preferred, but the flat input
signature is supported for backward compatibility.
### Original Function Signature
When calling a ConcreteFunction with the signature of the original function,
each argument must match the type or value that was used when the
ConcreteFunction's graph was traced. In particular:
* Tensor arguments (including CompositeTensors, such as RaggedTensor) must
have matching `TypeSpec`s.
* Non-Tensor arguments (such as booleans or ints) must have equal values.
* Nested arguments (such as lists, tuples, or dictionaries) must have the
same nesting structure; and each nested value must have a matching type
or value.
The default value for any arguments that were traced with non-Tensor values
is the value that was used in the trace. Arguments that were traced with
tensor arguments do not have a default value (even if the original function
had a default value for that argument).
### Flat Signature
When calling a ConcreteFunction with the flat signature, the arguments
correspond to the flattened component tensors of the arguments that were
used to construct the ConcreteFunction. Parameter names are assigned based
on `TensorSpec.name` (when specified) or the original argument names (with
suffixes automatically added for nested arguments or composite tensors with
multiple components).
Args: Args:
*args: Tensors or Variables. Positional arguments are only accepted when *args: Positional arguments to the concrete function.
they correspond one-to-one with arguments of the traced Python function. **kwargs: Keyword arguments to the concrete function.
**kwargs: Tensors or Variables specified by name. When
`get_concrete_function` was called to create this `ConcreteFunction`,
each Tensor input was given a name, defaulting to the name of the Python
function's argument but possibly overridden by the `name=` argument to
`tf.TensorSpec`. These names become the argument names for the concrete
function.
Returns: Returns:
The result of applying the TF function on the given Tensors. The result of applying the TF function on the given Tensors.
@ -1605,9 +1700,7 @@ class ConcreteFunction(object):
Raises: Raises:
AssertionError: If this `ConcreteFunction` was not created through AssertionError: If this `ConcreteFunction` was not created through
`get_concrete_function`. `get_concrete_function`.
ValueError: If arguments contains anything other than Tensors or TypeError: If the arguments do not match the function's signature.
Variables.
TypeError: For invalid positional/keyword argument combinations.
""" """
return self._call_impl(args, kwargs) return self._call_impl(args, kwargs)
@ -1615,40 +1708,174 @@ class ConcreteFunction(object):
"""See `__call__` for details.""" """See `__call__` for details."""
with traceme.TraceMe(self._func_graph.name, with traceme.TraceMe(self._func_graph.name,
tf_function_call="concrete"): tf_function_call="concrete"):
if self._arg_keywords is None or self._num_positional_args is None: # Construct the list of input tensors: check if the structured signature
raise AssertionError( # applies first; and if not, then use the flat signature.
"Tried to call a concrete function obtained from an internal API " if self._function_spec is not None:
"through the public interface. Use get_concrete_function instead.")
if len(args) > self._num_positional_args:
raise TypeError(
("Expected at most {} positional arguments (and the rest keywords, "
"of {}), got {}. When calling a concrete function, positional "
"arguments may not be bound to Tensors within nested structures."
).format(self._num_positional_args, self._arg_keywords, args))
args = list(args)
for keyword in self._arg_keywords[len(args):]:
try: try:
args.append(kwargs.pop(compat.as_str(keyword))) return self._call_with_structured_signature(args, kwargs,
except KeyError: cancellation_manager)
specified_keywords = (list(self._arg_keywords[:len(args)]) except TypeError as structured_err:
+ list(kwargs.keys())) try:
raise TypeError( return self._call_with_flat_signature(args, kwargs,
"Expected argument names {} but got values for {}. Missing: {}." cancellation_manager)
.format( except TypeError:
list(self._arg_keywords), raise structured_err
specified_keywords,
list(set(self._arg_keywords) - set(specified_keywords))))
if kwargs:
positional_arg_keywords = set(self._arg_keywords[:len(args)])
for unused_key in kwargs:
if unused_key in positional_arg_keywords:
raise TypeError("Got two values for keyword '{}'.".format(
unused_key))
raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
list(kwargs.keys()), list(self._arg_keywords)))
return self._call_flat(args, self.captured_inputs, cancellation_manager)
def _filtered_call(self, args, kwargs): return self._call_with_flat_signature(args, kwargs, cancellation_manager)
def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
"""Executes the wrapped function with the flat signature.
Args:
args: Positional arguments to the concrete function.
kwargs: Keyword arguments to the concrete function.
cancellation_manager: A `CancellationManager` that can be used to cancel
function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
`args` and `kwargs`.
Raises:
TypeError: if `args` and `kwargs` do not match the flat signature of this
`ConcreteFunction`.
"""
if len(args) > self._num_positional_args:
raise TypeError(
"{} takes {} positional arguments but {} were given".format(
self._flat_signature_summary(), self._num_positional_args,
len(args)))
args = list(args)
kwargs = dict(kwargs)
for keyword in self._arg_keywords[len(args):]:
try:
args.append(kwargs.pop(compat.as_str(keyword)))
except KeyError:
specified_keywords = (
list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
raise TypeError("{} missing required arguments: {}".format(
self._flat_signature_summary(), ", ".join(
sorted(set(self._arg_keywords) - set(specified_keywords)))))
if kwargs:
positional_arg_keywords = set(self._arg_keywords[:len(args)])
for unused_key in kwargs:
if unused_key in positional_arg_keywords:
raise TypeError("{} got two values for argument '{}'".format(
self._flat_signature_summary(), unused_key))
raise TypeError("{} got unexpected keyword arguments: {}.".format(
self._flat_signature_summary(), ", ".join(sorted(kwargs))))
for i, arg in enumerate(args):
if not isinstance(
arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
"got {} ({})".format(self._flat_signature_summary(), i,
type(arg).__name__, str(arg)))
return self._call_flat(args, self.captured_inputs, cancellation_manager)
def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
"""Executes the wrapped function with the structured signature.
Args:
args: Positional arguments to the concrete function.
kwargs: Keyword arguments to the concrete function.
cancellation_manager: A `CancellationManager` that can be used to cancel
function invocation.
Returns:
The result of applying the function on the Tensors/Variables contained in
`args` and `kwargs`.
Raises:
TypeError: if `args` and `kwargs` do not match the structured signature
of this `ConcreteFunction`.
"""
args, kwargs = self._function_spec.canonicalize_function_inputs(
*args, **kwargs)
self._structured_signature_check_missing_args(args, kwargs)
self._structured_signature_check_unexpected_args(args, kwargs)
self._structured_signature_check_arg_types(args, kwargs)
return self._filtered_call(args, kwargs, cancellation_manager)
def _structured_signature_check_missing_args(self, args, kwargs):
"""Raises a TypeError if any args are missing."""
arg_specs, kwarg_specs = self.structured_input_signature
missing_arguments = []
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
if arg is _BOUND_VALUE and _contains_type_spec(spec):
missing_arguments.append(self._function_spec.arg_names[i])
for (name, arg) in kwargs.items():
if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
missing_arguments.append(name)
if missing_arguments:
raise TypeError("{} missing required arguments: {}".format(
self._structured_signature_summary(),
", ".join(sorted(missing_arguments))))
def _structured_signature_check_unexpected_args(self, args, kwargs):
"""Raises a TypeError if there are any extra args."""
arg_specs, kwarg_specs = self.structured_input_signature
if len(args) > len(arg_specs):
raise TypeError(
"{} takes {} positional arguments but {} were given".format(
self._structured_signature_summary(),
len(self._function_spec.arg_names), len(args)))
if len(kwargs) > len(kwarg_specs):
extra_args = set(kwargs) - set(kwarg_specs)
raise TypeError("{} got unexpected keyword arguments: {}".format(
self._structured_signature_summary(), ", ".join(extra_args)))
def _structured_signature_check_arg_types(self, args, kwargs):
"""Raises a TypeError if any args have the wrong type."""
# Check argument types
arg_specs, kwarg_specs = self.structured_input_signature
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
name = self._function_spec.arg_names[i]
self._structured_signature_check_arg_type(arg, spec, name)
for (name, arg) in kwargs.items():
self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
def _structured_signature_check_arg_type(self, arg, spec, name):
"""Raise TypeError if `arg`'s type doesn't match `spec`."""
if arg is _BOUND_VALUE:
return
# Check the overall nested structure of the argument.
try:
nest.assert_same_structure(arg, spec, expand_composites=True)
except (ValueError, TypeError):
try:
nest.assert_same_structure(arg, spec, expand_composites=False)
expected, got = spec, arg
except (ValueError, TypeError):
expected, got = _structure_summary(spec), _structure_summary(arg)
raise TypeError("{}: argument {} had incorrect type\n"
" expected: {}\n got: {}".format(
self._structured_signature_summary(), name, expected,
got))
# Check the type for each leaf in the nested structure.
arg_pieces = nest.flatten(arg, expand_composites=True)
spec_pieces = nest.flatten(spec, expand_composites=True)
for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
if isinstance(spec_piece, tensor_spec.DenseSpec):
# TODO(edloper): Consider calling convert_to_tensor on non-tensor
# values here. That would match the behavior of
# _call_concrete_function() in function_deserialization.py. If
# we do, then we need to change the nest assert_same_structure and
# flatten calls above to use shallow variants.
tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
if not isinstance(arg_piece, tensor_types):
raise TypeError(
"{} expected a Tensor in {}, but got {} value {}".format(
self._structured_signature_summary(), name,
type(arg_piece).__name__, arg_piece))
elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece:
raise TypeError("ConcreteFunction {} was constructed with {} value "
"{} in {}, but was called with {} value {}".format(
self._structured_signature_summary(),
type(spec_piece).__name__, spec_piece, name,
type(arg_piece).__name__, arg_piece))
def _filtered_call(self, args, kwargs, cancellation_manager=None):
"""Executes the function, filtering arguments from the Python function. """Executes the function, filtering arguments from the Python function.
Objects aside from Tensors, CompositeTensors, and Variables are ignored. Objects aside from Tensors, CompositeTensors, and Variables are ignored.
@ -1657,6 +1884,8 @@ class ConcreteFunction(object):
Args: Args:
args: Canonicalized positional arguments of the Python function. args: Canonicalized positional arguments of the Python function.
kwargs: Canonicalized keyword arguments of the Python function. kwargs: Canonicalized keyword arguments of the Python function.
cancellation_manager: (Optional.) A `CancellationManager` that can be
used to cancel function invocation.
Returns: Returns:
The result of applying the function on the Tensors/Variables contained in The result of applying the function on the Tensors/Variables contained in
@ -1666,7 +1895,8 @@ class ConcreteFunction(object):
(t for t in nest.flatten((args, kwargs), expand_composites=True) (t for t in nest.flatten((args, kwargs), expand_composites=True)
if isinstance(t, (ops.Tensor, if isinstance(t, (ops.Tensor,
resource_variable_ops.BaseResourceVariable))), resource_variable_ops.BaseResourceVariable))),
self.captured_inputs) captured_inputs=self.captured_inputs,
cancellation_manager=cancellation_manager)
def _call_flat(self, args, captured_inputs, cancellation_manager=None): def _call_flat(self, args, captured_inputs, cancellation_manager=None):
"""Executes the wrapped function. """Executes the wrapped function.
@ -1795,7 +2025,26 @@ class ConcreteFunction(object):
@property @property
def structured_input_signature(self): def structured_input_signature(self):
"""Returns structured signature of the original function.""" """Returns structured signature for this concrete function.
Returns:
A tuple `(args, kwargs)`, where:
* `args` is a tuple that specifies the expected type or value each for
positional argument.
* `kwargs` is a dictionary that specifies the expected type or value
for each keyword-only argument.
The type or value for each argument is specified using one of the
following:
* A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
value is expected.
* A Python value, such as an integer, indicating that an equal value
is expected.
* A nested structure of `tf.TypeSpec`s and Python values, indicating
that a corresponding nested structure is expected.
"""
return self._func_graph.structured_input_signature return self._func_graph.structured_input_signature
@property @property
@ -1982,6 +2231,103 @@ class ConcreteFunction(object):
ret.attr[name].CopyFrom(value) ret.attr[name].CopyFrom(value)
return ret return ret
def _structured_signature_summary(self, default_values=False):
"""Returns a string summarizing this function's structured signature.
Args:
default_values: If true, then include default values in the signature.
Returns:
A `string`.
"""
# Note: we can't just use self._funcion_spec.signature_summary(), because
# that would show "_BOUND_VALUE" as the default value for all arguments.
assert self._function_spec is not None
arg_specs, kwarg_specs = self.structured_input_signature
arg_names = list(self._function_spec.arg_names)
if default_values:
for i in range(len(arg_names)):
if not _contains_type_spec(arg_specs[i]):
arg_names[i] += "={}".format(arg_specs[i])
if kwarg_specs:
arg_names.append("*")
for name, spec in kwarg_specs.items():
arg_names.append(name)
if default_values and not _contains_type_spec(spec):
arg_names[-1] += "={}".format(spec)
signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names))
return signature
def _flat_signature_summary(self):
"""Returns a string summarizing this function's flat signature."""
assert self._arg_keywords is not None
assert self._num_positional_args is not None
arg_names = self._arg_keywords
if self._num_positional_args > len(arg_names):
arg_names.extend(
"<arg{}>".format(i + 1)
for i in range(len(arg_names), self._num_positional_args))
return "{}({})".format(self._func_graph.name, ", ".join(arg_names))
def pretty_printed_signature(self, verbose=True):
"""Returns a string summarizing the signature of this concrete function."""
if not verbose:
return self._structured_signature_summary(default_values=True)
def pretty_print_spec(spec):
"""Returns a string describing the spec for a single argument."""
if isinstance(spec, tensor_spec.TensorSpec):
return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
elif nest.is_sequence(spec):
pieces = nest.flatten(spec, expand_composites=False)
markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
structure = nest.pack_sequence_as(spec, markers)
result = "{}".format(structure)
for (marker, piece) in zip(markers, pieces):
result += "\n {}: {}".format(marker, pretty_print_spec(piece))
return result
else:
return repr(spec)
lines = [self._structured_signature_summary(default_values=True)]
arg_specs, kwarg_specs = self.structured_input_signature
names = list(self._function_spec.arg_names)
names.extend(sorted(kwarg_specs))
specs = list(arg_specs) + list(kwarg_specs.values())
# note: we can skip bound args, since we already displayed thier bound
# value in the signature summary.
arg_details = []
for (name, spec) in zip(names, specs):
if _contains_type_spec(spec):
arg_details.append(" {}: {}".format(name, pretty_print_spec(spec)))
if arg_details:
lines.append(" Args:")
lines.extend(arg_details)
lines.append(" Returns:")
lines.append(" {}".format(
pretty_print_spec(
nest.map_structure(type_spec.type_spec_from_value,
self.structured_outputs))))
return "\n".join(lines)
def __repr__(self):
if self._function_spec is not None:
return "<ConcreteFunction {} at 0x{:X}>".format(
self.pretty_printed_signature(verbose=False), id(self))
elif not (self._num_positional_args is None or self._arg_keywords is None):
return "<ConcreteFunction {} at 0x{:X}>".format(
self._flat_signature_summary(), id(self))
else:
return object.__repr__(self)
def __str__(self):
if self._function_spec is not None:
return "ConcreteFunction {}".format(self.pretty_printed_signature())
else:
return self.__repr__()
_pywrap_utils.RegisterType("Tensor", ops.Tensor) _pywrap_utils.RegisterType("Tensor", ops.Tensor)
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) _pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
@ -2075,17 +2421,37 @@ class FunctionSpec(object):
kwonlydefaults={}, kwonlydefaults={},
annotations=fullargspec.annotations) annotations=fullargspec.annotations)
is_method = tf_inspect.ismethod(python_function) is_method = tf_inspect.ismethod(python_function)
return FunctionSpec(fullargspec, is_method, [], {}, input_signature,
is_pure=is_pure)
def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include, # Get the function's name. Remove functools.partial wrappers if necessary.
input_signature, is_pure=False): while isinstance(python_function, functools.partial):
python_function = python_function.func
name = getattr(python_function, "__name__", "f")
return FunctionSpec(
fullargspec, is_method, input_signature, is_pure=is_pure, name=name)
def __init__(self,
fullargspec,
is_method,
input_signature,
is_pure=False,
name=None):
"""Constructs a FunctionSpec describing a python function.
Args:
fullargspec: `tf_inspect.FullArgSpec` object describing the function.
is_method: True if the function is a method.
input_signature: a signature of the function (None, if variable)
is_pure: if True all input arguments (including variables and constants)
will be converted to tensors and no variable changes allowed.
name: Name of the function
"""
self._fullargspec = fullargspec self._fullargspec = fullargspec
self._is_method = is_method self._is_method = is_method
self._is_pure = is_pure self._is_pure = is_pure
del args_to_prepend
del kwargs_to_include # TODO(edloper): Include name when serializing for SavedModel?
self._default_values = fullargspec.defaults self._name = name or "f"
if self._is_method: if self._is_method:
# Remove `self`: default arguments shouldn't be matched to it. # Remove `self`: default arguments shouldn't be matched to it.
@ -2098,21 +2464,21 @@ class FunctionSpec(object):
# A cache mapping from argument name to index, for canonicalizing # A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion. # arguments that are called in a keyword-like fashion.
self._args_to_indices = {arg: i for i, arg in enumerate(args)} self._args_to_indices = {arg: i for i, arg in enumerate(args)}
self.arg_names = args self._arg_names = args
self.vararg_name = fullargspec.varargs
# A cache mapping from arg index to default value, for canonicalization. # A cache mapping from arg index to default value, for canonicalization.
offset = len(args) - len(self._default_values or []) default_values = fullargspec.defaults
offset = len(args) - len(default_values or [])
self._arg_indices_to_default_values = { self._arg_indices_to_default_values = {
offset + index: default offset + index: default
for index, default in enumerate(self._default_values or []) for index, default in enumerate(default_values or [])
} }
if input_signature is None: if input_signature is None:
self._input_signature = None self._input_signature = None
else: else:
if fullargspec.kwonlyargs: if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
raise ValueError("Cannot define a TensorFlow function from a Python " raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when " "function with keyword-only arguments when "
"input_signature is provided.") "input_signature is provided.")
if not isinstance(input_signature, (tuple, list)): if not isinstance(input_signature, (tuple, list)):
@ -2132,8 +2498,8 @@ class FunctionSpec(object):
return self._is_method return self._is_method
@property @property
def args_to_prepend(self): def args_to_indices(self):
return self._args_to_prepend return self._args_to_indices
@property @property
def kwargs_to_include(self): def kwargs_to_include(self):
@ -2147,6 +2513,43 @@ class FunctionSpec(object):
def flat_input_signature(self): def flat_input_signature(self):
return self._flat_input_signature return self._flat_input_signature
@property
def is_pure(self):
return self._is_pure
@property
def arg_names(self):
return self._arg_names
@property
def vararg_name(self):
return self._fullargspec.varargs
@property
def varkw_name(self):
return self._fullargspec.varkw
def signature_summary(self, default_values=False):
"""Returns a string summarizing this function's signature.
Args:
default_values: If true, then include default values in the signature.
Returns:
A `string`.
"""
args = list(self._arg_names)
if default_values:
for (i, default) in self._arg_indices_to_default_values.items():
args[i] += "={}".format(default)
if self._fullargspec.kwonlyargs:
args.append("*")
for arg_name in self._fullargspec.kwonlyargs:
args.append(arg_name)
if default_values and arg_name in self._fullargspec.kwonlydefaults:
args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
return "{}({})".format(self._name, ", ".join(args))
def _convert_variables_to_tensors(self, args, kwargs): def _convert_variables_to_tensors(self, args, kwargs):
args = [ops.convert_to_tensor(x) for x in args] args = [ops.convert_to_tensor(x) for x in args]
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()} kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
@ -2159,7 +2562,13 @@ class FunctionSpec(object):
instance. In particular, we parse the varags and kwargs that the instance. In particular, we parse the varags and kwargs that the
original function was called with into a tuple corresponding to the original function was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary Python function's positional (named) arguments and a dictionary
corresponding to its kwargs. corresponding to its kwargs. Missing default arguments are added.
If this `FunctionSpec` has an input signature, then it is used to convert
arguments to tensors; otherwise, any inputs containing numpy arrays are
converted to tensors.
Additionally, any inputs containing numpy arrays are converted to Tensors.
Args: Args:
*args: The varargs this object was called with. *args: The varargs this object was called with.
@ -2180,29 +2589,38 @@ class FunctionSpec(object):
args, kwargs = self._convert_variables_to_tensors(args, kwargs) args, kwargs = self._convert_variables_to_tensors(args, kwargs)
if self._input_signature is not None: if self._input_signature is not None:
if len(args) > len(self._input_signature): if len(args) > len(self._input_signature):
raise TypeError( raise TypeError("{} takes {} positional arguments (as specified by the "
"When input_signature is provided, only pass arguments " "input_signature) but {} were given".format(
"covered by it. Received %d argument(s)." % len(args)) self.signature_summary(),
len(self._input_signature), len(args)))
for arg in six.iterkeys(kwargs): for arg in six.iterkeys(kwargs):
index = self._args_to_indices.get(arg, None) index = self._args_to_indices.get(arg, None)
if index is None: if index is None:
raise TypeError( raise TypeError("{} got unexpected keyword argument `{}`".format(
"Function got an unexpected keyword argument %s" % arg) self.signature_summary(), arg))
if index >= len(self._input_signature): if index >= len(self._input_signature):
raise TypeError( raise TypeError(
"When input_signature is provided, only pass arguments " "{} got keyword argument `{}` that was not included in "
"covered by it. Received argument %s." % arg) "input_signature".format(self.signature_summary(), arg))
if not kwargs: if not kwargs:
inputs = args inputs = args
default_keys = sorted(self._arg_indices_to_default_values.keys()) if self._arg_indices_to_default_values:
if default_keys: try:
assert min(default_keys) <= len( inputs += tuple(
args), "Not enough arguments (%s, %s, %s)" % (args, default_keys, self._arg_indices_to_default_values[i]
self.arg_names) for i in range(len(args), len(self._arg_names)))
for index in default_keys: except KeyError:
if index >= len(args): missing_args = [
inputs += (self._arg_indices_to_default_values[index],) self._arg_names[i]
for i in range(len(args), len(self._arg_names))
if i not in self._arg_indices_to_default_values
]
raise TypeError("{} missing required arguments: {}".format(
self.signature_summary(), ", ".join(missing_args)))
if self._fullargspec.kwonlydefaults:
kwargs.update(self._fullargspec.kwonlydefaults)
else: else:
# Maps from index of arg to its corresponding value, according to `args` # Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that # and `kwargs`; seeded with the default values for the named args that
@ -2215,18 +2633,28 @@ class FunctionSpec(object):
for arg, value in six.iteritems(kwargs): for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None) index = self._args_to_indices.get(arg, None)
if index is not None: if index is not None:
if index < len(args):
raise TypeError("{} got two values for argument '{}'".format(
self.signature_summary(), arg))
arg_indices_to_values[index] = value arg_indices_to_values[index] = value
consumed_args.append(arg) consumed_args.append(arg)
elif self._input_signature is not None:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args: for arg in consumed_args:
# After this loop, `kwargs` will only contain true keyword arguments, as # After this loop, `kwargs` will only contain keyword_only arguments,
# opposed to named arguments called in a keyword-like fashion. # and all positional_or_keyword arguments have been moved to `inputs`.
kwargs.pop(arg) kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values) inputs = args + _deterministic_dict_values(arg_indices_to_values)
if kwargs and self._input_signature is not None:
raise TypeError(
"{} got unexpected keyword arguments: {}\n(Cannot define a "
"TensorFlow function from a Python function with keyword arguments "
"when input_signature is provided.)".format(
self.signature_summary(), ", ".join(kwargs)))
if self._fullargspec.kwonlydefaults:
for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
kwargs.setdefault(kwarg, default)
if self._input_signature is None: if self._input_signature is None:
inputs = _convert_numpy_inputs(inputs) inputs = _convert_numpy_inputs(inputs)
kwargs = _convert_numpy_inputs(kwargs) kwargs = _convert_numpy_inputs(kwargs)
@ -2447,6 +2875,7 @@ class Function(object):
graph_function, _, _ = self._maybe_define_function(args, kwargs) graph_function, _, _ = self._maybe_define_function(args, kwargs)
return graph_function return graph_function
# XX TODO: make sure we fix up this path as well!?
def _get_concrete_function_internal(self, *args, **kwargs): def _get_concrete_function_internal(self, *args, **kwargs):
"""Bypasses error checking when getting a graph function.""" """Bypasses error checking when getting a graph function."""
graph_function = self._get_concrete_function_internal_garbage_collected( graph_function = self._get_concrete_function_internal_garbage_collected(
@ -2664,6 +3093,7 @@ class Function(object):
override_flat_arg_shapes=override_flat_arg_shapes, override_flat_arg_shapes=override_flat_arg_shapes,
capture_by_value=self._capture_by_value), capture_by_value=self._capture_by_value),
self._function_attributes, self._function_attributes,
function_spec=self.function_spec,
# Tell the ConcreteFunction to clean up its graph once it goes out of # Tell the ConcreteFunction to clean up its graph once it goes out of
# scope. This is not the default behavior since it gets used in some # scope. This is not the default behavior since it gets used in some
# places (like Keras) where the FuncGraph lives longer than the # places (like Keras) where the FuncGraph lives longer than the
@ -3350,3 +3780,30 @@ class ConcreteFunctionGarbageCollector(object):
func_graph_module.dismantle_func_graph(self._func_graph) func_graph_module.dismantle_func_graph(self._func_graph)
except: # pylint: disable=bare-except except: # pylint: disable=bare-except
pass pass
class _Marker(object):
"""Markers used to pretty-print nested args in function signatures."""
def __init__(self, s):
self._s = s
def __repr__(self):
return str(self._s)
def _structure_summary(structure):
"""Displays a summary of the nesting structure of the given value."""
def type_name(x):
if isinstance(x, type_spec.TypeSpec):
return x.value_type.__name__
else:
return type(x).__name__
markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
return str(nest.pack_sequence_as(structure, markers))
def _contains_type_spec(value):
return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))

View File

@ -36,6 +36,7 @@ from tensorflow.python.eager import cancellation
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function from tensorflow.python.eager import def_function
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import config from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -50,6 +51,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.layers import convolutional from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops from tensorflow.python.ops import check_ops
@ -65,6 +67,7 @@ from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops from tensorflow.python.ops.ragged import ragged_factory_ops
@ -99,6 +102,16 @@ def _example_indexed_slices_without_dense_shape():
constant_op.constant([1, 2]), constant_op.constant([0, 1])) constant_op.constant([1, 2]), constant_op.constant([0, 1]))
def _spec_for_value(value):
"""Returns the (nested) TypeSpec for a value."""
if nest.is_sequence(value):
return nest.map_structure(_spec_for_value, value)
elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
return type_spec.type_spec_from_value(value)
else:
return value
class FunctionTest(test.TestCase, parameterized.TestCase): class FunctionTest(test.TestCase, parameterized.TestCase):
def setUp(self): def setUp(self):
@ -1789,6 +1802,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, 'incompatible'): with self.assertRaisesRegexp(ValueError, 'incompatible'):
func([['wrong dtype']]) func([['wrong dtype']])
def testNoKeywordOnlyArgumentsWithInputSignature(self):
if sys.version_info[0] < 3:
self.skipTest('keyword_only arguments only exist in Python 3.')
func = eval('lambda x, *, y: x') # pylint: disable=eval-used
signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
with self.assertRaisesRegexp(
ValueError, 'Cannot define a TensorFlow function from a Python '
'function with keyword-only arguments when input_signature is '
'provided.'):
def_function.function(func, signature)
def testNestedInputSignatures(self): def testNestedInputSignatures(self):
def expected_foo(a, b): def expected_foo(a, b):
@ -1905,7 +1930,9 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
defined(array_ops.ones([2, 1])) defined(array_ops.ones([2, 1]))
# Wrong number of arguments. # Wrong number of arguments.
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'): with self.assertRaisesRegexp(
TypeError, r'takes 1 positional arguments \(as specified by the '
r'input_signature\) but 2 were given'):
defined(array_ops.ones([2]), array_ops.ones([2])) defined(array_ops.ones([2]), array_ops.ones([2]))
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegexp(ValueError,
'Structure of Python function inputs.*'): 'Structure of Python function inputs.*'):
@ -1946,10 +1973,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return -1.0 * a return -1.0 * a
x = constant_op.constant(1.0) x = constant_op.constant(1.0)
with self.assertRaisesRegexp(TypeError, 'only pass arguments'): with self.assertRaisesRegexp(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=True) foo(x, training=True)
with self.assertRaisesRegexp(TypeError, 'only pass arguments'): with self.assertRaisesRegexp(
TypeError, 'got keyword argument `training` '
'that was not included in input_signature'):
foo(x, training=False) foo(x, training=False)
self.assertAllEqual(x.numpy(), foo(x).numpy()) self.assertAllEqual(x.numpy(), foo(x).numpy())
@ -2472,8 +2503,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
return x return x
graph_function = foo.get_concrete_function(constant_op.constant(1.0)) graph_function = foo.get_concrete_function(constant_op.constant(1.0))
with self.assertRaisesRegexp( with self.assertRaises((TypeError, ValueError)):
ValueError, 'All inputs to `ConcreteFunction`s must be Tensors;.*'):
graph_function('Not a Tensor.') graph_function('Not a Tensor.')
def testSwapImplementationWithGrapplerPlugin(self): def testSwapImplementationWithGrapplerPlugin(self):
@ -3148,6 +3178,432 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
function.clear_function_callbacks() function.clear_function_callbacks()
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = constant_op.constant(1000)
b = constant_op.constant(200)
c = constant_op.constant(30)
d = {'a': a, 'b': b}
e = (c, 4)
# Test different argument signatures when constructing the concrete func.
for cf in [
f.get_concrete_function(d, e),
f.get_concrete_function(d, y=e),
f.get_concrete_function(y=e, x=d),
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
]:
# Test different calling conventions when calling the concrete func.
for output in [
cf(d, e), # structured signature
cf(d, y=e), # structured signature w/ kwarg
cf(y=e, x=d), # structured signature w/ 2 kwargs
cf(a, b, c), # flat signature
cf(x=a, x_1=b, y=c) # flat signature w/ kwargs
]:
self.assertIsInstance(output, tuple)
self.assertLen(output, 2)
self.assertAllEqual(output[0], 1200)
self.assertAllEqual(output[1], 34)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
b = (50, 3)
for cf in [ # argument y is bound to non-Tensor value (50, 3).
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 1253)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 3000, 'b': 200, 'c': 9000}
b = (constant_op.constant(30), 4)
for cf in [ # argument x is bound to non-tensor value `a`
f.get_concrete_function(a, b),
f.get_concrete_function(a, y=b),
f.get_concrete_function(x=a, y=b)
]:
for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
self.assertAllEqual(output[0] + output[1], 3234)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
@def_function.function
def f(x, y):
return (x['a'] + x['b'], y[0] + y[1])
a = {'a': 5000, 'b': 500}
b = (50, 5)
cf = f.get_concrete_function(a, b)
for output in [cf(), cf(a), cf(y=b)]:
self.assertAllEqual(output[0] + output[1], 5555)
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
# Check that keyword-only arguments are sorted appropriately, so that they
# feed the right tensor into each input.
@def_function.function
def g(**kwargs):
return string_ops.reduce_join(
string_ops.reduce_join(
ops.convert_to_tensor(sorted(kwargs.items())),
axis=1,
separator='='),
axis=0,
separator=', ')
s = constant_op.constant('s')
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
self.assertAllEqual(
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
self.assertAllEqual(
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (1, constant_op.constant(2)),
call_args=lambda: (1,),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
call_args=lambda: (1, 2),
error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2, 3),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
dict(
testcase_name='MissingKeywordOnlyArg',
conc_args=lambda: (1, 2),
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
call_args=lambda: (1, 2),
error=r'func\(x, y, \*, c\) missing required arguments: c'),
dict(
testcase_name='ExtraKeywordArg',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
error=r'func\(x, y\) got unexpected keyword arguments: c'),
dict(
testcase_name='ExpectedRaggedGotNest',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: ({
'a': constant_op.constant([1, 2, 3])
},),
error=r'func\(x, y\): argument x had incorrect type\n'
r' expected: RaggedTensor\n'
r" got: {'a': (Eager)?Tensor}"),
dict(
testcase_name='WrongRaggedRank',
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='WrongRaggedDType',
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='ExpectedDictGotTensor',
conc_args=lambda: ({
'a': constant_op.constant(1),
'b': constant_op.constant(1)
},),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='ExpectedTupleGotTensor',
conc_args=lambda:
((constant_op.constant(1), constant_op.constant(2)),),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\): argument x had incorrect type\n'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (5,),
error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
dict(
testcase_name='ExpectedIntGotDifferentInt',
conc_args=lambda: (5,),
call_args=lambda: (8,),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
r'value 5 in x, but was called with int value 8'),
dict(
testcase_name='ExpectedIntGotTensor',
conc_args=lambda: (5,),
call_args=lambda: (constant_op.constant(6),),
error=r'ConcreteFunction func\(x, y\) was constructed with int '
'value 5 in x, but was called with (Eager)?Tensor value .*'),
dict(
testcase_name='TwoValuesForArgument',
conc_args=lambda: (1, 2),
call_args=lambda: (1, 2),
call_kwargs=lambda: {'x': 3},
error=r"func\(x, y\) got two values for argument 'x'"),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionStructuredSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the structrued signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@def_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
with self.assertRaisesRegexp(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
# pylint: disable=g-long-lambda
@parameterized.named_parameters([
dict(
testcase_name='MissingArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
error=r'func\(x, y\) missing required arguments: y'),
dict(
testcase_name='TwoValuesForArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {
'x': constant_op.constant(1),
'y': constant_op.constant(1)
},
error=r"func\(x, y\) got two values for argument 'x'"),
dict(
testcase_name='ExtraPositionalArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
dict(
testcase_name='UnexpectedKeywordArg',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1),),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x\) got unexpected keyword arguments: c'),
dict(
testcase_name='MissingVararg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
constant_op.constant(3)),
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, varargs_0\) missing required '
r'arguments: varargs_0'),
dict(
testcase_name='MissingKeywordArg',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': constant_op.constant(1)},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
error=r'func\(x, y, c\) missing required arguments: c'),
dict(
testcase_name='ExpectedTensorGotInt',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_args=lambda: (5, constant_op.constant(2)),
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
r'a Tensor; got int \(5\)'),
dict(
testcase_name='WrongDType',
conc_args=lambda: (constant_op.constant(1),),
call_args=lambda: (constant_op.constant(1.0),),
exception=(ValueError, errors.InvalidArgumentError,
# on xla_gpu, we get InternalError instead.
errors.InternalError)),
dict(
testcase_name='MissingKeywordArgNestPiece',
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
call_kwargs=lambda: {'c': constant_op.constant(1)},
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
])
# pylint: enable=g-long-lambda
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionFlatSignatureError(self,
conc_args=(),
conc_kwargs=None,
call_args=(),
call_kwargs=None,
error='.*',
exception=TypeError):
"""Tests for errors in the flat signature.
Args:
conc_args: Positional arguments used for get_concrete_function.
conc_kwargs: Keyword arguments used for get_concrete_function.
call_args: Positional arguments used to call the function.
call_kwargs: Keyword arguments used to call the function.
error: Expected exception message.
exception: Expected exception type.
"""
conc_args = conc_args() if callable(conc_args) else conc_args
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
call_args = call_args() if callable(call_args) else call_args
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
self.assertIsInstance(conc_args, tuple)
self.assertIsInstance(call_args, tuple)
self.assertIsInstance(conc_kwargs, dict)
self.assertIsInstance(call_kwargs, dict)
@def_function.function
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
del y, varargs, kwargs
return x
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
# Remove _function_spec, to disable the structured signature.
conc._set_function_spec(None) # pylint: disable=protected-access
with self.assertRaisesRegexp(exception, error):
self.evaluate(conc(*call_args, **call_kwargs))
@test_util.run_in_graph_and_eager_modes
def testConcreteFunctionAmbiguousSignature(self):
# When both the flat & structured signatures are applicable, but they
# give different results, we use the structured signature. Note: we expect
# this to be extremely rare.
@def_function.function
def f(x, y):
return x * 10 + y
conc = f.get_concrete_function(
x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
self.assertAllEqual(result, 56)
def testPrettyPrintedSignature(self):
@def_function.function
def func(x, kangaroo=None, octopus=7):
del octopus, kangaroo
return x
scalar = constant_op.constant(5)
vector = constant_op.constant([10, 10, 20])
ragged = ragged_factory_ops.constant([[10, 20], [40]])
c1 = func.get_concrete_function(scalar, vector)
c1_summary = r'func\(x, kangaroo, octopus=7\)'
c1_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: int32 Tensor, shape=\(3,\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegexpMatches(
c1.pretty_printed_signature(verbose=False), c1_summary)
self.assertRegexpMatches(
c1.pretty_printed_signature(verbose=True),
c1_summary + '\n' + c1_details)
self.assertRegexpMatches(
repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
self.assertRegexpMatches(
str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
c2 = func.get_concrete_function(scalar, ragged, 3)
c2_summary = r'func\(x, kangaroo, octopus=3\)'
c2_details = (r' Args:\n'
r' x: int32 Tensor, shape=\(\)\n'
r' kangaroo: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r' int32 Tensor, shape=\(\)')
self.assertRegexpMatches(c2.pretty_printed_signature(),
c2_summary + '\n' + c2_details)
c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
c3_details = (r' Args:\n'
r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)\n'
r' Returns:\n'
r" {'a': <1>, 'b': \[<2>, <3>\]}\n"
r' <1>: int32 Tensor, shape=\(\)\n'
r' <2>: RaggedTensorSpec\(.*\)\n'
r' <3>: RaggedTensorSpec\(.*\)')
self.assertRegexpMatches(c3.pretty_printed_signature(),
c3_summary + '\n' + c3_details)
# pylint: disable=keyword-arg-before-vararg
@def_function.function
def func2(x, y=3, *args, **kwargs):
return (x, y, args, kwargs)
c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
c5 = func2.get_concrete_function(8, vector)
c5_summary = 'func2(x=8, y)'
self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
class MultiDeviceTest(test.TestCase, parameterized.TestCase): class MultiDeviceTest(test.TestCase, parameterized.TestCase):

View File

@ -43,6 +43,8 @@ def _is_tensor(t):
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the
# structured signature.
def _call_concrete_function(function, inputs): def _call_concrete_function(function, inputs):
"""Calls a restored Function with structured inputs. """Calls a restored Function with structured inputs.
@ -137,8 +139,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
input_signature = coder.decode_proto(function_spec_proto.input_signature) input_signature = coder.decode_proto(function_spec_proto.input_signature)
return function_lib.FunctionSpec(fullargspec=fullargspec, return function_lib.FunctionSpec(fullargspec=fullargspec,
is_method=False, is_method=False,
args_to_prepend=[],
kwargs_to_include={},
input_signature=input_signature) input_signature=input_signature)
@ -191,6 +191,8 @@ def recreate_function(saved_function, concrete_functions):
Args: Args:
saved_function: `SavedFunction` proto. saved_function: `SavedFunction` proto.
concrete_functions: map from function name to `ConcreteFunction`. concrete_functions: map from function name to `ConcreteFunction`.
As a side effect of this function, the `FunctionSpec` from
`saved_function` is added to each `ConcreteFunction` in this map.
Returns: Returns:
A `Function`. A `Function`.
@ -254,6 +256,9 @@ def recreate_function(saved_function, concrete_functions):
for concrete_function_name in saved_function.concrete_functions: for concrete_function_name in saved_function.concrete_functions:
concrete_function_objects.append(concrete_functions[concrete_function_name]) concrete_function_objects.append(concrete_functions[concrete_function_name])
for cf in concrete_function_objects:
cf._set_function_spec(function_spec) # pylint: disable=protected-access
restored_function = RestoredFunction( restored_function = RestoredFunction(
restored_function_body, restored_function_body,
restored_function_body.__name__, restored_function_body.__name__,
@ -317,6 +322,11 @@ def load_function_def_library(library, load_shared_name_suffix=None):
for dep in _list_function_deps(fdef, library_function_names): for dep in _list_function_deps(fdef, library_function_names):
functions[dep].add_to_graph(func_graph) functions[dep].add_to_graph(func_graph)
# We do not initialize the new ConcreteFunction's function_spec or
# arg_keywords here (which are used to parse the structured and flat
# signatures, respectively). function_spec is set up later by
# recreate_function(); and arg_keywords by setup_bare_concrete_function().
func = function_lib.ConcreteFunction(func_graph) func = function_lib.ConcreteFunction(func_graph)
func.add_to_graph(graph) func.add_to_graph(graph)

View File

@ -77,6 +77,10 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
def serialize_bare_concrete_function(concrete_function, name_map): def serialize_bare_concrete_function(concrete_function, name_map):
"""Build a SavedBareConcreteFunction.""" """Build a SavedBareConcreteFunction."""
# TODO(edloper): Currently, bare concrete functions don't have access to a
# function_spec, so they can't be called with the structured signature.
# Update the serialization to include a function_spec.
# pylint: disable=protected-access # pylint: disable=protected-access
name = name_map.get(compat.as_text(concrete_function.name), name = name_map.get(compat.as_text(concrete_function.name),
concrete_function.name) concrete_function.name)
@ -151,7 +155,8 @@ def wrap_cached_variables(concrete_function):
func_graph_module.func_graph_from_py_func( func_graph_module.func_graph_from_py_func(
None, wrap_function, args=tuple(args), kwargs={}, None, wrap_function, args=tuple(args), kwargs={},
func_graph=outer_graph) func_graph=outer_graph)
fn = defun.ConcreteFunction(outer_graph) fn = defun.ConcreteFunction(
outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access

View File

@ -173,12 +173,12 @@ class Loader(object):
# The original_outputs here had Tensors converted to TensorSpecs, so # The original_outputs here had Tensors converted to TensorSpecs, so
# the restored function's structured_outputs field will not be # the restored function's structured_outputs field will not be
# exactly the same. Fortunately the repacking logic cares only about # exactly the same. Fortunately the repacking logic cares only about
# the structure. # the structure; and the unpacking logic cares only about structure
# TODO(vbardiovsky): Should we just replicate the structures, with # and types.
# Nones instead of real objects?
concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access
concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access
coder.decode_proto(proto.canonicalized_input_signature)) coder.decode_proto(proto.canonicalized_input_signature))
concrete_function._initialize_function_spec() # pylint: disable=protected-access
def _setup_functions_captures(self): def _setup_functions_captures(self):
"""Setup captures and variables in restored functions.""" """Setup captures and variables in restored functions."""