From 15fa7c49e27963df5304d7f827e6c4459079cc18 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Dec 2018 01:55:09 -0800 Subject: [PATCH] Extract function spec information that is needed for canonicalization. This will allow serialization of all the information we need for canonicalization. PiperOrigin-RevId: 225960841 --- tensorflow/python/eager/def_function.py | 5 +- tensorflow/python/eager/function.py | 291 +++++++++++++----------- 2 files changed, 161 insertions(+), 135 deletions(-) diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 52b481915ef..fc14558cc72 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -269,8 +269,11 @@ class PolymorphicFunction(object): raise ValueError( "_canonicalize_function_inputs must be called only after _initialize " "has run.") + # pylint: disable=protected-access if self._input_signature is None or args or kwds: - return self._stateful_fn._canonicalize_function_inputs(*args, **kwds) # pylint: disable=protected-access + return self._stateful_fn._function_spec.canonicalize_function_inputs( + *args, **kwds) + # pylint: enable=protected-access # If an input signature is defined, we may need to fetch a concrete function # without any inputs specified. In this case args and kwds should be ignored # but running _canonicalize_function_inputs would raise an exception. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 885403dd10c..7ba9f9290bc 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -769,6 +769,146 @@ def _deterministic_dict_values(dictionary): return tuple(dictionary[key] for key in sorted(dictionary)) +class FunctionSpec(object): + """Specification of how to bind arguments to a function.""" + + def __init__(self, python_function, input_signature): + if isinstance(python_function, functools.partial): + python_function_to_inspect = python_function.func + self._args_to_prepend = python_function.args or tuple() + self._kwargs_to_include = python_function.keywords or {} + else: + python_function_to_inspect = python_function + self._args_to_prepend = tuple() + self._kwargs_to_include = {} + + fullargspec = tf_inspect.getfullargspec(python_function_to_inspect) + self._default_values = fullargspec.defaults + + if tf_inspect.ismethod(python_function_to_inspect): + # Remove `self`: default arguments shouldn't be matched to it. + args = fullargspec.args[1:] + else: + args = fullargspec.args + + # A cache mapping from argument name to index, for canonicalizing + # arguments that are called in a keyword-like fashion. + self._args_to_indices = {arg: i for i, arg in enumerate(args)} + self.arg_names = args + self.vararg_name = fullargspec.varargs + + # A cache mapping from arg index to default value, for canonicalization. + offset = len(args) - len(fullargspec.defaults or []) + self._arg_indices_to_default_values = { + offset + index: default + for index, default in enumerate(fullargspec.defaults or []) + } + self._default_values_start_index = offset + if input_signature is None: + self.input_signature = None + else: + if fullargspec.varkw is not None or fullargspec.kwonlyargs: + raise ValueError("Cannot define a TensorFlow function from a Python " + "function with keyword arguments when " + "input_signature is provided.") + + if not isinstance(input_signature, (tuple, list)): + raise TypeError("input_signature must be either a tuple or a " + "list, received " + str(type(input_signature))) + + self.input_signature = tuple(input_signature) + self.flat_input_signature = tuple(nest.flatten(input_signature)) + + def canonicalize_function_inputs(self, *args, **kwargs): + """Canonicalizes `args` and `kwargs`. + + Canonicalize the inputs to the Python function using a `FunctionSpec` + instance. In particular, we parse the varags and kwargs that the + original function was called with into a tuple corresponding to the + Python function's positional (named) arguments and a dictionary + corresponding to its kwargs. + + Args: + *args: The varargs this object was called with. + **kwargs: The keyword args this function was called with. + + Returns: + A canonicalized ordering of the inputs. + + Raises: + ValueError: If a keyword in `kwargs` cannot be matched with a positional + argument when an input signature is specified, or when the inputs + do not conform to the input signature. + """ + args = self._args_to_prepend + args + kwargs = dict(kwargs, **self._kwargs_to_include) + if not kwargs: + if self._default_values: + inputs = args + self._default_values[ + len(args) - self._default_values_start_index:] + else: + inputs = args + else: + # Maps from index of arg to its corresponding value, according to `args` + # and `kwargs`; seeded with the default values for the named args that + # aren't in `args`. + arg_indices_to_values = { + index: default for index, default in six.iteritems( + self._arg_indices_to_default_values) if index >= len(args) + } + consumed_args = [] + for arg, value in six.iteritems(kwargs): + index = self._args_to_indices.get(arg, None) + if index is not None: + arg_indices_to_values[index] = value + consumed_args.append(arg) + elif self.input_signature is not None: + raise ValueError("Cannot define a TensorFlow function from a Python " + "function with keyword arguments when " + "input_signature is provided.") + for arg in consumed_args: + # After this loop, `kwargs` will only contain true keyword arguments, as + # opposed to named arguments called in a keyword-like fashion. + kwargs.pop(arg) + inputs = args + _deterministic_dict_values(arg_indices_to_values) + flat_inputs = nest.flatten(inputs) + + # Check for NumPy arrays in arguments and convert them to Tensors. + # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps + # finding a way to store them directly in the cache key (currently not + # possible since ndarrays are not hashable). + need_packing = False + for index, value in enumerate(flat_inputs): + if type(value) == np.ndarray: + flat_inputs[index] = constant_op.constant(value) + need_packing = True + if need_packing: + inputs = nest.pack_sequence_as( + structure=inputs, flat_sequence=flat_inputs) + if self.input_signature is None: + return inputs, kwargs + else: + assert not kwargs + signature_relevant_inputs = inputs[:len(self.input_signature)] + try: + nest.assert_same_structure(self.input_signature, + signature_relevant_inputs) + except (ValueError, TypeError): + raise ValueError("Structure of Python function inputs does not match " + "input_signature.") + signature_inputs_flat = nest.flatten(signature_relevant_inputs) + if any( + not pywrap_tensorflow.IsTensor(arg) for arg in signature_inputs_flat): + raise ValueError("When input_signature is provided, all inputs to " + "the Python function must be Tensors.") + if any(not spec.is_compatible_with(other) for spec, other in zip( + self.flat_input_signature, signature_inputs_flat)): + raise ValueError("Python inputs incompatible with input_signature: " + "inputs (%s), input_signature (%s)" % + (str(inputs), str(self.input_signature))) + return inputs, {} + + class PolymorphicFunction(object): """Wrapper class for the graph functions defined for a Python function. @@ -805,15 +945,11 @@ class PolymorphicFunction(object): ValueError: if `input_signature` is not None and the `python_function`'s argspec has keyword arguments. """ - if isinstance(python_function, functools.partial): self._python_function = python_function.func - self._args_to_prepend = python_function.args or tuple() - self._kwargs_to_include = python_function.keywords or {} else: self._python_function = python_function - self._args_to_prepend = tuple() - self._kwargs_to_include = {} + self._function_spec = FunctionSpec(python_function, input_signature) self._name = name self._autograph = autograph self._function_cache = collections.OrderedDict() @@ -827,41 +963,6 @@ class PolymorphicFunction(object): # different functions for each instance. self._descriptor_cache = weakref.WeakKeyDictionary() - fullargspec = tf_inspect.getfullargspec(self._python_function) - if tf_inspect.ismethod(self._python_function): - # Remove `self`: default arguments shouldn't be matched to it. - args = fullargspec.args[1:] - else: - args = fullargspec.args - - # A cache mapping from argument name to index, for canonicalizing - # arguments that are called in a keyword-like fashion. - self._args_to_indices = {arg: i for i, arg in enumerate(args)} - self._arg_names = args - self._vararg_name = fullargspec.varargs - # A cache mapping from arg index to default value, for canonicalization. - offset = len(args) - len(fullargspec.defaults or []) - self._arg_indices_to_default_values = { - offset + index: default - for index, default in enumerate(fullargspec.defaults or []) - } - self._default_values = fullargspec.defaults - self._default_values_start_index = offset - if input_signature is None: - self._input_signature = None - else: - if fullargspec.varkw is not None or fullargspec.kwonlyargs: - raise ValueError("Cannot define a TensorFlow function from a Python " - "function with keyword arguments when " - "input_signature is provided.") - - if not isinstance(input_signature, (tuple, list)): - raise TypeError("input_signature must be either a tuple or a " - "list, received " + str(type(input_signature))) - - self._input_signature = tuple(input_signature) - self._flat_input_signature = tuple(nest.flatten(input_signature)) - def __call__(self, *args, **kwargs): """Calls a graph function specialized to the inputs.""" graph_function, args, kwargs = self._maybe_define_function(args, kwargs) @@ -870,7 +971,17 @@ class PolymorphicFunction(object): @property def python_function(self): """Returns the wrapped Python function.""" - return self._python_function + return self._python_function # pylint: disable=protected-access + + @property + def _input_signature(self): + """Returns the wrapped Python function.""" + return self._function_spec.input_signature # pylint: disable=protected-access + + @property + def _flat_input_signature(self): + """Returns the wrapped Python function.""" + return self._function_spec.flat_input_signature # pylint: disable=protected-access def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): """Returns a concrete function which cleans up its graph function.""" @@ -1050,96 +1161,6 @@ class PolymorphicFunction(object): return CacheKey(input_signature, parent_graph, device_functions, colocation_stack, uses_xla) - def _canonicalize_function_inputs(self, *args, **kwargs): - """Canonicalizes `args` and `kwargs`. - - Canonicalize the inputs to the Python function using its fullargspec. In - particular, we parse the varags and kwargs that this - `PolymorphicFunction` was called with into a tuple corresponding to the - Python function's positional (named) arguments and a dictionary - corresponding to its kwargs. - - Args: - *args: The varargs this object was called with. - **kwargs: The keyword args this function was called with. - - Returns: - A canonicalized ordering of the inputs. - - Raises: - ValueError: If a keyword in `kwargs` cannot be matched with a positional - argument when an input signature is specified, or when the inputs - do not conform to the input signature. - """ - args = self._args_to_prepend + args - kwargs = dict(kwargs, **self._kwargs_to_include) - if not kwargs: - if self._default_values: - inputs = args + self._default_values[len(args) - - self._default_values_start_index:] - else: - inputs = args - else: - # Maps from index of arg to its corresponding value, according to `args` - # and `kwargs`; seeded with the default values for the named args that - # aren't in `args`. - arg_indices_to_values = { - index: default for index, default in six.iteritems( - self._arg_indices_to_default_values) if index >= len(args) - } - consumed_args = [] - for arg, value in six.iteritems(kwargs): - index = self._args_to_indices.get(arg, None) - if index is not None: - arg_indices_to_values[index] = value - consumed_args.append(arg) - elif self._input_signature is not None: - raise ValueError("Cannot define a TensorFlow function from a Python " - "function with keyword arguments when " - "input_signature is provided.") - for arg in consumed_args: - # After this loop, `kwargs` will only contain true keyword arguments, as - # opposed to named arguments called in a keyword-like fashion. - kwargs.pop(arg) - inputs = args + _deterministic_dict_values(arg_indices_to_values) - flat_inputs = nest.flatten(inputs) - - # Check for NumPy arrays in arguments and convert them to Tensors. - # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps - # finding a way to store them directly in the cache key (currently not - # possible since ndarrays are not hashable). - need_packing = False - for index, value in enumerate(flat_inputs): - if type(value) == np.ndarray: - flat_inputs[index] = constant_op.constant(value) - need_packing = True - if need_packing: - inputs = nest.pack_sequence_as(structure=inputs, - flat_sequence=flat_inputs) - if self._input_signature is None: - return inputs, kwargs - else: - assert not kwargs - signature_relevant_inputs = inputs[:len(self._input_signature)] - try: - nest.assert_same_structure(self._input_signature, - signature_relevant_inputs) - except (ValueError, TypeError): - raise ValueError("Structure of Python function inputs does not match " - "input_signature.") - signature_inputs_flat = nest.flatten(signature_relevant_inputs) - if any(not pywrap_tensorflow.IsTensor(arg) - for arg in signature_inputs_flat): - raise ValueError("When input_signature is provided, all inputs to " - "the Python function must be Tensors.") - if any(not spec.is_compatible_with(other) - for spec, other in zip(self._flat_input_signature, - signature_inputs_flat)): - raise ValueError("Python inputs incompatible with input_signature: " - "inputs (%s), input_signature (%s)" % - (str(inputs), str(self._input_signature))) - return inputs, {} - def _maybe_define_function(self, args, kwargs): """Gets a function for these inputs, defining it if necessary. @@ -1159,7 +1180,8 @@ class PolymorphicFunction(object): TypeError: If the function inputs include non-hashable objects """ if self._input_signature is None or args is not None or kwargs is not None: - args, kwargs = self._canonicalize_function_inputs(*args, **kwargs) + args, kwargs = self._function_spec.canonicalize_function_inputs( + *args, **kwargs) cache_key = self._cache_key(args, kwargs) with self._lock: try: @@ -1177,8 +1199,9 @@ class PolymorphicFunction(object): else: arglen = len(self._input_signature) arg_names = ( - self._arg_names[:arglen] - + [self._vararg_name] * (arglen - len(self._arg_names))) + self._function_spec.arg_names[:arglen] + + [self._function_spec.vararg_name] * + (arglen - len(self._function_spec.arg_names))) graph_function = Function( func_graph_module.func_graph_from_py_func( self._name,