Extract function spec information that is needed for canonicalization.

This will allow serialization of all the information we need for canonicalization.

PiperOrigin-RevId: 225960841
This commit is contained in:
A. Unique TensorFlower 2018-12-18 01:55:09 -08:00 committed by TensorFlower Gardener
parent e4bdb31636
commit 15fa7c49e2
2 changed files with 161 additions and 135 deletions

View File

@ -269,8 +269,11 @@ class PolymorphicFunction(object):
raise ValueError(
"_canonicalize_function_inputs must be called only after _initialize "
"has run.")
# pylint: disable=protected-access
if self._input_signature is None or args or kwds:
return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access
return self._stateful_fn._function_spec.canonicalize_function_inputs(
*args, **kwds)
# pylint: enable=protected-access
# If an input signature is defined, we may need to fetch a concrete function
# without any inputs specified. In this case args and kwds should be ignored
# but running _canonicalize_function_inputs would raise an exception.

View File

@ -769,6 +769,146 @@ def _deterministic_dict_values(dictionary):
return tuple(dictionary[key] for key in sorted(dictionary))
class FunctionSpec(object):
"""Specification of how to bind arguments to a function."""
def __init__(self, python_function, input_signature):
if isinstance(python_function, functools.partial):
python_function_to_inspect = python_function.func
self._args_to_prepend = python_function.args or tuple()
self._kwargs_to_include = python_function.keywords or {}
else:
python_function_to_inspect = python_function
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
fullargspec = tf_inspect.getfullargspec(python_function_to_inspect)
self._default_values = fullargspec.defaults
if tf_inspect.ismethod(python_function_to_inspect):
# Remove `self`: default arguments shouldn't be matched to it.
args = fullargspec.args[1:]
else:
args = fullargspec.args
# A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion.
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
self.arg_names = args
self.vararg_name = fullargspec.varargs
# A cache mapping from arg index to default value, for canonicalization.
offset = len(args) - len(fullargspec.defaults or [])
self._arg_indices_to_default_values = {
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
}
self._default_values_start_index = offset
if input_signature is None:
self.input_signature = None
else:
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
if not isinstance(input_signature, (tuple, list)):
raise TypeError("input_signature must be either a tuple or a "
"list, received " + str(type(input_signature)))
self.input_signature = tuple(input_signature)
self.flat_input_signature = tuple(nest.flatten(input_signature))
def canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using a `FunctionSpec`
instance. In particular, we parse the varags and kwargs that the
original function was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs.
Args:
*args: The varargs this object was called with.
**kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
kwargs = dict(kwargs, **self._kwargs_to_include)
if not kwargs:
if self._default_values:
inputs = args + self._default_values[
len(args) - self._default_values_start_index:]
else:
inputs = args
else:
# Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that
# aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= len(args)
}
consumed_args = []
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
consumed_args.append(arg)
elif self.input_signature is not None:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
# After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
# finding a way to store them directly in the cache key (currently not
# possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
if type(value) == np.ndarray:
flat_inputs[index] = constant_op.constant(value)
need_packing = True
if need_packing:
inputs = nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs)
if self.input_signature is None:
return inputs, kwargs
else:
assert not kwargs
signature_relevant_inputs = inputs[:len(self.input_signature)]
try:
nest.assert_same_structure(self.input_signature,
signature_relevant_inputs)
except (ValueError, TypeError):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
if any(
not pywrap_tensorflow.IsTensor(arg) for arg in signature_inputs_flat):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
if any(not spec.is_compatible_with(other) for spec, other in zip(
self.flat_input_signature, signature_inputs_flat)):
raise ValueError("Python inputs incompatible with input_signature: "
"inputs (%s), input_signature (%s)" %
(str(inputs), str(self.input_signature)))
return inputs, {}
class PolymorphicFunction(object):
"""Wrapper class for the graph functions defined for a Python function.
@ -805,15 +945,11 @@ class PolymorphicFunction(object):
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
"""
if isinstance(python_function, functools.partial):
self._python_function = python_function.func
self._args_to_prepend = python_function.args or tuple()
self._kwargs_to_include = python_function.keywords or {}
else:
self._python_function = python_function
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
self._function_spec = FunctionSpec(python_function, input_signature)
self._name = name
self._autograph = autograph
self._function_cache = collections.OrderedDict()
@ -827,41 +963,6 @@ class PolymorphicFunction(object):
# different functions for each instance.
self._descriptor_cache = weakref.WeakKeyDictionary()
fullargspec = tf_inspect.getfullargspec(self._python_function)
if tf_inspect.ismethod(self._python_function):
# Remove `self`: default arguments shouldn't be matched to it.
args = fullargspec.args[1:]
else:
args = fullargspec.args
# A cache mapping from argument name to index, for canonicalizing
# arguments that are called in a keyword-like fashion.
self._args_to_indices = {arg: i for i, arg in enumerate(args)}
self._arg_names = args
self._vararg_name = fullargspec.varargs
# A cache mapping from arg index to default value, for canonicalization.
offset = len(args) - len(fullargspec.defaults or [])
self._arg_indices_to_default_values = {
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
}
self._default_values = fullargspec.defaults
self._default_values_start_index = offset
if input_signature is None:
self._input_signature = None
else:
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
if not isinstance(input_signature, (tuple, list)):
raise TypeError("input_signature must be either a tuple or a "
"list, received " + str(type(input_signature)))
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
def __call__(self, *args, **kwargs):
"""Calls a graph function specialized to the inputs."""
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
@ -870,7 +971,17 @@ class PolymorphicFunction(object):
@property
def python_function(self):
"""Returns the wrapped Python function."""
return self._python_function
return self._python_function # pylint: disable=protected-access
@property
def _input_signature(self):
"""Returns the wrapped Python function."""
return self._function_spec.input_signature # pylint: disable=protected-access
@property
def _flat_input_signature(self):
"""Returns the wrapped Python function."""
return self._function_spec.flat_input_signature # pylint: disable=protected-access
def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
"""Returns a concrete function which cleans up its graph function."""
@ -1050,96 +1161,6 @@ class PolymorphicFunction(object):
return CacheKey(input_signature, parent_graph, device_functions,
colocation_stack, uses_xla)
def _canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
`PolymorphicFunction` was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs.
Args:
*args: The varargs this object was called with.
**kwargs: The keyword args this function was called with.
Returns:
A canonicalized ordering of the inputs.
Raises:
ValueError: If a keyword in `kwargs` cannot be matched with a positional
argument when an input signature is specified, or when the inputs
do not conform to the input signature.
"""
args = self._args_to_prepend + args
kwargs = dict(kwargs, **self._kwargs_to_include)
if not kwargs:
if self._default_values:
inputs = args + self._default_values[len(args) -
self._default_values_start_index:]
else:
inputs = args
else:
# Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that
# aren't in `args`.
arg_indices_to_values = {
index: default for index, default in six.iteritems(
self._arg_indices_to_default_values) if index >= len(args)
}
consumed_args = []
for arg, value in six.iteritems(kwargs):
index = self._args_to_indices.get(arg, None)
if index is not None:
arg_indices_to_values[index] = value
consumed_args.append(arg)
elif self._input_signature is not None:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
for arg in consumed_args:
# After this loop, `kwargs` will only contain true keyword arguments, as
# opposed to named arguments called in a keyword-like fashion.
kwargs.pop(arg)
inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
# TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
# finding a way to store them directly in the cache key (currently not
# possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
if type(value) == np.ndarray:
flat_inputs[index] = constant_op.constant(value)
need_packing = True
if need_packing:
inputs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_inputs)
if self._input_signature is None:
return inputs, kwargs
else:
assert not kwargs
signature_relevant_inputs = inputs[:len(self._input_signature)]
try:
nest.assert_same_structure(self._input_signature,
signature_relevant_inputs)
except (ValueError, TypeError):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
if any(not pywrap_tensorflow.IsTensor(arg)
for arg in signature_inputs_flat):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
if any(not spec.is_compatible_with(other)
for spec, other in zip(self._flat_input_signature,
signature_inputs_flat)):
raise ValueError("Python inputs incompatible with input_signature: "
"inputs (%s), input_signature (%s)" %
(str(inputs), str(self._input_signature)))
return inputs, {}
def _maybe_define_function(self, args, kwargs):
"""Gets a function for these inputs, defining it if necessary.
@ -1159,7 +1180,8 @@ class PolymorphicFunction(object):
TypeError: If the function inputs include non-hashable objects
"""
if self._input_signature is None or args is not None or kwargs is not None:
args, kwargs = self._canonicalize_function_inputs(*args, **kwargs)
args, kwargs = self._function_spec.canonicalize_function_inputs(
*args, **kwargs)
cache_key = self._cache_key(args, kwargs)
with self._lock:
try:
@ -1177,8 +1199,9 @@ class PolymorphicFunction(object):
else:
arglen = len(self._input_signature)
arg_names = (
self._arg_names[:arglen]
+ [self._vararg_name] * (arglen - len(self._arg_names)))
self._function_spec.arg_names[:arglen]
+ [self._function_spec.vararg_name] *
(arglen - len(self._function_spec.arg_names)))
graph_function = Function(
func_graph_module.func_graph_from_py_func(
self._name,