Extract function spec information that is needed for canonicalization.
This will allow serialization of all the information we need for canonicalization. PiperOrigin-RevId: 225960841
This commit is contained in:
parent
e4bdb31636
commit
15fa7c49e2
@ -269,8 +269,11 @@ class PolymorphicFunction(object):
|
||||
raise ValueError(
|
||||
"_canonicalize_function_inputs must be called only after _initialize "
|
||||
"has run.")
|
||||
# pylint: disable=protected-access
|
||||
if self._input_signature is None or args or kwds:
|
||||
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
|
||||
return self._stateful_fn._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwds)
|
||||
# pylint: enable=protected-access
|
||||
# If an input signature is defined, we may need to fetch a concrete function
|
||||
# without any inputs specified. In this case args and kwds should be ignored
|
||||
# but running _canonicalize_function_inputs would raise an exception.
|
||||
|
||||
@ -769,6 +769,146 @@ def _deterministic_dict_values(dictionary):
|
||||
return tuple(dictionary[key] for key in sorted(dictionary))
|
||||
|
||||
|
||||
class FunctionSpec(object):
|
||||
"""Specification of how to bind arguments to a function."""
|
||||
|
||||
def __init__(self, python_function, input_signature):
|
||||
if isinstance(python_function, functools.partial):
|
||||
python_function_to_inspect = python_function.func
|
||||
self._args_to_prepend = python_function.args or tuple()
|
||||
self._kwargs_to_include = python_function.keywords or {}
|
||||
else:
|
||||
python_function_to_inspect = python_function
|
||||
self._args_to_prepend = tuple()
|
||||
self._kwargs_to_include = {}
|
||||
|
||||
fullargspec = tf_inspect.getfullargspec(python_function_to_inspect)
|
||||
self._default_values = fullargspec.defaults
|
||||
|
||||
if tf_inspect.ismethod(python_function_to_inspect):
|
||||
# Remove `self`: default arguments shouldn't be matched to it.
|
||||
args = fullargspec.args[1:]
|
||||
else:
|
||||
args = fullargspec.args
|
||||
|
||||
# A cache mapping from argument name to index, for canonicalizing
|
||||
# arguments that are called in a keyword-like fashion.
|
||||
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
||||
self.arg_names = args
|
||||
self.vararg_name = fullargspec.varargs
|
||||
|
||||
# A cache mapping from arg index to default value, for canonicalization.
|
||||
offset = len(args) - len(fullargspec.defaults or [])
|
||||
self._arg_indices_to_default_values = {
|
||||
offset + index: default
|
||||
for index, default in enumerate(fullargspec.defaults or [])
|
||||
}
|
||||
self._default_values_start_index = offset
|
||||
if input_signature is None:
|
||||
self.input_signature = None
|
||||
else:
|
||||
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"input_signature is provided.")
|
||||
|
||||
if not isinstance(input_signature, (tuple, list)):
|
||||
raise TypeError("input_signature must be either a tuple or a "
|
||||
"list, received " + str(type(input_signature)))
|
||||
|
||||
self.input_signature = tuple(input_signature)
|
||||
self.flat_input_signature = tuple(nest.flatten(input_signature))
|
||||
|
||||
def canonicalize_function_inputs(self, *args, **kwargs):
|
||||
"""Canonicalizes `args` and `kwargs`.
|
||||
|
||||
Canonicalize the inputs to the Python function using a `FunctionSpec`
|
||||
instance. In particular, we parse the varags and kwargs that the
|
||||
original function was called with into a tuple corresponding to the
|
||||
Python function's positional (named) arguments and a dictionary
|
||||
corresponding to its kwargs.
|
||||
|
||||
Args:
|
||||
*args: The varargs this object was called with.
|
||||
**kwargs: The keyword args this function was called with.
|
||||
|
||||
Returns:
|
||||
A canonicalized ordering of the inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a keyword in `kwargs` cannot be matched with a positional
|
||||
argument when an input signature is specified, or when the inputs
|
||||
do not conform to the input signature.
|
||||
"""
|
||||
args = self._args_to_prepend + args
|
||||
kwargs = dict(kwargs, **self._kwargs_to_include)
|
||||
if not kwargs:
|
||||
if self._default_values:
|
||||
inputs = args + self._default_values[
|
||||
len(args) - self._default_values_start_index:]
|
||||
else:
|
||||
inputs = args
|
||||
else:
|
||||
# Maps from index of arg to its corresponding value, according to `args`
|
||||
# and `kwargs`; seeded with the default values for the named args that
|
||||
# aren't in `args`.
|
||||
arg_indices_to_values = {
|
||||
index: default for index, default in six.iteritems(
|
||||
self._arg_indices_to_default_values) if index >= len(args)
|
||||
}
|
||||
consumed_args = []
|
||||
for arg, value in six.iteritems(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is not None:
|
||||
arg_indices_to_values[index] = value
|
||||
consumed_args.append(arg)
|
||||
elif self.input_signature is not None:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"input_signature is provided.")
|
||||
for arg in consumed_args:
|
||||
# After this loop, `kwargs` will only contain true keyword arguments, as
|
||||
# opposed to named arguments called in a keyword-like fashion.
|
||||
kwargs.pop(arg)
|
||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
|
||||
# Check for NumPy arrays in arguments and convert them to Tensors.
|
||||
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
|
||||
# finding a way to store them directly in the cache key (currently not
|
||||
# possible since ndarrays are not hashable).
|
||||
need_packing = False
|
||||
for index, value in enumerate(flat_inputs):
|
||||
if type(value) == np.ndarray:
|
||||
flat_inputs[index] = constant_op.constant(value)
|
||||
need_packing = True
|
||||
if need_packing:
|
||||
inputs = nest.pack_sequence_as(
|
||||
structure=inputs, flat_sequence=flat_inputs)
|
||||
if self.input_signature is None:
|
||||
return inputs, kwargs
|
||||
else:
|
||||
assert not kwargs
|
||||
signature_relevant_inputs = inputs[:len(self.input_signature)]
|
||||
try:
|
||||
nest.assert_same_structure(self.input_signature,
|
||||
signature_relevant_inputs)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError("Structure of Python function inputs does not match "
|
||||
"input_signature.")
|
||||
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
||||
if any(
|
||||
not pywrap_tensorflow.IsTensor(arg) for arg in signature_inputs_flat):
|
||||
raise ValueError("When input_signature is provided, all inputs to "
|
||||
"the Python function must be Tensors.")
|
||||
if any(not spec.is_compatible_with(other) for spec, other in zip(
|
||||
self.flat_input_signature, signature_inputs_flat)):
|
||||
raise ValueError("Python inputs incompatible with input_signature: "
|
||||
"inputs (%s), input_signature (%s)" %
|
||||
(str(inputs), str(self.input_signature)))
|
||||
return inputs, {}
|
||||
|
||||
|
||||
class PolymorphicFunction(object):
|
||||
"""Wrapper class for the graph functions defined for a Python function.
|
||||
|
||||
@ -805,15 +945,11 @@ class PolymorphicFunction(object):
|
||||
ValueError: if `input_signature` is not None and the `python_function`'s
|
||||
argspec has keyword arguments.
|
||||
"""
|
||||
|
||||
if isinstance(python_function, functools.partial):
|
||||
self._python_function = python_function.func
|
||||
self._args_to_prepend = python_function.args or tuple()
|
||||
self._kwargs_to_include = python_function.keywords or {}
|
||||
else:
|
||||
self._python_function = python_function
|
||||
self._args_to_prepend = tuple()
|
||||
self._kwargs_to_include = {}
|
||||
self._function_spec = FunctionSpec(python_function, input_signature)
|
||||
self._name = name
|
||||
self._autograph = autograph
|
||||
self._function_cache = collections.OrderedDict()
|
||||
@ -827,41 +963,6 @@ class PolymorphicFunction(object):
|
||||
# different functions for each instance.
|
||||
self._descriptor_cache = weakref.WeakKeyDictionary()
|
||||
|
||||
fullargspec = tf_inspect.getfullargspec(self._python_function)
|
||||
if tf_inspect.ismethod(self._python_function):
|
||||
# Remove `self`: default arguments shouldn't be matched to it.
|
||||
args = fullargspec.args[1:]
|
||||
else:
|
||||
args = fullargspec.args
|
||||
|
||||
# A cache mapping from argument name to index, for canonicalizing
|
||||
# arguments that are called in a keyword-like fashion.
|
||||
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
|
||||
self._arg_names = args
|
||||
self._vararg_name = fullargspec.varargs
|
||||
# A cache mapping from arg index to default value, for canonicalization.
|
||||
offset = len(args) - len(fullargspec.defaults or [])
|
||||
self._arg_indices_to_default_values = {
|
||||
offset + index: default
|
||||
for index, default in enumerate(fullargspec.defaults or [])
|
||||
}
|
||||
self._default_values = fullargspec.defaults
|
||||
self._default_values_start_index = offset
|
||||
if input_signature is None:
|
||||
self._input_signature = None
|
||||
else:
|
||||
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"input_signature is provided.")
|
||||
|
||||
if not isinstance(input_signature, (tuple, list)):
|
||||
raise TypeError("input_signature must be either a tuple or a "
|
||||
"list, received " + str(type(input_signature)))
|
||||
|
||||
self._input_signature = tuple(input_signature)
|
||||
self._flat_input_signature = tuple(nest.flatten(input_signature))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Calls a graph function specialized to the inputs."""
|
||||
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
|
||||
@ -870,7 +971,17 @@ class PolymorphicFunction(object):
|
||||
@property
|
||||
def python_function(self):
|
||||
"""Returns the wrapped Python function."""
|
||||
return self._python_function
|
||||
return self._python_function # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _input_signature(self):
|
||||
"""Returns the wrapped Python function."""
|
||||
return self._function_spec.input_signature # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def _flat_input_signature(self):
|
||||
"""Returns the wrapped Python function."""
|
||||
return self._function_spec.flat_input_signature # pylint: disable=protected-access
|
||||
|
||||
def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
|
||||
"""Returns a concrete function which cleans up its graph function."""
|
||||
@ -1050,96 +1161,6 @@ class PolymorphicFunction(object):
|
||||
return CacheKey(input_signature, parent_graph, device_functions,
|
||||
colocation_stack, uses_xla)
|
||||
|
||||
def _canonicalize_function_inputs(self, *args, **kwargs):
|
||||
"""Canonicalizes `args` and `kwargs`.
|
||||
|
||||
Canonicalize the inputs to the Python function using its fullargspec. In
|
||||
particular, we parse the varags and kwargs that this
|
||||
`PolymorphicFunction` was called with into a tuple corresponding to the
|
||||
Python function's positional (named) arguments and a dictionary
|
||||
corresponding to its kwargs.
|
||||
|
||||
Args:
|
||||
*args: The varargs this object was called with.
|
||||
**kwargs: The keyword args this function was called with.
|
||||
|
||||
Returns:
|
||||
A canonicalized ordering of the inputs.
|
||||
|
||||
Raises:
|
||||
ValueError: If a keyword in `kwargs` cannot be matched with a positional
|
||||
argument when an input signature is specified, or when the inputs
|
||||
do not conform to the input signature.
|
||||
"""
|
||||
args = self._args_to_prepend + args
|
||||
kwargs = dict(kwargs, **self._kwargs_to_include)
|
||||
if not kwargs:
|
||||
if self._default_values:
|
||||
inputs = args + self._default_values[len(args) -
|
||||
self._default_values_start_index:]
|
||||
else:
|
||||
inputs = args
|
||||
else:
|
||||
# Maps from index of arg to its corresponding value, according to `args`
|
||||
# and `kwargs`; seeded with the default values for the named args that
|
||||
# aren't in `args`.
|
||||
arg_indices_to_values = {
|
||||
index: default for index, default in six.iteritems(
|
||||
self._arg_indices_to_default_values) if index >= len(args)
|
||||
}
|
||||
consumed_args = []
|
||||
for arg, value in six.iteritems(kwargs):
|
||||
index = self._args_to_indices.get(arg, None)
|
||||
if index is not None:
|
||||
arg_indices_to_values[index] = value
|
||||
consumed_args.append(arg)
|
||||
elif self._input_signature is not None:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"input_signature is provided.")
|
||||
for arg in consumed_args:
|
||||
# After this loop, `kwargs` will only contain true keyword arguments, as
|
||||
# opposed to named arguments called in a keyword-like fashion.
|
||||
kwargs.pop(arg)
|
||||
inputs = args + _deterministic_dict_values(arg_indices_to_values)
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
|
||||
# Check for NumPy arrays in arguments and convert them to Tensors.
|
||||
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
|
||||
# finding a way to store them directly in the cache key (currently not
|
||||
# possible since ndarrays are not hashable).
|
||||
need_packing = False
|
||||
for index, value in enumerate(flat_inputs):
|
||||
if type(value) == np.ndarray:
|
||||
flat_inputs[index] = constant_op.constant(value)
|
||||
need_packing = True
|
||||
if need_packing:
|
||||
inputs = nest.pack_sequence_as(structure=inputs,
|
||||
flat_sequence=flat_inputs)
|
||||
if self._input_signature is None:
|
||||
return inputs, kwargs
|
||||
else:
|
||||
assert not kwargs
|
||||
signature_relevant_inputs = inputs[:len(self._input_signature)]
|
||||
try:
|
||||
nest.assert_same_structure(self._input_signature,
|
||||
signature_relevant_inputs)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError("Structure of Python function inputs does not match "
|
||||
"input_signature.")
|
||||
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
||||
if any(not pywrap_tensorflow.IsTensor(arg)
|
||||
for arg in signature_inputs_flat):
|
||||
raise ValueError("When input_signature is provided, all inputs to "
|
||||
"the Python function must be Tensors.")
|
||||
if any(not spec.is_compatible_with(other)
|
||||
for spec, other in zip(self._flat_input_signature,
|
||||
signature_inputs_flat)):
|
||||
raise ValueError("Python inputs incompatible with input_signature: "
|
||||
"inputs (%s), input_signature (%s)" %
|
||||
(str(inputs), str(self._input_signature)))
|
||||
return inputs, {}
|
||||
|
||||
def _maybe_define_function(self, args, kwargs):
|
||||
"""Gets a function for these inputs, defining it if necessary.
|
||||
|
||||
@ -1159,7 +1180,8 @@ class PolymorphicFunction(object):
|
||||
TypeError: If the function inputs include non-hashable objects
|
||||
"""
|
||||
if self._input_signature is None or args is not None or kwargs is not None:
|
||||
args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
|
||||
args, kwargs = self._function_spec.canonicalize_function_inputs(
|
||||
*args, **kwargs)
|
||||
cache_key = self._cache_key(args, kwargs)
|
||||
with self._lock:
|
||||
try:
|
||||
@ -1177,8 +1199,9 @@ class PolymorphicFunction(object):
|
||||
else:
|
||||
arglen = len(self._input_signature)
|
||||
arg_names = (
|
||||
self._arg_names[:arglen]
|
||||
+ [self._vararg_name] * (arglen - len(self._arg_names)))
|
||||
self._function_spec.arg_names[:arglen]
|
||||
+ [self._function_spec.vararg_name] *
|
||||
(arglen - len(self._function_spec.arg_names)))
|
||||
graph_function = Function(
|
||||
func_graph_module.func_graph_from_py_func(
|
||||
self._name,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user