Functions returned by the get_concrete_function method of tf.Function objects can now be called with arguments consistent with the original arguments or type specs passed to get_concrete_function. In particular:
* If a composite tensor (such as RaggedTensor or SparseTensor) was passed to `get_concrete_function`, then the returned function will accept a composite tensor of the same type for that argument. * If a nested structure (such as a list or dict) was passed to `get_concrete_function`, then the returned function will accept a value with the same nesting structure. Each tensor or composite tensor value must have the same type as was used in the original argument; and each non-Tensor value (such as bools or ints) must be equal value that was used in the original argument. * If a non-tensor value (such as a bool or int) was passed to `get_concrete_function`, then the returned function no longer deletes that argument; instead, it updates the argument's default value to the value that was passed to `get_concrete_function`. Passing in any other value will raise an exception. * Arguments are not renamed based on `TensorSpec.name`. For backwards compatibility, the functions returned by `get_concrete_function` will continue to accept arguments with the existing calling conventions (where nested structures and composite tensors are flattened; non-tensor arguments are deleted; suffixes are automatically added to disambiguate arguments with the same name; and TensorSpec.name is used to rename arguments). However, the preferred calling convention is the one that is consistent with the original arguments or type specs passed to `get_concrete_function`. PiperOrigin-RevId: 307398918 Change-Id: Ie4685b32d9f151c82f6c79a6c41379faa96b5ee8
This commit is contained in:
parent
0887fedd2d
commit
f39aab3092
@ -830,6 +830,13 @@ class Function(object):
|
|||||||
def function_spec(self):
|
def function_spec(self):
|
||||||
return self._function_spec
|
return self._function_spec
|
||||||
|
|
||||||
|
def pretty_printed_concrete_signatures(self, verbose=True):
|
||||||
|
joiner = "\n\n" if verbose else "\n"
|
||||||
|
return joiner.join([
|
||||||
|
c.pretty_printed_signature(verbose=verbose)
|
||||||
|
for c in self._list_all_concrete_functions()
|
||||||
|
])
|
||||||
|
|
||||||
def _initialize_uninitialized_variables(self, initializers):
|
def _initialize_uninitialized_variables(self, initializers):
|
||||||
"""Make and call a `ConcreteFunction` which initializes variables."""
|
"""Make and call a `ConcreteFunction` which initializes variables."""
|
||||||
|
|
||||||
@ -913,12 +920,8 @@ class Function(object):
|
|||||||
|
|
||||||
return initialize_variables.get_concrete_function()
|
return initialize_variables.get_concrete_function()
|
||||||
|
|
||||||
def _list_all_concrete_functions_for_serialization(self):
|
def _list_all_concrete_functions(self):
|
||||||
"""Returns all concrete functions for serialization.
|
"""Returns all concrete functions."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of instances of `ConcreteFunction`.
|
|
||||||
"""
|
|
||||||
if self.input_signature is not None:
|
if self.input_signature is not None:
|
||||||
self.get_concrete_function()
|
self.get_concrete_function()
|
||||||
concrete_functions = []
|
concrete_functions = []
|
||||||
@ -930,6 +933,15 @@ class Function(object):
|
|||||||
concrete_functions.extend(
|
concrete_functions.extend(
|
||||||
self._stateless_fn._function_cache.all_values())
|
self._stateless_fn._function_cache.all_values())
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
return concrete_functions
|
||||||
|
|
||||||
|
def _list_all_concrete_functions_for_serialization(self):
|
||||||
|
"""Returns all concrete functions for serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of instances of `ConcreteFunction`.
|
||||||
|
"""
|
||||||
|
concrete_functions = self._list_all_concrete_functions()
|
||||||
seen_signatures = []
|
seen_signatures = []
|
||||||
for concrete_function in concrete_functions:
|
for concrete_function in concrete_functions:
|
||||||
signature = concrete_function.structured_input_signature
|
signature = concrete_function.structured_input_signature
|
||||||
|
|||||||
@ -53,6 +53,7 @@ from tensorflow.python.framework import func_graph as func_graph_module
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import custom_gradient
|
from tensorflow.python.ops import custom_gradient
|
||||||
@ -340,7 +341,7 @@ class _InterpolateFunctionError(object):
|
|||||||
if t.name == compat.as_str(self._func.name):
|
if t.name == compat.as_str(self._func.name):
|
||||||
g = self._func.graph
|
g = self._func.graph
|
||||||
elif g:
|
elif g:
|
||||||
next_func = g._get_function(t.name)
|
next_func = g._get_function(t.name) # pylint: disable=protected-access
|
||||||
if next_func is not None and isinstance(next_func,
|
if next_func is not None and isinstance(next_func,
|
||||||
_EagerDefinedFunction):
|
_EagerDefinedFunction):
|
||||||
g = next_func.graph
|
g = next_func.graph
|
||||||
@ -1499,6 +1500,12 @@ class _ForwardBackwardCall(object):
|
|||||||
flat_outputs, self._inference_args, self._input_tangents)
|
flat_outputs, self._inference_args, self._input_tangents)
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinel value used by with ConcreteFunction's structured signature to
|
||||||
|
# indicate that a non-tensor parameter should use the value that was
|
||||||
|
# specified when the concrete function was created.
|
||||||
|
_BOUND_VALUE = object()
|
||||||
|
|
||||||
|
|
||||||
class ConcreteFunction(object):
|
class ConcreteFunction(object):
|
||||||
"""Callable object encapsulating a function definition and its gradient.
|
"""Callable object encapsulating a function definition and its gradient.
|
||||||
|
|
||||||
@ -1506,7 +1513,11 @@ class ConcreteFunction(object):
|
|||||||
is differentiable under `tf.GradientTape` objects.
|
is differentiable under `tf.GradientTape` objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, func_graph, attrs=None, shared_func_graph=True):
|
def __init__(self,
|
||||||
|
func_graph,
|
||||||
|
attrs=None,
|
||||||
|
shared_func_graph=True,
|
||||||
|
function_spec=None):
|
||||||
"""Initialize a `ConcreteFunction`.
|
"""Initialize a `ConcreteFunction`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1517,16 +1528,25 @@ class ConcreteFunction(object):
|
|||||||
shared_func_graph: If False, the ConcreteFunction takes ownership of
|
shared_func_graph: If False, the ConcreteFunction takes ownership of
|
||||||
`func_graph` and will break reference cycles when it is deleted. This
|
`func_graph` and will break reference cycles when it is deleted. This
|
||||||
makes the FuncGraph inoperable.
|
makes the FuncGraph inoperable.
|
||||||
|
function_spec: FunctionSpec for the original function. If not specified,
|
||||||
|
then this ConcreteFunction may only be called using the flat signature.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If number of input_placeholders is not equal to the number
|
ValueError: If number of input_placeholders is not equal to the number
|
||||||
of function inputs.
|
of function inputs.
|
||||||
"""
|
"""
|
||||||
|
# _arg_keywords and _num_positional_args define the flat signature. They
|
||||||
|
# are assigned after construction.
|
||||||
self._arg_keywords = None
|
self._arg_keywords = None
|
||||||
self._num_positional_args = None
|
self._num_positional_args = None
|
||||||
|
|
||||||
self._func_graph = func_graph
|
self._func_graph = func_graph
|
||||||
self._captured_inputs = self._func_graph.external_captures
|
self._captured_inputs = self._func_graph.external_captures
|
||||||
self._captured_closures = self._func_graph.deferred_external_captures
|
self._captured_closures = self._func_graph.deferred_external_captures
|
||||||
|
|
||||||
|
# function_spec defines the structured signature.
|
||||||
|
self._set_function_spec(function_spec)
|
||||||
|
|
||||||
if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
|
if attrs and IMPLEMENTS_ATTRIBUTE_NAME in attrs:
|
||||||
# The alternative is to silently drop "implements" tag
|
# The alternative is to silently drop "implements" tag
|
||||||
# but it seems likely it would lead to hard to catch bugs.
|
# but it seems likely it would lead to hard to catch bugs.
|
||||||
@ -1576,6 +1596,52 @@ class ConcreteFunction(object):
|
|||||||
# building gradients.
|
# building gradients.
|
||||||
self._inference_function = self._delayed_rewrite_functions.forward()
|
self._inference_function = self._delayed_rewrite_functions.forward()
|
||||||
|
|
||||||
|
def _set_function_spec(self, function_spec):
|
||||||
|
"""Enables the structured signature by supplying a function_spec."""
|
||||||
|
self._function_spec = None
|
||||||
|
self._pre_initialized_function_spec = function_spec
|
||||||
|
|
||||||
|
# Note: when ConcreteFunctions are built by recreate_function() in
|
||||||
|
# function_deserialization.py, they don't have a structured_input_signature
|
||||||
|
# yet. In that case, _initialize_function_spec() gets called by
|
||||||
|
# _setup_functions_structures() in load.py.
|
||||||
|
if (function_spec is not None and
|
||||||
|
self.structured_input_signature is not None):
|
||||||
|
self._initialize_function_spec()
|
||||||
|
|
||||||
|
def _initialize_function_spec(self):
|
||||||
|
"""Updates `self._function_spec` to include varargs and bound variables.
|
||||||
|
|
||||||
|
Adds new positional arguments for any varargs (i.e., for args that are
|
||||||
|
in `structured_input_signature`, but not in the original fullargspec.args).
|
||||||
|
|
||||||
|
Replaces `defaults` and `kwonlydefaults` with the `_BOUND_VALUE`, for
|
||||||
|
all args and kwargs in `structured_input_signature`.
|
||||||
|
|
||||||
|
Sets `varkw` and `varargs` to None.
|
||||||
|
"""
|
||||||
|
if self._pre_initialized_function_spec is None:
|
||||||
|
return # e.g., SavedBareConcreteFunction doesn't have function_spec yet.
|
||||||
|
assert not self._function_spec, "already initialized"
|
||||||
|
function_spec = self._pre_initialized_function_spec
|
||||||
|
args = function_spec.fullargspec.args
|
||||||
|
arg_specs, kwarg_specs = self.structured_input_signature
|
||||||
|
fullargspec = tf_inspect.FullArgSpec(
|
||||||
|
args=list(args) +
|
||||||
|
["<arg{}>".format(i + 1) for i in range(len(args), len(arg_specs))],
|
||||||
|
varargs=None,
|
||||||
|
varkw=None,
|
||||||
|
defaults=[_BOUND_VALUE] * len(arg_specs),
|
||||||
|
kwonlyargs=list(sorted(kwarg_specs)),
|
||||||
|
kwonlydefaults=dict((k, _BOUND_VALUE) for k in kwarg_specs),
|
||||||
|
annotations=function_spec.fullargspec.annotations)
|
||||||
|
self._function_spec = FunctionSpec(
|
||||||
|
fullargspec,
|
||||||
|
function_spec.is_method,
|
||||||
|
function_spec.input_signature,
|
||||||
|
function_spec.is_pure,
|
||||||
|
name=self._func_graph.name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def variables(self):
|
def variables(self):
|
||||||
"""Sequence of variables for this function."""
|
"""Sequence of variables for this function."""
|
||||||
@ -1589,15 +1655,44 @@ class ConcreteFunction(object):
|
|||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Executes the wrapped function.
|
"""Executes the wrapped function.
|
||||||
|
|
||||||
|
ConcreteFunctions have two signatures:
|
||||||
|
|
||||||
|
* The signature of the original function wrapped by this ConcreteFunction.
|
||||||
|
* A flat signature, where each argument accepts a single Tensor.
|
||||||
|
|
||||||
|
The original function signature is generally preferred, but the flat input
|
||||||
|
signature is supported for backward compatibility.
|
||||||
|
|
||||||
|
### Original Function Signature
|
||||||
|
|
||||||
|
When calling a ConcreteFunction with the signature of the original function,
|
||||||
|
each argument must match the type or value that was used when the
|
||||||
|
ConcreteFunction's graph was traced. In particular:
|
||||||
|
|
||||||
|
* Tensor arguments (including CompositeTensors, such as RaggedTensor) must
|
||||||
|
have matching `TypeSpec`s.
|
||||||
|
* Non-Tensor arguments (such as booleans or ints) must have equal values.
|
||||||
|
* Nested arguments (such as lists, tuples, or dictionaries) must have the
|
||||||
|
same nesting structure; and each nested value must have a matching type
|
||||||
|
or value.
|
||||||
|
|
||||||
|
The default value for any arguments that were traced with non-Tensor values
|
||||||
|
is the value that was used in the trace. Arguments that were traced with
|
||||||
|
tensor arguments do not have a default value (even if the original function
|
||||||
|
had a default value for that argument).
|
||||||
|
|
||||||
|
### Flat Signature
|
||||||
|
|
||||||
|
When calling a ConcreteFunction with the flat signature, the arguments
|
||||||
|
correspond to the flattened component tensors of the arguments that were
|
||||||
|
used to construct the ConcreteFunction. Parameter names are assigned based
|
||||||
|
on `TensorSpec.name` (when specified) or the original argument names (with
|
||||||
|
suffixes automatically added for nested arguments or composite tensors with
|
||||||
|
multiple components).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: Tensors or Variables. Positional arguments are only accepted when
|
*args: Positional arguments to the concrete function.
|
||||||
they correspond one-to-one with arguments of the traced Python function.
|
**kwargs: Keyword arguments to the concrete function.
|
||||||
**kwargs: Tensors or Variables specified by name. When
|
|
||||||
`get_concrete_function` was called to create this `ConcreteFunction`,
|
|
||||||
each Tensor input was given a name, defaulting to the name of the Python
|
|
||||||
function's argument but possibly overridden by the `name=` argument to
|
|
||||||
`tf.TensorSpec`. These names become the argument names for the concrete
|
|
||||||
function.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The result of applying the TF function on the given Tensors.
|
The result of applying the TF function on the given Tensors.
|
||||||
@ -1605,9 +1700,7 @@ class ConcreteFunction(object):
|
|||||||
Raises:
|
Raises:
|
||||||
AssertionError: If this `ConcreteFunction` was not created through
|
AssertionError: If this `ConcreteFunction` was not created through
|
||||||
`get_concrete_function`.
|
`get_concrete_function`.
|
||||||
ValueError: If arguments contains anything other than Tensors or
|
TypeError: If the arguments do not match the function's signature.
|
||||||
Variables.
|
|
||||||
TypeError: For invalid positional/keyword argument combinations.
|
|
||||||
"""
|
"""
|
||||||
return self._call_impl(args, kwargs)
|
return self._call_impl(args, kwargs)
|
||||||
|
|
||||||
@ -1615,40 +1708,174 @@ class ConcreteFunction(object):
|
|||||||
"""See `__call__` for details."""
|
"""See `__call__` for details."""
|
||||||
with traceme.TraceMe(self._func_graph.name,
|
with traceme.TraceMe(self._func_graph.name,
|
||||||
tf_function_call="concrete"):
|
tf_function_call="concrete"):
|
||||||
if self._arg_keywords is None or self._num_positional_args is None:
|
# Construct the list of input tensors: check if the structured signature
|
||||||
raise AssertionError(
|
# applies first; and if not, then use the flat signature.
|
||||||
"Tried to call a concrete function obtained from an internal API "
|
if self._function_spec is not None:
|
||||||
"through the public interface. Use get_concrete_function instead.")
|
|
||||||
if len(args) > self._num_positional_args:
|
|
||||||
raise TypeError(
|
|
||||||
("Expected at most {} positional arguments (and the rest keywords, "
|
|
||||||
"of {}), got {}. When calling a concrete function, positional "
|
|
||||||
"arguments may not be bound to Tensors within nested structures."
|
|
||||||
).format(self._num_positional_args, self._arg_keywords, args))
|
|
||||||
args = list(args)
|
|
||||||
for keyword in self._arg_keywords[len(args):]:
|
|
||||||
try:
|
try:
|
||||||
args.append(kwargs.pop(compat.as_str(keyword)))
|
return self._call_with_structured_signature(args, kwargs,
|
||||||
except KeyError:
|
cancellation_manager)
|
||||||
specified_keywords = (list(self._arg_keywords[:len(args)])
|
except TypeError as structured_err:
|
||||||
+ list(kwargs.keys()))
|
try:
|
||||||
raise TypeError(
|
return self._call_with_flat_signature(args, kwargs,
|
||||||
"Expected argument names {} but got values for {}. Missing: {}."
|
cancellation_manager)
|
||||||
.format(
|
except TypeError:
|
||||||
list(self._arg_keywords),
|
raise structured_err
|
||||||
specified_keywords,
|
|
||||||
list(set(self._arg_keywords) - set(specified_keywords))))
|
|
||||||
if kwargs:
|
|
||||||
positional_arg_keywords = set(self._arg_keywords[:len(args)])
|
|
||||||
for unused_key in kwargs:
|
|
||||||
if unused_key in positional_arg_keywords:
|
|
||||||
raise TypeError("Got two values for keyword '{}'.".format(
|
|
||||||
unused_key))
|
|
||||||
raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
|
|
||||||
list(kwargs.keys()), list(self._arg_keywords)))
|
|
||||||
return self._call_flat(args, self.captured_inputs, cancellation_manager)
|
|
||||||
|
|
||||||
def _filtered_call(self, args, kwargs):
|
return self._call_with_flat_signature(args, kwargs, cancellation_manager)
|
||||||
|
|
||||||
|
def _call_with_flat_signature(self, args, kwargs, cancellation_manager):
|
||||||
|
"""Executes the wrapped function with the flat signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Positional arguments to the concrete function.
|
||||||
|
kwargs: Keyword arguments to the concrete function.
|
||||||
|
cancellation_manager: A `CancellationManager` that can be used to cancel
|
||||||
|
function invocation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of applying the function on the Tensors/Variables contained in
|
||||||
|
`args` and `kwargs`.
|
||||||
|
Raises:
|
||||||
|
TypeError: if `args` and `kwargs` do not match the flat signature of this
|
||||||
|
`ConcreteFunction`.
|
||||||
|
"""
|
||||||
|
if len(args) > self._num_positional_args:
|
||||||
|
raise TypeError(
|
||||||
|
"{} takes {} positional arguments but {} were given".format(
|
||||||
|
self._flat_signature_summary(), self._num_positional_args,
|
||||||
|
len(args)))
|
||||||
|
args = list(args)
|
||||||
|
kwargs = dict(kwargs)
|
||||||
|
for keyword in self._arg_keywords[len(args):]:
|
||||||
|
try:
|
||||||
|
args.append(kwargs.pop(compat.as_str(keyword)))
|
||||||
|
except KeyError:
|
||||||
|
specified_keywords = (
|
||||||
|
list(self._arg_keywords[:len(args)]) + list(kwargs.keys()))
|
||||||
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
|
self._flat_signature_summary(), ", ".join(
|
||||||
|
sorted(set(self._arg_keywords) - set(specified_keywords)))))
|
||||||
|
if kwargs:
|
||||||
|
positional_arg_keywords = set(self._arg_keywords[:len(args)])
|
||||||
|
for unused_key in kwargs:
|
||||||
|
if unused_key in positional_arg_keywords:
|
||||||
|
raise TypeError("{} got two values for argument '{}'".format(
|
||||||
|
self._flat_signature_summary(), unused_key))
|
||||||
|
raise TypeError("{} got unexpected keyword arguments: {}.".format(
|
||||||
|
self._flat_signature_summary(), ", ".join(sorted(kwargs))))
|
||||||
|
|
||||||
|
for i, arg in enumerate(args):
|
||||||
|
if not isinstance(
|
||||||
|
arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)):
|
||||||
|
raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
|
||||||
|
"got {} ({})".format(self._flat_signature_summary(), i,
|
||||||
|
type(arg).__name__, str(arg)))
|
||||||
|
return self._call_flat(args, self.captured_inputs, cancellation_manager)
|
||||||
|
|
||||||
|
def _call_with_structured_signature(self, args, kwargs, cancellation_manager):
|
||||||
|
"""Executes the wrapped function with the structured signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Positional arguments to the concrete function.
|
||||||
|
kwargs: Keyword arguments to the concrete function.
|
||||||
|
cancellation_manager: A `CancellationManager` that can be used to cancel
|
||||||
|
function invocation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The result of applying the function on the Tensors/Variables contained in
|
||||||
|
`args` and `kwargs`.
|
||||||
|
Raises:
|
||||||
|
TypeError: if `args` and `kwargs` do not match the structured signature
|
||||||
|
of this `ConcreteFunction`.
|
||||||
|
"""
|
||||||
|
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||||
|
*args, **kwargs)
|
||||||
|
self._structured_signature_check_missing_args(args, kwargs)
|
||||||
|
self._structured_signature_check_unexpected_args(args, kwargs)
|
||||||
|
self._structured_signature_check_arg_types(args, kwargs)
|
||||||
|
return self._filtered_call(args, kwargs, cancellation_manager)
|
||||||
|
|
||||||
|
def _structured_signature_check_missing_args(self, args, kwargs):
|
||||||
|
"""Raises a TypeError if any args are missing."""
|
||||||
|
arg_specs, kwarg_specs = self.structured_input_signature
|
||||||
|
missing_arguments = []
|
||||||
|
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
|
||||||
|
if arg is _BOUND_VALUE and _contains_type_spec(spec):
|
||||||
|
missing_arguments.append(self._function_spec.arg_names[i])
|
||||||
|
for (name, arg) in kwargs.items():
|
||||||
|
if arg is _BOUND_VALUE and _contains_type_spec(kwarg_specs[name]):
|
||||||
|
missing_arguments.append(name)
|
||||||
|
if missing_arguments:
|
||||||
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
|
self._structured_signature_summary(),
|
||||||
|
", ".join(sorted(missing_arguments))))
|
||||||
|
|
||||||
|
def _structured_signature_check_unexpected_args(self, args, kwargs):
|
||||||
|
"""Raises a TypeError if there are any extra args."""
|
||||||
|
arg_specs, kwarg_specs = self.structured_input_signature
|
||||||
|
if len(args) > len(arg_specs):
|
||||||
|
raise TypeError(
|
||||||
|
"{} takes {} positional arguments but {} were given".format(
|
||||||
|
self._structured_signature_summary(),
|
||||||
|
len(self._function_spec.arg_names), len(args)))
|
||||||
|
if len(kwargs) > len(kwarg_specs):
|
||||||
|
extra_args = set(kwargs) - set(kwarg_specs)
|
||||||
|
raise TypeError("{} got unexpected keyword arguments: {}".format(
|
||||||
|
self._structured_signature_summary(), ", ".join(extra_args)))
|
||||||
|
|
||||||
|
def _structured_signature_check_arg_types(self, args, kwargs):
|
||||||
|
"""Raises a TypeError if any args have the wrong type."""
|
||||||
|
# Check argument types
|
||||||
|
arg_specs, kwarg_specs = self.structured_input_signature
|
||||||
|
for i, (arg, spec) in enumerate(zip(args, arg_specs)):
|
||||||
|
name = self._function_spec.arg_names[i]
|
||||||
|
self._structured_signature_check_arg_type(arg, spec, name)
|
||||||
|
for (name, arg) in kwargs.items():
|
||||||
|
self._structured_signature_check_arg_type(arg, kwarg_specs[name], name)
|
||||||
|
|
||||||
|
def _structured_signature_check_arg_type(self, arg, spec, name):
|
||||||
|
"""Raise TypeError if `arg`'s type doesn't match `spec`."""
|
||||||
|
if arg is _BOUND_VALUE:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check the overall nested structure of the argument.
|
||||||
|
try:
|
||||||
|
nest.assert_same_structure(arg, spec, expand_composites=True)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
try:
|
||||||
|
nest.assert_same_structure(arg, spec, expand_composites=False)
|
||||||
|
expected, got = spec, arg
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
expected, got = _structure_summary(spec), _structure_summary(arg)
|
||||||
|
raise TypeError("{}: argument {} had incorrect type\n"
|
||||||
|
" expected: {}\n got: {}".format(
|
||||||
|
self._structured_signature_summary(), name, expected,
|
||||||
|
got))
|
||||||
|
|
||||||
|
# Check the type for each leaf in the nested structure.
|
||||||
|
arg_pieces = nest.flatten(arg, expand_composites=True)
|
||||||
|
spec_pieces = nest.flatten(spec, expand_composites=True)
|
||||||
|
for (arg_piece, spec_piece) in zip(arg_pieces, spec_pieces):
|
||||||
|
if isinstance(spec_piece, tensor_spec.DenseSpec):
|
||||||
|
# TODO(edloper): Consider calling convert_to_tensor on non-tensor
|
||||||
|
# values here. That would match the behavior of
|
||||||
|
# _call_concrete_function() in function_deserialization.py. If
|
||||||
|
# we do, then we need to change the nest assert_same_structure and
|
||||||
|
# flatten calls above to use shallow variants.
|
||||||
|
tensor_types = (ops.Tensor, resource_variable_ops.BaseResourceVariable)
|
||||||
|
if not isinstance(arg_piece, tensor_types):
|
||||||
|
raise TypeError(
|
||||||
|
"{} expected a Tensor in {}, but got {} value {}".format(
|
||||||
|
self._structured_signature_summary(), name,
|
||||||
|
type(arg_piece).__name__, arg_piece))
|
||||||
|
elif arg_piece is not _BOUND_VALUE and arg_piece != spec_piece:
|
||||||
|
raise TypeError("ConcreteFunction {} was constructed with {} value "
|
||||||
|
"{} in {}, but was called with {} value {}".format(
|
||||||
|
self._structured_signature_summary(),
|
||||||
|
type(spec_piece).__name__, spec_piece, name,
|
||||||
|
type(arg_piece).__name__, arg_piece))
|
||||||
|
|
||||||
|
def _filtered_call(self, args, kwargs, cancellation_manager=None):
|
||||||
"""Executes the function, filtering arguments from the Python function.
|
"""Executes the function, filtering arguments from the Python function.
|
||||||
|
|
||||||
Objects aside from Tensors, CompositeTensors, and Variables are ignored.
|
Objects aside from Tensors, CompositeTensors, and Variables are ignored.
|
||||||
@ -1657,6 +1884,8 @@ class ConcreteFunction(object):
|
|||||||
Args:
|
Args:
|
||||||
args: Canonicalized positional arguments of the Python function.
|
args: Canonicalized positional arguments of the Python function.
|
||||||
kwargs: Canonicalized keyword arguments of the Python function.
|
kwargs: Canonicalized keyword arguments of the Python function.
|
||||||
|
cancellation_manager: (Optional.) A `CancellationManager` that can be
|
||||||
|
used to cancel function invocation.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The result of applying the function on the Tensors/Variables contained in
|
The result of applying the function on the Tensors/Variables contained in
|
||||||
@ -1666,7 +1895,8 @@ class ConcreteFunction(object):
|
|||||||
(t for t in nest.flatten((args, kwargs), expand_composites=True)
|
(t for t in nest.flatten((args, kwargs), expand_composites=True)
|
||||||
if isinstance(t, (ops.Tensor,
|
if isinstance(t, (ops.Tensor,
|
||||||
resource_variable_ops.BaseResourceVariable))),
|
resource_variable_ops.BaseResourceVariable))),
|
||||||
self.captured_inputs)
|
captured_inputs=self.captured_inputs,
|
||||||
|
cancellation_manager=cancellation_manager)
|
||||||
|
|
||||||
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
|
def _call_flat(self, args, captured_inputs, cancellation_manager=None):
|
||||||
"""Executes the wrapped function.
|
"""Executes the wrapped function.
|
||||||
@ -1795,7 +2025,26 @@ class ConcreteFunction(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def structured_input_signature(self):
|
def structured_input_signature(self):
|
||||||
"""Returns structured signature of the original function."""
|
"""Returns structured signature for this concrete function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple `(args, kwargs)`, where:
|
||||||
|
|
||||||
|
* `args` is a tuple that specifies the expected type or value each for
|
||||||
|
positional argument.
|
||||||
|
* `kwargs` is a dictionary that specifies the expected type or value
|
||||||
|
for each keyword-only argument.
|
||||||
|
|
||||||
|
The type or value for each argument is specified using one of the
|
||||||
|
following:
|
||||||
|
|
||||||
|
* A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native
|
||||||
|
value is expected.
|
||||||
|
* A Python value, such as an integer, indicating that an equal value
|
||||||
|
is expected.
|
||||||
|
* A nested structure of `tf.TypeSpec`s and Python values, indicating
|
||||||
|
that a corresponding nested structure is expected.
|
||||||
|
"""
|
||||||
return self._func_graph.structured_input_signature
|
return self._func_graph.structured_input_signature
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1982,6 +2231,103 @@ class ConcreteFunction(object):
|
|||||||
ret.attr[name].CopyFrom(value)
|
ret.attr[name].CopyFrom(value)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def _structured_signature_summary(self, default_values=False):
|
||||||
|
"""Returns a string summarizing this function's structured signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_values: If true, then include default values in the signature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `string`.
|
||||||
|
"""
|
||||||
|
# Note: we can't just use self._funcion_spec.signature_summary(), because
|
||||||
|
# that would show "_BOUND_VALUE" as the default value for all arguments.
|
||||||
|
assert self._function_spec is not None
|
||||||
|
arg_specs, kwarg_specs = self.structured_input_signature
|
||||||
|
arg_names = list(self._function_spec.arg_names)
|
||||||
|
if default_values:
|
||||||
|
for i in range(len(arg_names)):
|
||||||
|
if not _contains_type_spec(arg_specs[i]):
|
||||||
|
arg_names[i] += "={}".format(arg_specs[i])
|
||||||
|
if kwarg_specs:
|
||||||
|
arg_names.append("*")
|
||||||
|
for name, spec in kwarg_specs.items():
|
||||||
|
arg_names.append(name)
|
||||||
|
if default_values and not _contains_type_spec(spec):
|
||||||
|
arg_names[-1] += "={}".format(spec)
|
||||||
|
signature = "{}({})".format(self._func_graph.name, ", ".join(arg_names))
|
||||||
|
|
||||||
|
return signature
|
||||||
|
|
||||||
|
def _flat_signature_summary(self):
|
||||||
|
"""Returns a string summarizing this function's flat signature."""
|
||||||
|
assert self._arg_keywords is not None
|
||||||
|
assert self._num_positional_args is not None
|
||||||
|
arg_names = self._arg_keywords
|
||||||
|
if self._num_positional_args > len(arg_names):
|
||||||
|
arg_names.extend(
|
||||||
|
"<arg{}>".format(i + 1)
|
||||||
|
for i in range(len(arg_names), self._num_positional_args))
|
||||||
|
return "{}({})".format(self._func_graph.name, ", ".join(arg_names))
|
||||||
|
|
||||||
|
def pretty_printed_signature(self, verbose=True):
|
||||||
|
"""Returns a string summarizing the signature of this concrete function."""
|
||||||
|
if not verbose:
|
||||||
|
return self._structured_signature_summary(default_values=True)
|
||||||
|
|
||||||
|
def pretty_print_spec(spec):
|
||||||
|
"""Returns a string describing the spec for a single argument."""
|
||||||
|
if isinstance(spec, tensor_spec.TensorSpec):
|
||||||
|
return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape)
|
||||||
|
elif nest.is_sequence(spec):
|
||||||
|
pieces = nest.flatten(spec, expand_composites=False)
|
||||||
|
markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))]
|
||||||
|
structure = nest.pack_sequence_as(spec, markers)
|
||||||
|
result = "{}".format(structure)
|
||||||
|
for (marker, piece) in zip(markers, pieces):
|
||||||
|
result += "\n {}: {}".format(marker, pretty_print_spec(piece))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return repr(spec)
|
||||||
|
|
||||||
|
lines = [self._structured_signature_summary(default_values=True)]
|
||||||
|
arg_specs, kwarg_specs = self.structured_input_signature
|
||||||
|
names = list(self._function_spec.arg_names)
|
||||||
|
names.extend(sorted(kwarg_specs))
|
||||||
|
specs = list(arg_specs) + list(kwarg_specs.values())
|
||||||
|
# note: we can skip bound args, since we already displayed thier bound
|
||||||
|
# value in the signature summary.
|
||||||
|
arg_details = []
|
||||||
|
for (name, spec) in zip(names, specs):
|
||||||
|
if _contains_type_spec(spec):
|
||||||
|
arg_details.append(" {}: {}".format(name, pretty_print_spec(spec)))
|
||||||
|
if arg_details:
|
||||||
|
lines.append(" Args:")
|
||||||
|
lines.extend(arg_details)
|
||||||
|
lines.append(" Returns:")
|
||||||
|
lines.append(" {}".format(
|
||||||
|
pretty_print_spec(
|
||||||
|
nest.map_structure(type_spec.type_spec_from_value,
|
||||||
|
self.structured_outputs))))
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
if self._function_spec is not None:
|
||||||
|
return "<ConcreteFunction {} at 0x{:X}>".format(
|
||||||
|
self.pretty_printed_signature(verbose=False), id(self))
|
||||||
|
elif not (self._num_positional_args is None or self._arg_keywords is None):
|
||||||
|
return "<ConcreteFunction {} at 0x{:X}>".format(
|
||||||
|
self._flat_signature_summary(), id(self))
|
||||||
|
else:
|
||||||
|
return object.__repr__(self)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
if self._function_spec is not None:
|
||||||
|
return "ConcreteFunction {}".format(self.pretty_printed_signature())
|
||||||
|
else:
|
||||||
|
return self.__repr__()
|
||||||
|
|
||||||
|
|
||||||
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
_pywrap_utils.RegisterType("Tensor", ops.Tensor)
|
||||||
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
|
_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor)
|
||||||
@ -2075,17 +2421,37 @@ class FunctionSpec(object):
|
|||||||
kwonlydefaults={},
|
kwonlydefaults={},
|
||||||
annotations=fullargspec.annotations)
|
annotations=fullargspec.annotations)
|
||||||
is_method = tf_inspect.ismethod(python_function)
|
is_method = tf_inspect.ismethod(python_function)
|
||||||
return FunctionSpec(fullargspec, is_method, [], {}, input_signature,
|
|
||||||
is_pure=is_pure)
|
|
||||||
|
|
||||||
def __init__(self, fullargspec, is_method, args_to_prepend, kwargs_to_include,
|
# Get the function's name. Remove functools.partial wrappers if necessary.
|
||||||
input_signature, is_pure=False):
|
while isinstance(python_function, functools.partial):
|
||||||
|
python_function = python_function.func
|
||||||
|
name = getattr(python_function, "__name__", "f")
|
||||||
|
|
||||||
|
return FunctionSpec(
|
||||||
|
fullargspec, is_method, input_signature, is_pure=is_pure, name=name)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
fullargspec,
|
||||||
|
is_method,
|
||||||
|
input_signature,
|
||||||
|
is_pure=False,
|
||||||
|
name=None):
|
||||||
|
"""Constructs a FunctionSpec describing a python function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fullargspec: `tf_inspect.FullArgSpec` object describing the function.
|
||||||
|
is_method: True if the function is a method.
|
||||||
|
input_signature: a signature of the function (None, if variable)
|
||||||
|
is_pure: if True all input arguments (including variables and constants)
|
||||||
|
will be converted to tensors and no variable changes allowed.
|
||||||
|
name: Name of the function
|
||||||
|
"""
|
||||||
self._fullargspec = fullargspec
|
self._fullargspec = fullargspec
|
||||||
self._is_method = is_method
|
self._is_method = is_method
|
||||||
self._is_pure = is_pure
|
self._is_pure = is_pure
|
||||||
del args_to_prepend
|
|
||||||
del kwargs_to_include
|
# TODO(edloper): Include name when serializing for SavedModel?
|
||||||
self._default_values = fullargspec.defaults
|
self._name = name or "f"
|
||||||
|
|
||||||
if self._is_method:
|
if self._is_method:
|
||||||
# Remove `self`: default arguments shouldn't be matched to it.
|
# Remove `self`: default arguments shouldn't be matched to it.
|
||||||
@ -2098,21 +2464,21 @@ class FunctionSpec(object):
|
|||||||
# A cache mapping from argument name to index, for canonicalizing
|
# A cache mapping from argument name to index, for canonicalizing
|
||||||
# arguments that are called in a keyword-like fashion.
|
# arguments that are called in a keyword-like fashion.
|
||||||
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
||||||
self.arg_names = args
|
self._arg_names = args
|
||||||
self.vararg_name = fullargspec.varargs
|
|
||||||
|
|
||||||
# A cache mapping from arg index to default value, for canonicalization.
|
# A cache mapping from arg index to default value, for canonicalization.
|
||||||
offset = len(args) - len(self._default_values or [])
|
default_values = fullargspec.defaults
|
||||||
|
offset = len(args) - len(default_values or [])
|
||||||
self._arg_indices_to_default_values = {
|
self._arg_indices_to_default_values = {
|
||||||
offset + index: default
|
offset + index: default
|
||||||
for index, default in enumerate(self._default_values or [])
|
for index, default in enumerate(default_values or [])
|
||||||
}
|
}
|
||||||
if input_signature is None:
|
if input_signature is None:
|
||||||
self._input_signature = None
|
self._input_signature = None
|
||||||
else:
|
else:
|
||||||
if fullargspec.kwonlyargs:
|
if set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ()):
|
||||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||||
"function with keyword arguments when "
|
"function with keyword-only arguments when "
|
||||||
"input_signature is provided.")
|
"input_signature is provided.")
|
||||||
|
|
||||||
if not isinstance(input_signature, (tuple, list)):
|
if not isinstance(input_signature, (tuple, list)):
|
||||||
@ -2132,8 +2498,8 @@ class FunctionSpec(object):
|
|||||||
return self._is_method
|
return self._is_method
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def args_to_prepend(self):
|
def args_to_indices(self):
|
||||||
return self._args_to_prepend
|
return self._args_to_indices
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kwargs_to_include(self):
|
def kwargs_to_include(self):
|
||||||
@ -2147,6 +2513,43 @@ class FunctionSpec(object):
|
|||||||
def flat_input_signature(self):
|
def flat_input_signature(self):
|
||||||
return self._flat_input_signature
|
return self._flat_input_signature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_pure(self):
|
||||||
|
return self._is_pure
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arg_names(self):
|
||||||
|
return self._arg_names
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vararg_name(self):
|
||||||
|
return self._fullargspec.varargs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def varkw_name(self):
|
||||||
|
return self._fullargspec.varkw
|
||||||
|
|
||||||
|
def signature_summary(self, default_values=False):
|
||||||
|
"""Returns a string summarizing this function's signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
default_values: If true, then include default values in the signature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `string`.
|
||||||
|
"""
|
||||||
|
args = list(self._arg_names)
|
||||||
|
if default_values:
|
||||||
|
for (i, default) in self._arg_indices_to_default_values.items():
|
||||||
|
args[i] += "={}".format(default)
|
||||||
|
if self._fullargspec.kwonlyargs:
|
||||||
|
args.append("*")
|
||||||
|
for arg_name in self._fullargspec.kwonlyargs:
|
||||||
|
args.append(arg_name)
|
||||||
|
if default_values and arg_name in self._fullargspec.kwonlydefaults:
|
||||||
|
args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
|
||||||
|
return "{}({})".format(self._name, ", ".join(args))
|
||||||
|
|
||||||
def _convert_variables_to_tensors(self, args, kwargs):
|
def _convert_variables_to_tensors(self, args, kwargs):
|
||||||
args = [ops.convert_to_tensor(x) for x in args]
|
args = [ops.convert_to_tensor(x) for x in args]
|
||||||
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
|
kwargs = {kw: ops.convert_to_tensor(x) for kw, x in kwargs.items()}
|
||||||
@ -2159,7 +2562,13 @@ class FunctionSpec(object):
|
|||||||
instance. In particular, we parse the varags and kwargs that the
|
instance. In particular, we parse the varags and kwargs that the
|
||||||
original function was called with into a tuple corresponding to the
|
original function was called with into a tuple corresponding to the
|
||||||
Python function's positional (named) arguments and a dictionary
|
Python function's positional (named) arguments and a dictionary
|
||||||
corresponding to its kwargs.
|
corresponding to its kwargs. Missing default arguments are added.
|
||||||
|
|
||||||
|
If this `FunctionSpec` has an input signature, then it is used to convert
|
||||||
|
arguments to tensors; otherwise, any inputs containing numpy arrays are
|
||||||
|
converted to tensors.
|
||||||
|
|
||||||
|
Additionally, any inputs containing numpy arrays are converted to Tensors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: The varargs this object was called with.
|
*args: The varargs this object was called with.
|
||||||
@ -2180,29 +2589,38 @@ class FunctionSpec(object):
|
|||||||
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
args, kwargs = self._convert_variables_to_tensors(args, kwargs)
|
||||||
if self._input_signature is not None:
|
if self._input_signature is not None:
|
||||||
if len(args) > len(self._input_signature):
|
if len(args) > len(self._input_signature):
|
||||||
raise TypeError(
|
raise TypeError("{} takes {} positional arguments (as specified by the "
|
||||||
"When input_signature is provided, only pass arguments "
|
"input_signature) but {} were given".format(
|
||||||
"covered by it. Received %d argument(s)." % len(args))
|
self.signature_summary(),
|
||||||
|
len(self._input_signature), len(args)))
|
||||||
for arg in six.iterkeys(kwargs):
|
for arg in six.iterkeys(kwargs):
|
||||||
index = self._args_to_indices.get(arg, None)
|
index = self._args_to_indices.get(arg, None)
|
||||||
if index is None:
|
if index is None:
|
||||||
raise TypeError(
|
raise TypeError("{} got unexpected keyword argument `{}`".format(
|
||||||
"Function got an unexpected keyword argument %s" % arg)
|
self.signature_summary(), arg))
|
||||||
if index >= len(self._input_signature):
|
if index >= len(self._input_signature):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"When input_signature is provided, only pass arguments "
|
"{} got keyword argument `{}` that was not included in "
|
||||||
"covered by it. Received argument %s." % arg)
|
"input_signature".format(self.signature_summary(), arg))
|
||||||
|
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
inputs = args
|
inputs = args
|
||||||
default_keys = sorted(self._arg_indices_to_default_values.keys())
|
if self._arg_indices_to_default_values:
|
||||||
if default_keys:
|
try:
|
||||||
assert min(default_keys) <= len(
|
inputs += tuple(
|
||||||
args), "Not enough arguments (%s, %s, %s)" % (args, default_keys,
|
self._arg_indices_to_default_values[i]
|
||||||
self.arg_names)
|
for i in range(len(args), len(self._arg_names)))
|
||||||
for index in default_keys:
|
except KeyError:
|
||||||
if index >= len(args):
|
missing_args = [
|
||||||
inputs += (self._arg_indices_to_default_values[index],)
|
self._arg_names[i]
|
||||||
|
for i in range(len(args), len(self._arg_names))
|
||||||
|
if i not in self._arg_indices_to_default_values
|
||||||
|
]
|
||||||
|
raise TypeError("{} missing required arguments: {}".format(
|
||||||
|
self.signature_summary(), ", ".join(missing_args)))
|
||||||
|
|
||||||
|
if self._fullargspec.kwonlydefaults:
|
||||||
|
kwargs.update(self._fullargspec.kwonlydefaults)
|
||||||
else:
|
else:
|
||||||
# Maps from index of arg to its corresponding value, according to `args`
|
# Maps from index of arg to its corresponding value, according to `args`
|
||||||
# and `kwargs`; seeded with the default values for the named args that
|
# and `kwargs`; seeded with the default values for the named args that
|
||||||
@ -2215,18 +2633,28 @@ class FunctionSpec(object):
|
|||||||
for arg, value in six.iteritems(kwargs):
|
for arg, value in six.iteritems(kwargs):
|
||||||
index = self._args_to_indices.get(arg, None)
|
index = self._args_to_indices.get(arg, None)
|
||||||
if index is not None:
|
if index is not None:
|
||||||
|
if index < len(args):
|
||||||
|
raise TypeError("{} got two values for argument '{}'".format(
|
||||||
|
self.signature_summary(), arg))
|
||||||
arg_indices_to_values[index] = value
|
arg_indices_to_values[index] = value
|
||||||
consumed_args.append(arg)
|
consumed_args.append(arg)
|
||||||
elif self._input_signature is not None:
|
|
||||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
|
||||||
"function with keyword arguments when "
|
|
||||||
"input_signature is provided.")
|
|
||||||
for arg in consumed_args:
|
for arg in consumed_args:
|
||||||
# After this loop, `kwargs` will only contain true keyword arguments, as
|
# After this loop, `kwargs` will only contain keyword_only arguments,
|
||||||
# opposed to named arguments called in a keyword-like fashion.
|
# and all positional_or_keyword arguments have been moved to `inputs`.
|
||||||
kwargs.pop(arg)
|
kwargs.pop(arg)
|
||||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||||
|
|
||||||
|
if kwargs and self._input_signature is not None:
|
||||||
|
raise TypeError(
|
||||||
|
"{} got unexpected keyword arguments: {}\n(Cannot define a "
|
||||||
|
"TensorFlow function from a Python function with keyword arguments "
|
||||||
|
"when input_signature is provided.)".format(
|
||||||
|
self.signature_summary(), ", ".join(kwargs)))
|
||||||
|
|
||||||
|
if self._fullargspec.kwonlydefaults:
|
||||||
|
for (kwarg, default) in self._fullargspec.kwonlydefaults.items():
|
||||||
|
kwargs.setdefault(kwarg, default)
|
||||||
|
|
||||||
if self._input_signature is None:
|
if self._input_signature is None:
|
||||||
inputs = _convert_numpy_inputs(inputs)
|
inputs = _convert_numpy_inputs(inputs)
|
||||||
kwargs = _convert_numpy_inputs(kwargs)
|
kwargs = _convert_numpy_inputs(kwargs)
|
||||||
@ -2447,6 +2875,7 @@ class Function(object):
|
|||||||
graph_function, _, _ = self._maybe_define_function(args, kwargs)
|
graph_function, _, _ = self._maybe_define_function(args, kwargs)
|
||||||
return graph_function
|
return graph_function
|
||||||
|
|
||||||
|
# XX TODO: make sure we fix up this path as well!?
|
||||||
def _get_concrete_function_internal(self, *args, **kwargs):
|
def _get_concrete_function_internal(self, *args, **kwargs):
|
||||||
"""Bypasses error checking when getting a graph function."""
|
"""Bypasses error checking when getting a graph function."""
|
||||||
graph_function = self._get_concrete_function_internal_garbage_collected(
|
graph_function = self._get_concrete_function_internal_garbage_collected(
|
||||||
@ -2664,6 +3093,7 @@ class Function(object):
|
|||||||
override_flat_arg_shapes=override_flat_arg_shapes,
|
override_flat_arg_shapes=override_flat_arg_shapes,
|
||||||
capture_by_value=self._capture_by_value),
|
capture_by_value=self._capture_by_value),
|
||||||
self._function_attributes,
|
self._function_attributes,
|
||||||
|
function_spec=self.function_spec,
|
||||||
# Tell the ConcreteFunction to clean up its graph once it goes out of
|
# Tell the ConcreteFunction to clean up its graph once it goes out of
|
||||||
# scope. This is not the default behavior since it gets used in some
|
# scope. This is not the default behavior since it gets used in some
|
||||||
# places (like Keras) where the FuncGraph lives longer than the
|
# places (like Keras) where the FuncGraph lives longer than the
|
||||||
@ -3350,3 +3780,30 @@ class ConcreteFunctionGarbageCollector(object):
|
|||||||
func_graph_module.dismantle_func_graph(self._func_graph)
|
func_graph_module.dismantle_func_graph(self._func_graph)
|
||||||
except: # pylint: disable=bare-except
|
except: # pylint: disable=bare-except
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _Marker(object):
|
||||||
|
"""Markers used to pretty-print nested args in function signatures."""
|
||||||
|
|
||||||
|
def __init__(self, s):
|
||||||
|
self._s = s
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self._s)
|
||||||
|
|
||||||
|
|
||||||
|
def _structure_summary(structure):
|
||||||
|
"""Displays a summary of the nesting structure of the given value."""
|
||||||
|
|
||||||
|
def type_name(x):
|
||||||
|
if isinstance(x, type_spec.TypeSpec):
|
||||||
|
return x.value_type.__name__
|
||||||
|
else:
|
||||||
|
return type(x).__name__
|
||||||
|
|
||||||
|
markers = [_Marker(type_name(v)) for v in nest.flatten(structure)]
|
||||||
|
return str(nest.pack_sequence_as(structure, markers))
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_type_spec(value):
|
||||||
|
return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from tensorflow.python.eager import cancellation
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import composite_tensor
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -50,6 +51,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_ops
|
from tensorflow.python.framework import test_ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.framework import type_spec
|
||||||
from tensorflow.python.layers import convolutional
|
from tensorflow.python.layers import convolutional
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
@ -65,6 +67,7 @@ from tensorflow.python.ops import list_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops import string_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
@ -99,6 +102,16 @@ def _example_indexed_slices_without_dense_shape():
|
|||||||
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
|
constant_op.constant([1, 2]), constant_op.constant([0, 1]))
|
||||||
|
|
||||||
|
|
||||||
|
def _spec_for_value(value):
|
||||||
|
"""Returns the (nested) TypeSpec for a value."""
|
||||||
|
if nest.is_sequence(value):
|
||||||
|
return nest.map_structure(_spec_for_value, value)
|
||||||
|
elif isinstance(value, (ops.Tensor, composite_tensor.CompositeTensor)):
|
||||||
|
return type_spec.type_spec_from_value(value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class FunctionTest(test.TestCase, parameterized.TestCase):
|
class FunctionTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -1789,6 +1802,18 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, 'incompatible'):
|
with self.assertRaisesRegexp(ValueError, 'incompatible'):
|
||||||
func([['wrong dtype']])
|
func([['wrong dtype']])
|
||||||
|
|
||||||
|
def testNoKeywordOnlyArgumentsWithInputSignature(self):
|
||||||
|
if sys.version_info[0] < 3:
|
||||||
|
self.skipTest('keyword_only arguments only exist in Python 3.')
|
||||||
|
|
||||||
|
func = eval('lambda x, *, y: x') # pylint: disable=eval-used
|
||||||
|
signature = [tensor_spec.TensorSpec(None, dtypes.int32)]
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Cannot define a TensorFlow function from a Python '
|
||||||
|
'function with keyword-only arguments when input_signature is '
|
||||||
|
'provided.'):
|
||||||
|
def_function.function(func, signature)
|
||||||
|
|
||||||
def testNestedInputSignatures(self):
|
def testNestedInputSignatures(self):
|
||||||
|
|
||||||
def expected_foo(a, b):
|
def expected_foo(a, b):
|
||||||
@ -1905,7 +1930,9 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
defined(array_ops.ones([2, 1]))
|
defined(array_ops.ones([2, 1]))
|
||||||
|
|
||||||
# Wrong number of arguments.
|
# Wrong number of arguments.
|
||||||
with self.assertRaisesRegexp(TypeError, r'Received 2 argument\(s\)'):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, r'takes 1 positional arguments \(as specified by the '
|
||||||
|
r'input_signature\) but 2 were given'):
|
||||||
defined(array_ops.ones([2]), array_ops.ones([2]))
|
defined(array_ops.ones([2]), array_ops.ones([2]))
|
||||||
with self.assertRaisesRegexp(ValueError,
|
with self.assertRaisesRegexp(ValueError,
|
||||||
'Structure of Python function inputs.*'):
|
'Structure of Python function inputs.*'):
|
||||||
@ -1946,10 +1973,14 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
return -1.0 * a
|
return -1.0 * a
|
||||||
|
|
||||||
x = constant_op.constant(1.0)
|
x = constant_op.constant(1.0)
|
||||||
with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, 'got keyword argument `training` '
|
||||||
|
'that was not included in input_signature'):
|
||||||
foo(x, training=True)
|
foo(x, training=True)
|
||||||
|
|
||||||
with self.assertRaisesRegexp(TypeError, 'only pass arguments'):
|
with self.assertRaisesRegexp(
|
||||||
|
TypeError, 'got keyword argument `training` '
|
||||||
|
'that was not included in input_signature'):
|
||||||
foo(x, training=False)
|
foo(x, training=False)
|
||||||
|
|
||||||
self.assertAllEqual(x.numpy(), foo(x).numpy())
|
self.assertAllEqual(x.numpy(), foo(x).numpy())
|
||||||
@ -2472,8 +2503,7 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
|
graph_function = foo.get_concrete_function(constant_op.constant(1.0))
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaises((TypeError, ValueError)):
|
||||||
ValueError, 'All inputs to `ConcreteFunction`s must be Tensors;.*'):
|
|
||||||
graph_function('Not a Tensor.')
|
graph_function('Not a Tensor.')
|
||||||
|
|
||||||
def testSwapImplementationWithGrapplerPlugin(self):
|
def testSwapImplementationWithGrapplerPlugin(self):
|
||||||
@ -3148,6 +3178,432 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
|
|||||||
function.clear_function_callbacks()
|
function.clear_function_callbacks()
|
||||||
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
|
self.assertEmpty(function._function_callbacks) # pylint:disable=protected-access
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionWithNestedTensorInputs(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
return (x['a'] + x['b'], y[0] + y[1])
|
||||||
|
|
||||||
|
a = constant_op.constant(1000)
|
||||||
|
b = constant_op.constant(200)
|
||||||
|
c = constant_op.constant(30)
|
||||||
|
d = {'a': a, 'b': b}
|
||||||
|
e = (c, 4)
|
||||||
|
|
||||||
|
# Test different argument signatures when constructing the concrete func.
|
||||||
|
for cf in [
|
||||||
|
f.get_concrete_function(d, e),
|
||||||
|
f.get_concrete_function(d, y=e),
|
||||||
|
f.get_concrete_function(y=e, x=d),
|
||||||
|
f.get_concrete_function(_spec_for_value(d), _spec_for_value(e)),
|
||||||
|
f.get_concrete_function(_spec_for_value(d), y=_spec_for_value(e)),
|
||||||
|
f.get_concrete_function(y=_spec_for_value(e), x=_spec_for_value(d))
|
||||||
|
]:
|
||||||
|
# Test different calling conventions when calling the concrete func.
|
||||||
|
for output in [
|
||||||
|
cf(d, e), # structured signature
|
||||||
|
cf(d, y=e), # structured signature w/ kwarg
|
||||||
|
cf(y=e, x=d), # structured signature w/ 2 kwargs
|
||||||
|
cf(a, b, c), # flat signature
|
||||||
|
cf(x=a, x_1=b, y=c) # flat signature w/ kwargs
|
||||||
|
]:
|
||||||
|
self.assertIsInstance(output, tuple)
|
||||||
|
self.assertLen(output, 2)
|
||||||
|
self.assertAllEqual(output[0], 1200)
|
||||||
|
self.assertAllEqual(output[1], 34)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionWithNestedNonTensorInputs(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
return (x['a'] + x['b'], y[0] + y[1])
|
||||||
|
|
||||||
|
a = {'a': constant_op.constant(1000), 'b': constant_op.constant(200)}
|
||||||
|
b = (50, 3)
|
||||||
|
|
||||||
|
for cf in [ # argument y is bound to non-Tensor value (50, 3).
|
||||||
|
f.get_concrete_function(a, b),
|
||||||
|
f.get_concrete_function(a, y=b),
|
||||||
|
f.get_concrete_function(x=a, y=b)
|
||||||
|
]:
|
||||||
|
for output in [cf(a), cf(x=a), cf(a, b), cf(x=a, y=b)]:
|
||||||
|
self.assertAllEqual(output[0] + output[1], 1253)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionWithBoundNestedNonTensorInputs(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
return (x['a'] + x['b'], y[0] + y[1])
|
||||||
|
|
||||||
|
a = {'a': 3000, 'b': 200, 'c': 9000}
|
||||||
|
b = (constant_op.constant(30), 4)
|
||||||
|
|
||||||
|
for cf in [ # argument x is bound to non-tensor value `a`
|
||||||
|
f.get_concrete_function(a, b),
|
||||||
|
f.get_concrete_function(a, y=b),
|
||||||
|
f.get_concrete_function(x=a, y=b)
|
||||||
|
]:
|
||||||
|
for output in [cf(a, b), cf(a, y=b), cf(y=b), cf(x=a, y=b)]:
|
||||||
|
self.assertAllEqual(output[0] + output[1], 3234)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionWithAllBoundNestedNonTensorInputs(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
return (x['a'] + x['b'], y[0] + y[1])
|
||||||
|
|
||||||
|
a = {'a': 5000, 'b': 500}
|
||||||
|
b = (50, 5)
|
||||||
|
|
||||||
|
cf = f.get_concrete_function(a, b)
|
||||||
|
for output in [cf(), cf(a), cf(y=b)]:
|
||||||
|
self.assertAllEqual(output[0] + output[1], 5555)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionStructuredSignatureKeywordOrder(self):
|
||||||
|
# Check that keyword-only arguments are sorted appropriately, so that they
|
||||||
|
# feed the right tensor into each input.
|
||||||
|
@def_function.function
|
||||||
|
def g(**kwargs):
|
||||||
|
return string_ops.reduce_join(
|
||||||
|
string_ops.reduce_join(
|
||||||
|
ops.convert_to_tensor(sorted(kwargs.items())),
|
||||||
|
axis=1,
|
||||||
|
separator='='),
|
||||||
|
axis=0,
|
||||||
|
separator=', ')
|
||||||
|
|
||||||
|
s = constant_op.constant('s')
|
||||||
|
g.get_concrete_function(q=s, a=s, p=s, r=s, v=s, m=s, l=s)
|
||||||
|
self.assertAllEqual(
|
||||||
|
g(m='a', r='b', v='c', q='d', l='e', a='f', p='g'),
|
||||||
|
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
|
||||||
|
self.assertAllEqual(
|
||||||
|
g(q='d', a='f', p='g', r='b', v='c', m='a', l='e'),
|
||||||
|
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
|
||||||
|
self.assertAllEqual(
|
||||||
|
g(a='f', l='e', m='a', p='g', q='d', r='b', v='c'),
|
||||||
|
b'a=f, l=e, m=a, p=g, q=d, r=b, v=c')
|
||||||
|
|
||||||
|
# pylint: disable=g-long-lambda
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingArg',
|
||||||
|
conc_args=lambda: (1, constant_op.constant(2)),
|
||||||
|
call_args=lambda: (1,),
|
||||||
|
error=r'func\(x, y\) missing required arguments: y'),
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingVararg',
|
||||||
|
conc_args=lambda: (1, 2, constant_op.constant(1.0)),
|
||||||
|
call_args=lambda: (1, 2),
|
||||||
|
error=r'func\(x, y, <arg3>\) missing required arguments: <arg3>'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExtraPositionalArg',
|
||||||
|
conc_args=lambda: (1, 2),
|
||||||
|
call_args=lambda: (1, 2, 3),
|
||||||
|
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingKeywordOnlyArg',
|
||||||
|
conc_args=lambda: (1, 2),
|
||||||
|
conc_kwargs=lambda: {'c': constant_op.constant(1.0)},
|
||||||
|
call_args=lambda: (1, 2),
|
||||||
|
error=r'func\(x, y, \*, c\) missing required arguments: c'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExtraKeywordArg',
|
||||||
|
conc_args=lambda: (1, 2),
|
||||||
|
call_args=lambda: (1, 2),
|
||||||
|
call_kwargs=lambda: {'c': constant_op.constant(1.0)},
|
||||||
|
error=r'func\(x, y\) got unexpected keyword arguments: c'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedRaggedGotNest',
|
||||||
|
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
|
||||||
|
call_args=lambda: ({
|
||||||
|
'a': constant_op.constant([1, 2, 3])
|
||||||
|
},),
|
||||||
|
error=r'func\(x, y\): argument x had incorrect type\n'
|
||||||
|
r' expected: RaggedTensor\n'
|
||||||
|
r" got: {'a': (Eager)?Tensor}"),
|
||||||
|
dict(
|
||||||
|
testcase_name='WrongRaggedRank',
|
||||||
|
conc_args=lambda: (ragged_factory_ops.constant([[1, 2], [3]]),),
|
||||||
|
call_args=lambda: (ragged_factory_ops.constant([[[1]]]),),
|
||||||
|
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||||
|
dict(
|
||||||
|
testcase_name='WrongRaggedDType',
|
||||||
|
conc_args=lambda: (ragged_factory_ops.constant([[1]]),),
|
||||||
|
call_args=lambda: (ragged_factory_ops.constant([[1.0]]),),
|
||||||
|
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedDictGotTensor',
|
||||||
|
conc_args=lambda: ({
|
||||||
|
'a': constant_op.constant(1),
|
||||||
|
'b': constant_op.constant(1)
|
||||||
|
},),
|
||||||
|
call_args=lambda: (constant_op.constant(1),),
|
||||||
|
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedTupleGotTensor',
|
||||||
|
conc_args=lambda:
|
||||||
|
((constant_op.constant(1), constant_op.constant(2)),),
|
||||||
|
call_args=lambda: (constant_op.constant(1),),
|
||||||
|
error=r'func\(x, y\): argument x had incorrect type\n'),
|
||||||
|
dict(
|
||||||
|
testcase_name='WrongDType',
|
||||||
|
conc_args=lambda: (constant_op.constant(1),),
|
||||||
|
call_args=lambda: (constant_op.constant(1.0),),
|
||||||
|
exception=(ValueError, errors.InvalidArgumentError,
|
||||||
|
# on xla_gpu, we get InternalError instead.
|
||||||
|
errors.InternalError)),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedTensorGotInt',
|
||||||
|
conc_args=lambda: (constant_op.constant(1),),
|
||||||
|
call_args=lambda: (5,),
|
||||||
|
error=r'func\(x, y\) expected a Tensor in x, but got int value 5'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedIntGotDifferentInt',
|
||||||
|
conc_args=lambda: (5,),
|
||||||
|
call_args=lambda: (8,),
|
||||||
|
error=r'ConcreteFunction func\(x, y\) was constructed with int '
|
||||||
|
r'value 5 in x, but was called with int value 8'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedIntGotTensor',
|
||||||
|
conc_args=lambda: (5,),
|
||||||
|
call_args=lambda: (constant_op.constant(6),),
|
||||||
|
error=r'ConcreteFunction func\(x, y\) was constructed with int '
|
||||||
|
'value 5 in x, but was called with (Eager)?Tensor value .*'),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoValuesForArgument',
|
||||||
|
conc_args=lambda: (1, 2),
|
||||||
|
call_args=lambda: (1, 2),
|
||||||
|
call_kwargs=lambda: {'x': 3},
|
||||||
|
error=r"func\(x, y\) got two values for argument 'x'"),
|
||||||
|
])
|
||||||
|
# pylint: enable=g-long-lambda
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionStructuredSignatureError(self,
|
||||||
|
conc_args=(),
|
||||||
|
conc_kwargs=None,
|
||||||
|
call_args=(),
|
||||||
|
call_kwargs=None,
|
||||||
|
error='.*',
|
||||||
|
exception=TypeError):
|
||||||
|
"""Tests for errors in the structrued signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conc_args: Positional arguments used for get_concrete_function.
|
||||||
|
conc_kwargs: Keyword arguments used for get_concrete_function.
|
||||||
|
call_args: Positional arguments used to call the function.
|
||||||
|
call_kwargs: Keyword arguments used to call the function.
|
||||||
|
error: Expected exception message.
|
||||||
|
exception: Expected exception type.
|
||||||
|
"""
|
||||||
|
conc_args = conc_args() if callable(conc_args) else conc_args
|
||||||
|
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
|
||||||
|
call_args = call_args() if callable(call_args) else call_args
|
||||||
|
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
|
||||||
|
self.assertIsInstance(conc_args, tuple)
|
||||||
|
self.assertIsInstance(call_args, tuple)
|
||||||
|
self.assertIsInstance(conc_kwargs, dict)
|
||||||
|
self.assertIsInstance(call_kwargs, dict)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
|
||||||
|
del y, varargs, kwargs
|
||||||
|
return x
|
||||||
|
|
||||||
|
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
|
||||||
|
with self.assertRaisesRegexp(exception, error):
|
||||||
|
self.evaluate(conc(*call_args, **call_kwargs))
|
||||||
|
|
||||||
|
# pylint: disable=g-long-lambda
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingArg',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
call_args=lambda: (constant_op.constant(1),),
|
||||||
|
error=r'func\(x, y\) missing required arguments: y'),
|
||||||
|
dict(
|
||||||
|
testcase_name='TwoValuesForArg',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
call_args=lambda: (constant_op.constant(1),),
|
||||||
|
call_kwargs=lambda: {
|
||||||
|
'x': constant_op.constant(1),
|
||||||
|
'y': constant_op.constant(1)
|
||||||
|
},
|
||||||
|
error=r"func\(x, y\) got two values for argument 'x'"),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExtraPositionalArg',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
call_args=lambda: (constant_op.constant(1), constant_op.constant(2),
|
||||||
|
constant_op.constant(3)),
|
||||||
|
error=r'func\(x, y\) takes 2 positional arguments but 3 were given'),
|
||||||
|
dict(
|
||||||
|
testcase_name='UnexpectedKeywordArg',
|
||||||
|
conc_args=lambda: (constant_op.constant(1),),
|
||||||
|
call_args=lambda: (constant_op.constant(1),),
|
||||||
|
call_kwargs=lambda: {'c': constant_op.constant(1)},
|
||||||
|
error=r'func\(x\) got unexpected keyword arguments: c'),
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingVararg',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2),
|
||||||
|
constant_op.constant(3)),
|
||||||
|
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
error=r'func\(x, y, varargs_0\) missing required '
|
||||||
|
r'arguments: varargs_0'),
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingKeywordArg',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
conc_kwargs=lambda: {'c': constant_op.constant(1)},
|
||||||
|
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
error=r'func\(x, y, c\) missing required arguments: c'),
|
||||||
|
dict(
|
||||||
|
testcase_name='ExpectedTensorGotInt',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
call_args=lambda: (5, constant_op.constant(2)),
|
||||||
|
error=r'func\(x, y\): expected argument #0\(zero-based\) to be '
|
||||||
|
r'a Tensor; got int \(5\)'),
|
||||||
|
dict(
|
||||||
|
testcase_name='WrongDType',
|
||||||
|
conc_args=lambda: (constant_op.constant(1),),
|
||||||
|
call_args=lambda: (constant_op.constant(1.0),),
|
||||||
|
exception=(ValueError, errors.InvalidArgumentError,
|
||||||
|
# on xla_gpu, we get InternalError instead.
|
||||||
|
errors.InternalError)),
|
||||||
|
dict(
|
||||||
|
testcase_name='MissingKeywordArgNestPiece',
|
||||||
|
conc_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
conc_kwargs=lambda: {'c': ragged_factory_ops.constant([[1]])},
|
||||||
|
call_args=lambda: (constant_op.constant(1), constant_op.constant(2)),
|
||||||
|
call_kwargs=lambda: {'c': constant_op.constant(1)},
|
||||||
|
error=r'func\(x, y, c, c_1\) missing required arguments: c_1'),
|
||||||
|
])
|
||||||
|
# pylint: enable=g-long-lambda
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionFlatSignatureError(self,
|
||||||
|
conc_args=(),
|
||||||
|
conc_kwargs=None,
|
||||||
|
call_args=(),
|
||||||
|
call_kwargs=None,
|
||||||
|
error='.*',
|
||||||
|
exception=TypeError):
|
||||||
|
"""Tests for errors in the flat signature.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conc_args: Positional arguments used for get_concrete_function.
|
||||||
|
conc_kwargs: Keyword arguments used for get_concrete_function.
|
||||||
|
call_args: Positional arguments used to call the function.
|
||||||
|
call_kwargs: Keyword arguments used to call the function.
|
||||||
|
error: Expected exception message.
|
||||||
|
exception: Expected exception type.
|
||||||
|
"""
|
||||||
|
conc_args = conc_args() if callable(conc_args) else conc_args
|
||||||
|
conc_kwargs = conc_kwargs() if callable(conc_kwargs) else conc_kwargs or {}
|
||||||
|
call_args = call_args() if callable(call_args) else call_args
|
||||||
|
call_kwargs = call_kwargs() if callable(call_kwargs) else call_kwargs or {}
|
||||||
|
self.assertIsInstance(conc_args, tuple)
|
||||||
|
self.assertIsInstance(call_args, tuple)
|
||||||
|
self.assertIsInstance(conc_kwargs, dict)
|
||||||
|
self.assertIsInstance(call_kwargs, dict)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def func(x, y=5, *varargs, **kwargs): # pylint: disable=keyword-arg-before-vararg
|
||||||
|
del y, varargs, kwargs
|
||||||
|
return x
|
||||||
|
|
||||||
|
conc = func.get_concrete_function(*conc_args, **conc_kwargs)
|
||||||
|
|
||||||
|
# Remove _function_spec, to disable the structured signature.
|
||||||
|
conc._set_function_spec(None) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(exception, error):
|
||||||
|
self.evaluate(conc(*call_args, **call_kwargs))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testConcreteFunctionAmbiguousSignature(self):
|
||||||
|
# When both the flat & structured signatures are applicable, but they
|
||||||
|
# give different results, we use the structured signature. Note: we expect
|
||||||
|
# this to be extremely rare.
|
||||||
|
@def_function.function
|
||||||
|
def f(x, y):
|
||||||
|
return x * 10 + y
|
||||||
|
|
||||||
|
conc = f.get_concrete_function(
|
||||||
|
x=tensor_spec.TensorSpec(None, dtypes.int32, name='y'),
|
||||||
|
y=tensor_spec.TensorSpec(None, dtypes.int32, name='x'))
|
||||||
|
|
||||||
|
result = conc(x=constant_op.constant(5), y=constant_op.constant(6))
|
||||||
|
self.assertAllEqual(result, 56)
|
||||||
|
|
||||||
|
def testPrettyPrintedSignature(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def func(x, kangaroo=None, octopus=7):
|
||||||
|
del octopus, kangaroo
|
||||||
|
return x
|
||||||
|
|
||||||
|
scalar = constant_op.constant(5)
|
||||||
|
vector = constant_op.constant([10, 10, 20])
|
||||||
|
ragged = ragged_factory_ops.constant([[10, 20], [40]])
|
||||||
|
|
||||||
|
c1 = func.get_concrete_function(scalar, vector)
|
||||||
|
c1_summary = r'func\(x, kangaroo, octopus=7\)'
|
||||||
|
c1_details = (r' Args:\n'
|
||||||
|
r' x: int32 Tensor, shape=\(\)\n'
|
||||||
|
r' kangaroo: int32 Tensor, shape=\(3,\)\n'
|
||||||
|
r' Returns:\n'
|
||||||
|
r' int32 Tensor, shape=\(\)')
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
c1.pretty_printed_signature(verbose=False), c1_summary)
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
c1.pretty_printed_signature(verbose=True),
|
||||||
|
c1_summary + '\n' + c1_details)
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
repr(c1), r'<ConcreteFunction func\(x, kangaroo, octopus=7\) at .*>')
|
||||||
|
self.assertRegexpMatches(
|
||||||
|
str(c1), 'ConcreteFunction {}\n{}'.format(c1_summary, c1_details))
|
||||||
|
|
||||||
|
c2 = func.get_concrete_function(scalar, ragged, 3)
|
||||||
|
c2_summary = r'func\(x, kangaroo, octopus=3\)'
|
||||||
|
c2_details = (r' Args:\n'
|
||||||
|
r' x: int32 Tensor, shape=\(\)\n'
|
||||||
|
r' kangaroo: RaggedTensorSpec\(.*\)\n'
|
||||||
|
r' Returns:\n'
|
||||||
|
r' int32 Tensor, shape=\(\)')
|
||||||
|
self.assertRegexpMatches(c2.pretty_printed_signature(),
|
||||||
|
c2_summary + '\n' + c2_details)
|
||||||
|
|
||||||
|
c3 = func.get_concrete_function({'a': scalar, 'b': [ragged, ragged]})
|
||||||
|
c3_summary = r'func\(x, kangaroo=None, octopus=7\)'
|
||||||
|
c3_details = (r' Args:\n'
|
||||||
|
r" x: {'a': <1>, 'b': \[<2>, <3>\]}\n"
|
||||||
|
r' <1>: int32 Tensor, shape=\(\)\n'
|
||||||
|
r' <2>: RaggedTensorSpec\(.*\)\n'
|
||||||
|
r' <3>: RaggedTensorSpec\(.*\)\n'
|
||||||
|
r' Returns:\n'
|
||||||
|
r" {'a': <1>, 'b': \[<2>, <3>\]}\n"
|
||||||
|
r' <1>: int32 Tensor, shape=\(\)\n'
|
||||||
|
r' <2>: RaggedTensorSpec\(.*\)\n'
|
||||||
|
r' <3>: RaggedTensorSpec\(.*\)')
|
||||||
|
self.assertRegexpMatches(c3.pretty_printed_signature(),
|
||||||
|
c3_summary + '\n' + c3_details)
|
||||||
|
|
||||||
|
# pylint: disable=keyword-arg-before-vararg
|
||||||
|
@def_function.function
|
||||||
|
def func2(x, y=3, *args, **kwargs):
|
||||||
|
return (x, y, args, kwargs)
|
||||||
|
|
||||||
|
c4 = func2.get_concrete_function(scalar, 4, 5, a=scalar)
|
||||||
|
c4_summary = 'func2(x, y=4, <arg3>=5, *, a)'
|
||||||
|
self.assertEqual(c4.pretty_printed_signature(verbose=False), c4_summary)
|
||||||
|
|
||||||
|
c5 = func2.get_concrete_function(8, vector)
|
||||||
|
c5_summary = 'func2(x=8, y)'
|
||||||
|
self.assertEqual(c5.pretty_printed_signature(verbose=False), c5_summary)
|
||||||
|
|
||||||
|
|
||||||
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
class MultiDeviceTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
|||||||
@ -43,6 +43,8 @@ def _is_tensor(t):
|
|||||||
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
|
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(edloper): Update this to just use ConcreteFunction.__call__ with the
|
||||||
|
# structured signature.
|
||||||
def _call_concrete_function(function, inputs):
|
def _call_concrete_function(function, inputs):
|
||||||
"""Calls a restored Function with structured inputs.
|
"""Calls a restored Function with structured inputs.
|
||||||
|
|
||||||
@ -137,8 +139,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
|
|||||||
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
||||||
return function_lib.FunctionSpec(fullargspec=fullargspec,
|
return function_lib.FunctionSpec(fullargspec=fullargspec,
|
||||||
is_method=False,
|
is_method=False,
|
||||||
args_to_prepend=[],
|
|
||||||
kwargs_to_include={},
|
|
||||||
input_signature=input_signature)
|
input_signature=input_signature)
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +191,8 @@ def recreate_function(saved_function, concrete_functions):
|
|||||||
Args:
|
Args:
|
||||||
saved_function: `SavedFunction` proto.
|
saved_function: `SavedFunction` proto.
|
||||||
concrete_functions: map from function name to `ConcreteFunction`.
|
concrete_functions: map from function name to `ConcreteFunction`.
|
||||||
|
As a side effect of this function, the `FunctionSpec` from
|
||||||
|
`saved_function` is added to each `ConcreteFunction` in this map.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Function`.
|
A `Function`.
|
||||||
@ -254,6 +256,9 @@ def recreate_function(saved_function, concrete_functions):
|
|||||||
for concrete_function_name in saved_function.concrete_functions:
|
for concrete_function_name in saved_function.concrete_functions:
|
||||||
concrete_function_objects.append(concrete_functions[concrete_function_name])
|
concrete_function_objects.append(concrete_functions[concrete_function_name])
|
||||||
|
|
||||||
|
for cf in concrete_function_objects:
|
||||||
|
cf._set_function_spec(function_spec) # pylint: disable=protected-access
|
||||||
|
|
||||||
restored_function = RestoredFunction(
|
restored_function = RestoredFunction(
|
||||||
restored_function_body,
|
restored_function_body,
|
||||||
restored_function_body.__name__,
|
restored_function_body.__name__,
|
||||||
@ -317,6 +322,11 @@ def load_function_def_library(library, load_shared_name_suffix=None):
|
|||||||
|
|
||||||
for dep in _list_function_deps(fdef, library_function_names):
|
for dep in _list_function_deps(fdef, library_function_names):
|
||||||
functions[dep].add_to_graph(func_graph)
|
functions[dep].add_to_graph(func_graph)
|
||||||
|
|
||||||
|
# We do not initialize the new ConcreteFunction's function_spec or
|
||||||
|
# arg_keywords here (which are used to parse the structured and flat
|
||||||
|
# signatures, respectively). function_spec is set up later by
|
||||||
|
# recreate_function(); and arg_keywords by setup_bare_concrete_function().
|
||||||
func = function_lib.ConcreteFunction(func_graph)
|
func = function_lib.ConcreteFunction(func_graph)
|
||||||
func.add_to_graph(graph)
|
func.add_to_graph(graph)
|
||||||
|
|
||||||
|
|||||||
@ -77,6 +77,10 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
|
|||||||
|
|
||||||
def serialize_bare_concrete_function(concrete_function, name_map):
|
def serialize_bare_concrete_function(concrete_function, name_map):
|
||||||
"""Build a SavedBareConcreteFunction."""
|
"""Build a SavedBareConcreteFunction."""
|
||||||
|
# TODO(edloper): Currently, bare concrete functions don't have access to a
|
||||||
|
# function_spec, so they can't be called with the structured signature.
|
||||||
|
# Update the serialization to include a function_spec.
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
name = name_map.get(compat.as_text(concrete_function.name),
|
name = name_map.get(compat.as_text(concrete_function.name),
|
||||||
concrete_function.name)
|
concrete_function.name)
|
||||||
@ -151,7 +155,8 @@ def wrap_cached_variables(concrete_function):
|
|||||||
func_graph_module.func_graph_from_py_func(
|
func_graph_module.func_graph_from_py_func(
|
||||||
None, wrap_function, args=tuple(args), kwargs={},
|
None, wrap_function, args=tuple(args), kwargs={},
|
||||||
func_graph=outer_graph)
|
func_graph=outer_graph)
|
||||||
fn = defun.ConcreteFunction(outer_graph)
|
fn = defun.ConcreteFunction(
|
||||||
|
outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access
|
||||||
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
|
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
|
||||||
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
|
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
|
||||||
|
|
||||||
|
|||||||
@ -173,12 +173,12 @@ class Loader(object):
|
|||||||
# The original_outputs here had Tensors converted to TensorSpecs, so
|
# The original_outputs here had Tensors converted to TensorSpecs, so
|
||||||
# the restored function's structured_outputs field will not be
|
# the restored function's structured_outputs field will not be
|
||||||
# exactly the same. Fortunately the repacking logic cares only about
|
# exactly the same. Fortunately the repacking logic cares only about
|
||||||
# the structure.
|
# the structure; and the unpacking logic cares only about structure
|
||||||
# TODO(vbardiovsky): Should we just replicate the structures, with
|
# and types.
|
||||||
# Nones instead of real objects?
|
|
||||||
concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access
|
concrete_function._func_graph.structured_outputs = original_outputs # pylint: disable=protected-access
|
||||||
concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access
|
concrete_function._func_graph.structured_input_signature = ( # pylint: disable=protected-access
|
||||||
coder.decode_proto(proto.canonicalized_input_signature))
|
coder.decode_proto(proto.canonicalized_input_signature))
|
||||||
|
concrete_function._initialize_function_spec() # pylint: disable=protected-access
|
||||||
|
|
||||||
def _setup_functions_captures(self):
|
def _setup_functions_captures(self):
|
||||||
"""Setup captures and variables in restored functions."""
|
"""Setup captures and variables in restored functions."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user