diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index bbf8630a2e8..6429022df3a 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -866,14 +866,6 @@ def _deterministic_dict_values(dictionary): class FunctionSpec(object): """Specification of how to bind arguments to a function.""" - def as_tuple(self): - return (self._fullargspec, self._is_method, self._args_to_prepend, - self._kwargs_to_include, self.input_signature) - - @staticmethod - def from_tuple(spec_tuple): - return FunctionSpec(*spec_tuple) - @staticmethod def from_function_and_signature(python_function, input_signature): """Create a FunctionSpec instance given a python function and signature.""" @@ -920,7 +912,7 @@ class FunctionSpec(object): } self._default_values_start_index = offset if input_signature is None: - self.input_signature = None + self._input_signature = None else: if fullargspec.varkw is not None or fullargspec.kwonlyargs: raise ValueError("Cannot define a TensorFlow function from a Python " @@ -931,8 +923,32 @@ class FunctionSpec(object): raise TypeError("input_signature must be either a tuple or a " "list, received " + str(type(input_signature))) - self.input_signature = tuple(input_signature) - self.flat_input_signature = tuple(nest.flatten(input_signature)) + self._input_signature = tuple(input_signature) + self._flat_input_signature = tuple(nest.flatten(input_signature)) + + @property + def fullargspec(self): + return self._fullargspec + + @property + def is_method(self): + return self._is_method + + @property + def args_to_prepend(self): + return self._args_to_prepend + + @property + def kwargs_to_include(self): + return self._kwargs_to_include + + @property + def input_signature(self): + return self._input_signature + + @property + def flat_input_signature(self): + return self._flat_input_signature def canonicalize_function_inputs(self, *args, **kwargs): """Canonicalizes `args` and `kwargs`. @@ -980,7 +996,7 @@ class FunctionSpec(object): if index is not None: arg_indices_to_values[index] = value consumed_args.append(arg) - elif self.input_signature is not None: + elif self._input_signature is not None: raise ValueError("Cannot define a TensorFlow function from a Python " "function with keyword arguments when " "input_signature is provided.") @@ -1003,12 +1019,13 @@ class FunctionSpec(object): if need_packing: inputs = nest.pack_sequence_as( structure=inputs, flat_sequence=flat_inputs) - if self.input_signature is None: + if self._input_signature is None: return inputs, kwargs else: assert not kwargs - signature_relevant_inputs = inputs[:len(self.input_signature)] - if not is_same_structure(self.input_signature, signature_relevant_inputs): + signature_relevant_inputs = inputs[:len(self._input_signature)] + if not is_same_structure(self._input_signature, + signature_relevant_inputs): raise ValueError("Structure of Python function inputs does not match " "input_signature.") signature_inputs_flat = nest.flatten(signature_relevant_inputs) @@ -1017,10 +1034,10 @@ class FunctionSpec(object): raise ValueError("When input_signature is provided, all inputs to " "the Python function must be Tensors.") if any(not spec.is_compatible_with(other) for spec, other in zip( - self.flat_input_signature, signature_inputs_flat)): + self._flat_input_signature, signature_inputs_flat)): raise ValueError("Python inputs incompatible with input_signature: " "inputs (%s), input_signature (%s)" % - (str(inputs), str(self.input_signature))) + (str(inputs), str(self._input_signature))) return inputs, {} diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index ec6cefb606b..e19a46ac192 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -60,6 +60,17 @@ def _inputs_compatible(args, stored_inputs): return True +def _deserialize_function_spec(function_spec_proto, coder): + """Deserialize a FunctionSpec object from its proto representation.""" + fullargspec = coder.decode_proto(function_spec_proto.fullargspec) + is_method = function_spec_proto.is_method + args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend) + kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include) + input_signature = coder.decode_proto(function_spec_proto.input_signature) + return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend, + kwargs_to_include, input_signature) + + def recreate_polymorphic_function( saved_polymorphic_function, functions): """Creates a PolymorphicFunction from a SavedPolymorphicFunction. @@ -77,9 +88,8 @@ def recreate_polymorphic_function( # serialization cycle. coder = nested_structure_coder.StructureCoder() - function_spec_tuple = coder.decode_proto( - saved_polymorphic_function.function_spec_tuple) - function_spec = function_lib.FunctionSpec.from_tuple(function_spec_tuple) + function_spec = _deserialize_function_spec( + saved_polymorphic_function.function_spec, coder) # TODO(mdan): We may enable autograph once exceptions are supported. @def_function.function(autograph=False) diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py index 267852ef4a2..d2c5b33a5fd 100644 --- a/tensorflow/python/saved_model/function_serialization.py +++ b/tensorflow/python/saved_model/function_serialization.py @@ -25,13 +25,28 @@ from tensorflow.python.saved_model import saved_object_graph_pb2 from tensorflow.python.util import nest +def _serialize_function_spec(function_spec, coder): + """Serialize a FunctionSpec object into its proto representation.""" + proto = saved_object_graph_pb2.FunctionSpec() + proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec)) + proto.is_method = function_spec.is_method + proto.args_to_prepend.CopyFrom( + coder.encode_structure(function_spec.args_to_prepend)) + proto.kwargs_to_include.CopyFrom( + coder.encode_structure(function_spec.kwargs_to_include)) + proto.input_signature.CopyFrom( + coder.encode_structure(function_spec.input_signature)) + return proto + + def serialize_polymorphic_function(polymorphic_function, node_ids): """Build a SavedPolymorphicProto.""" coder = nested_structure_coder.StructureCoder() proto = saved_object_graph_pb2.SavedPolymorphicFunction() - proto.function_spec_tuple.CopyFrom( - coder.encode_structure(polymorphic_function.function_spec.as_tuple())) # pylint: disable=protected-access + function_spec_proto = _serialize_function_spec( + polymorphic_function.function_spec, coder) + proto.function_spec.CopyFrom(function_spec_proto) for signature, concrete_function in list_all_concrete_functions( polymorphic_function): bound_inputs = [] diff --git a/tensorflow/python/saved_model/saved_object_graph.proto b/tensorflow/python/saved_model/saved_object_graph.proto index a322726e1ed..f48d2d2b7e6 100644 --- a/tensorflow/python/saved_model/saved_object_graph.proto +++ b/tensorflow/python/saved_model/saved_object_graph.proto @@ -86,9 +86,7 @@ message SavedAsset { // A function with multiple signatures, possibly with non-Tensor arguments. message SavedPolymorphicFunction { repeated SavedMonomorphicFunction monomorphic_function = 1; - // Tuple representing a `FunctionSpec`. - // TODO(vbardiovsky): Make this a proto. - StructuredValue function_spec_tuple = 2; + FunctionSpec function_spec = 2; } message SavedMonomorphicFunction { @@ -118,3 +116,20 @@ message SavedVariable { // TODO(andresp): Add save_slice_info_def? } + +// Represents FunctionSpec used in PolymorphicFunction. This represents a +// function that has been wrapped as a PolymorphicFunction. +message FunctionSpec { + // Full arg spec from inspect.getfullargspec(). + StructuredValue fullargspec = 1; + // Whether this represents a class method. + bool is_method = 2; + // Which arguments to always prepend, in case the original function is based + // on a functools.partial. + StructuredValue args_to_prepend = 3; + // Which kwargs to always include, in case the original function is based on a + // functools.partial. + StructuredValue kwargs_to_include = 4; + // The input signature, if specified. + StructuredValue input_signature = 5; +}