Extract function spec information that is needed for canonicalization.
This will allow serialization of all the information we need for canonicalization. PiperOrigin-RevId: 225960841
This commit is contained in:
parent
e4bdb31636
commit
15fa7c49e2
@ -269,8 +269,11 @@ class PolymorphicFunction(object):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"_canonicalize_function_inputs must be called only after _initialize "
|
"_canonicalize_function_inputs must be called only after _initialize "
|
||||||
"has run.")
|
"has run.")
|
||||||
|
# pylint: disable=protected-access
|
||||||
if self._input_signature is None or args or kwds:
|
if self._input_signature is None or args or kwds:
|
||||||
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
|
return self._stateful_fn._function_spec.canonicalize_function_inputs(
|
||||||
|
*args, **kwds)
|
||||||
|
# pylint: enable=protected-access
|
||||||
# If an input signature is defined, we may need to fetch a concrete function
|
# If an input signature is defined, we may need to fetch a concrete function
|
||||||
# without any inputs specified. In this case args and kwds should be ignored
|
# without any inputs specified. In this case args and kwds should be ignored
|
||||||
# but running _canonicalize_function_inputs would raise an exception.
|
# but running _canonicalize_function_inputs would raise an exception.
|
||||||
|
|||||||
@ -769,6 +769,146 @@ def _deterministic_dict_values(dictionary):
|
|||||||
return tuple(dictionary[key] for key in sorted(dictionary))
|
return tuple(dictionary[key] for key in sorted(dictionary))
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionSpec(object):
|
||||||
|
"""Specification of how to bind arguments to a function."""
|
||||||
|
|
||||||
|
def __init__(self, python_function, input_signature):
|
||||||
|
if isinstance(python_function, functools.partial):
|
||||||
|
python_function_to_inspect = python_function.func
|
||||||
|
self._args_to_prepend = python_function.args or tuple()
|
||||||
|
self._kwargs_to_include = python_function.keywords or {}
|
||||||
|
else:
|
||||||
|
python_function_to_inspect = python_function
|
||||||
|
self._args_to_prepend = tuple()
|
||||||
|
self._kwargs_to_include = {}
|
||||||
|
|
||||||
|
fullargspec = tf_inspect.getfullargspec(python_function_to_inspect)
|
||||||
|
self._default_values = fullargspec.defaults
|
||||||
|
|
||||||
|
if tf_inspect.ismethod(python_function_to_inspect):
|
||||||
|
# Remove `self`: default arguments shouldn't be matched to it.
|
||||||
|
args = fullargspec.args[1:]
|
||||||
|
else:
|
||||||
|
args = fullargspec.args
|
||||||
|
|
||||||
|
# A cache mapping from argument name to index, for canonicalizing
|
||||||
|
# arguments that are called in a keyword-like fashion.
|
||||||
|
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
||||||
|
self.arg_names = args
|
||||||
|
self.vararg_name = fullargspec.varargs
|
||||||
|
|
||||||
|
# A cache mapping from arg index to default value, for canonicalization.
|
||||||
|
offset = len(args) - len(fullargspec.defaults or [])
|
||||||
|
self._arg_indices_to_default_values = {
|
||||||
|
offset + index: default
|
||||||
|
for index, default in enumerate(fullargspec.defaults or [])
|
||||||
|
}
|
||||||
|
self._default_values_start_index = offset
|
||||||
|
if input_signature is None:
|
||||||
|
self.input_signature = None
|
||||||
|
else:
|
||||||
|
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
||||||
|
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||||
|
"function with keyword arguments when "
|
||||||
|
"input_signature is provided.")
|
||||||
|
|
||||||
|
if not isinstance(input_signature, (tuple, list)):
|
||||||
|
raise TypeError("input_signature must be either a tuple or a "
|
||||||
|
"list, received " + str(type(input_signature)))
|
||||||
|
|
||||||
|
self.input_signature = tuple(input_signature)
|
||||||
|
self.flat_input_signature = tuple(nest.flatten(input_signature))
|
||||||
|
|
||||||
|
def canonicalize_function_inputs(self, *args, **kwargs):
|
||||||
|
"""Canonicalizes `args` and `kwargs`.
|
||||||
|
|
||||||
|
Canonicalize the inputs to the Python function using a `FunctionSpec`
|
||||||
|
instance. In particular, we parse the varags and kwargs that the
|
||||||
|
original function was called with into a tuple corresponding to the
|
||||||
|
Python function's positional (named) arguments and a dictionary
|
||||||
|
corresponding to its kwargs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: The varargs this object was called with.
|
||||||
|
**kwargs: The keyword args this function was called with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A canonicalized ordering of the inputs.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a keyword in `kwargs` cannot be matched with a positional
|
||||||
|
argument when an input signature is specified, or when the inputs
|
||||||
|
do not conform to the input signature.
|
||||||
|
"""
|
||||||
|
args = self._args_to_prepend + args
|
||||||
|
kwargs = dict(kwargs, **self._kwargs_to_include)
|
||||||
|
if not kwargs:
|
||||||
|
if self._default_values:
|
||||||
|
inputs = args + self._default_values[
|
||||||
|
len(args) - self._default_values_start_index:]
|
||||||
|
else:
|
||||||
|
inputs = args
|
||||||
|
else:
|
||||||
|
# Maps from index of arg to its corresponding value, according to `args`
|
||||||
|
# and `kwargs`; seeded with the default values for the named args that
|
||||||
|
# aren't in `args`.
|
||||||
|
arg_indices_to_values = {
|
||||||
|
index: default for index, default in six.iteritems(
|
||||||
|
self._arg_indices_to_default_values) if index >= len(args)
|
||||||
|
}
|
||||||
|
consumed_args = []
|
||||||
|
for arg, value in six.iteritems(kwargs):
|
||||||
|
index = self._args_to_indices.get(arg, None)
|
||||||
|
if index is not None:
|
||||||
|
arg_indices_to_values[index] = value
|
||||||
|
consumed_args.append(arg)
|
||||||
|
elif self.input_signature is not None:
|
||||||
|
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||||
|
"function with keyword arguments when "
|
||||||
|
"input_signature is provided.")
|
||||||
|
for arg in consumed_args:
|
||||||
|
# After this loop, `kwargs` will only contain true keyword arguments, as
|
||||||
|
# opposed to named arguments called in a keyword-like fashion.
|
||||||
|
kwargs.pop(arg)
|
||||||
|
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||||
|
flat_inputs = nest.flatten(inputs)
|
||||||
|
|
||||||
|
# Check for NumPy arrays in arguments and convert them to Tensors.
|
||||||
|
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
|
||||||
|
# finding a way to store them directly in the cache key (currently not
|
||||||
|
# possible since ndarrays are not hashable).
|
||||||
|
need_packing = False
|
||||||
|
for index, value in enumerate(flat_inputs):
|
||||||
|
if type(value) == np.ndarray:
|
||||||
|
flat_inputs[index] = constant_op.constant(value)
|
||||||
|
need_packing = True
|
||||||
|
if need_packing:
|
||||||
|
inputs = nest.pack_sequence_as(
|
||||||
|
structure=inputs, flat_sequence=flat_inputs)
|
||||||
|
if self.input_signature is None:
|
||||||
|
return inputs, kwargs
|
||||||
|
else:
|
||||||
|
assert not kwargs
|
||||||
|
signature_relevant_inputs = inputs[:len(self.input_signature)]
|
||||||
|
try:
|
||||||
|
nest.assert_same_structure(self.input_signature,
|
||||||
|
signature_relevant_inputs)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
raise ValueError("Structure of Python function inputs does not match "
|
||||||
|
"input_signature.")
|
||||||
|
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
||||||
|
if any(
|
||||||
|
not pywrap_tensorflow.IsTensor(arg) for arg in signature_inputs_flat):
|
||||||
|
raise ValueError("When input_signature is provided, all inputs to "
|
||||||
|
"the Python function must be Tensors.")
|
||||||
|
if any(not spec.is_compatible_with(other) for spec, other in zip(
|
||||||
|
self.flat_input_signature, signature_inputs_flat)):
|
||||||
|
raise ValueError("Python inputs incompatible with input_signature: "
|
||||||
|
"inputs (%s), input_signature (%s)" %
|
||||||
|
(str(inputs), str(self.input_signature)))
|
||||||
|
return inputs, {}
|
||||||
|
|
||||||
|
|
||||||
class PolymorphicFunction(object):
|
class PolymorphicFunction(object):
|
||||||
"""Wrapper class for the graph functions defined for a Python function.
|
"""Wrapper class for the graph functions defined for a Python function.
|
||||||
|
|
||||||
@ -805,15 +945,11 @@ class PolymorphicFunction(object):
|
|||||||
ValueError: if `input_signature` is not None and the `python_function`'s
|
ValueError: if `input_signature` is not None and the `python_function`'s
|
||||||
argspec has keyword arguments.
|
argspec has keyword arguments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(python_function, functools.partial):
|
if isinstance(python_function, functools.partial):
|
||||||
self._python_function = python_function.func
|
self._python_function = python_function.func
|
||||||
self._args_to_prepend = python_function.args or tuple()
|
|
||||||
self._kwargs_to_include = python_function.keywords or {}
|
|
||||||
else:
|
else:
|
||||||
self._python_function = python_function
|
self._python_function = python_function
|
||||||
self._args_to_prepend = tuple()
|
self._function_spec = FunctionSpec(python_function, input_signature)
|
||||||
self._kwargs_to_include = {}
|
|
||||||
self._name = name
|
self._name = name
|
||||||
self._autograph = autograph
|
self._autograph = autograph
|
||||||
self._function_cache = collections.OrderedDict()
|
self._function_cache = collections.OrderedDict()
|
||||||
@ -827,41 +963,6 @@ class PolymorphicFunction(object):
|
|||||||
# different functions for each instance.
|
# different functions for each instance.
|
||||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
fullargspec = tf_inspect.getfullargspec(self._python_function)
|
|
||||||
if tf_inspect.ismethod(self._python_function):
|
|
||||||
# Remove `self`: default arguments shouldn't be matched to it.
|
|
||||||
args = fullargspec.args[1:]
|
|
||||||
else:
|
|
||||||
args = fullargspec.args
|
|
||||||
|
|
||||||
# A cache mapping from argument name to index, for canonicalizing
|
|
||||||
# arguments that are called in a keyword-like fashion.
|
|
||||||
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
|
||||||
self._arg_names = args
|
|
||||||
self._vararg_name = fullargspec.varargs
|
|
||||||
# A cache mapping from arg index to default value, for canonicalization.
|
|
||||||
offset = len(args) - len(fullargspec.defaults or [])
|
|
||||||
self._arg_indices_to_default_values = {
|
|
||||||
offset + index: default
|
|
||||||
for index, default in enumerate(fullargspec.defaults or [])
|
|
||||||
}
|
|
||||||
self._default_values = fullargspec.defaults
|
|
||||||
self._default_values_start_index = offset
|
|
||||||
if input_signature is None:
|
|
||||||
self._input_signature = None
|
|
||||||
else:
|
|
||||||
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
|
||||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
|
||||||
"function with keyword arguments when "
|
|
||||||
"input_signature is provided.")
|
|
||||||
|
|
||||||
if not isinstance(input_signature, (tuple, list)):
|
|
||||||
raise TypeError("input_signature must be either a tuple or a "
|
|
||||||
"list, received " + str(type(input_signature)))
|
|
||||||
|
|
||||||
self._input_signature = tuple(input_signature)
|
|
||||||
self._flat_input_signature = tuple(nest.flatten(input_signature))
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Calls a graph function specialized to the inputs."""
|
"""Calls a graph function specialized to the inputs."""
|
||||||
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
||||||
@ -870,7 +971,17 @@ class PolymorphicFunction(object):
|
|||||||
@property
|
@property
|
||||||
def python_function(self):
|
def python_function(self):
|
||||||
"""Returns the wrapped Python function."""
|
"""Returns the wrapped Python function."""
|
||||||
return self._python_function
|
return self._python_function # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _input_signature(self):
|
||||||
|
"""Returns the wrapped Python function."""
|
||||||
|
return self._function_spec.input_signature # pylint: disable=protected-access
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _flat_input_signature(self):
|
||||||
|
"""Returns the wrapped Python function."""
|
||||||
|
return self._function_spec.flat_input_signature # pylint: disable=protected-access
|
||||||
|
|
||||||
def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
|
def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
|
||||||
"""Returns a concrete function which cleans up its graph function."""
|
"""Returns a concrete function which cleans up its graph function."""
|
||||||
@ -1050,96 +1161,6 @@ class PolymorphicFunction(object):
|
|||||||
return CacheKey(input_signature, parent_graph, device_functions,
|
return CacheKey(input_signature, parent_graph, device_functions,
|
||||||
colocation_stack, uses_xla)
|
colocation_stack, uses_xla)
|
||||||
|
|
||||||
def _canonicalize_function_inputs(self, *args, **kwargs):
|
|
||||||
"""Canonicalizes `args` and `kwargs`.
|
|
||||||
|
|
||||||
Canonicalize the inputs to the Python function using its fullargspec. In
|
|
||||||
particular, we parse the varags and kwargs that this
|
|
||||||
`PolymorphicFunction` was called with into a tuple corresponding to the
|
|
||||||
Python function's positional (named) arguments and a dictionary
|
|
||||||
corresponding to its kwargs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: The varargs this object was called with.
|
|
||||||
**kwargs: The keyword args this function was called with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A canonicalized ordering of the inputs.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If a keyword in `kwargs` cannot be matched with a positional
|
|
||||||
argument when an input signature is specified, or when the inputs
|
|
||||||
do not conform to the input signature.
|
|
||||||
"""
|
|
||||||
args = self._args_to_prepend + args
|
|
||||||
kwargs = dict(kwargs, **self._kwargs_to_include)
|
|
||||||
if not kwargs:
|
|
||||||
if self._default_values:
|
|
||||||
inputs = args + self._default_values[len(args) -
|
|
||||||
self._default_values_start_index:]
|
|
||||||
else:
|
|
||||||
inputs = args
|
|
||||||
else:
|
|
||||||
# Maps from index of arg to its corresponding value, according to `args`
|
|
||||||
# and `kwargs`; seeded with the default values for the named args that
|
|
||||||
# aren't in `args`.
|
|
||||||
arg_indices_to_values = {
|
|
||||||
index: default for index, default in six.iteritems(
|
|
||||||
self._arg_indices_to_default_values) if index >= len(args)
|
|
||||||
}
|
|
||||||
consumed_args = []
|
|
||||||
for arg, value in six.iteritems(kwargs):
|
|
||||||
index = self._args_to_indices.get(arg, None)
|
|
||||||
if index is not None:
|
|
||||||
arg_indices_to_values[index] = value
|
|
||||||
consumed_args.append(arg)
|
|
||||||
elif self._input_signature is not None:
|
|
||||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
|
||||||
"function with keyword arguments when "
|
|
||||||
"input_signature is provided.")
|
|
||||||
for arg in consumed_args:
|
|
||||||
# After this loop, `kwargs` will only contain true keyword arguments, as
|
|
||||||
# opposed to named arguments called in a keyword-like fashion.
|
|
||||||
kwargs.pop(arg)
|
|
||||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
|
||||||
flat_inputs = nest.flatten(inputs)
|
|
||||||
|
|
||||||
# Check for NumPy arrays in arguments and convert them to Tensors.
|
|
||||||
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
|
|
||||||
# finding a way to store them directly in the cache key (currently not
|
|
||||||
# possible since ndarrays are not hashable).
|
|
||||||
need_packing = False
|
|
||||||
for index, value in enumerate(flat_inputs):
|
|
||||||
if type(value) == np.ndarray:
|
|
||||||
flat_inputs[index] = constant_op.constant(value)
|
|
||||||
need_packing = True
|
|
||||||
if need_packing:
|
|
||||||
inputs = nest.pack_sequence_as(structure=inputs,
|
|
||||||
flat_sequence=flat_inputs)
|
|
||||||
if self._input_signature is None:
|
|
||||||
return inputs, kwargs
|
|
||||||
else:
|
|
||||||
assert not kwargs
|
|
||||||
signature_relevant_inputs = inputs[:len(self._input_signature)]
|
|
||||||
try:
|
|
||||||
nest.assert_same_structure(self._input_signature,
|
|
||||||
signature_relevant_inputs)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise ValueError("Structure of Python function inputs does not match "
|
|
||||||
"input_signature.")
|
|
||||||
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
|
||||||
if any(not pywrap_tensorflow.IsTensor(arg)
|
|
||||||
for arg in signature_inputs_flat):
|
|
||||||
raise ValueError("When input_signature is provided, all inputs to "
|
|
||||||
"the Python function must be Tensors.")
|
|
||||||
if any(not spec.is_compatible_with(other)
|
|
||||||
for spec, other in zip(self._flat_input_signature,
|
|
||||||
signature_inputs_flat)):
|
|
||||||
raise ValueError("Python inputs incompatible with input_signature: "
|
|
||||||
"inputs (%s), input_signature (%s)" %
|
|
||||||
(str(inputs), str(self._input_signature)))
|
|
||||||
return inputs, {}
|
|
||||||
|
|
||||||
def _maybe_define_function(self, args, kwargs):
|
def _maybe_define_function(self, args, kwargs):
|
||||||
"""Gets a function for these inputs, defining it if necessary.
|
"""Gets a function for these inputs, defining it if necessary.
|
||||||
|
|
||||||
@ -1159,7 +1180,8 @@ class PolymorphicFunction(object):
|
|||||||
TypeError: If the function inputs include non-hashable objects
|
TypeError: If the function inputs include non-hashable objects
|
||||||
"""
|
"""
|
||||||
if self._input_signature is None or args is not None or kwargs is not None:
|
if self._input_signature is None or args is not None or kwargs is not None:
|
||||||
args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
|
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||||
|
*args, **kwargs)
|
||||||
cache_key = self._cache_key(args, kwargs)
|
cache_key = self._cache_key(args, kwargs)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
try:
|
try:
|
||||||
@ -1177,8 +1199,9 @@ class PolymorphicFunction(object):
|
|||||||
else:
|
else:
|
||||||
arglen = len(self._input_signature)
|
arglen = len(self._input_signature)
|
||||||
arg_names = (
|
arg_names = (
|
||||||
self._arg_names[:arglen]
|
self._function_spec.arg_names[:arglen]
|
||||||
+ [self._vararg_name] * (arglen - len(self._arg_names)))
|
+ [self._function_spec.vararg_name] *
|
||||||
|
(arglen - len(self._function_spec.arg_names)))
|
||||||
graph_function = Function(
|
graph_function = Function(
|
||||||
func_graph_module.func_graph_from_py_func(
|
func_graph_module.func_graph_from_py_func(
|
||||||
self._name,
|
self._name,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user