Serialize FunctionSpec as proto instead of a tuple.

PiperOrigin-RevId: 229174621
This commit is contained in:
Vojtech Bardiovsky 2019-01-14 06:51:47 -08:00 committed by TensorFlower Gardener
parent 3ed46c325f
commit ad683f866b
4 changed files with 82 additions and 25 deletions

View File

@ -866,14 +866,6 @@ def _deterministic_dict_values(dictionary):
class FunctionSpec(object):
"""Specification of how to bind arguments to a function."""
def as_tuple(self):
return (self._fullargspec, self._is_method, self._args_to_prepend,
self._kwargs_to_include, self.input_signature)
@staticmethod
def from_tuple(spec_tuple):
return FunctionSpec(*spec_tuple)
@staticmethod
def from_function_and_signature(python_function, input_signature):
"""Create a FunctionSpec instance given a python function and signature."""
@ -920,7 +912,7 @@ class FunctionSpec(object):
}
self._default_values_start_index = offset
if input_signature is None:
self.input_signature = None
self._input_signature = None
else:
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
raise ValueError("Cannot define a TensorFlow function from a Python "
@ -931,8 +923,32 @@ class FunctionSpec(object):
raise TypeError("input_signature must be either a tuple or a "
"list, received " + str(type(input_signature)))
self.input_signature = tuple(input_signature)
self.flat_input_signature = tuple(nest.flatten(input_signature))
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
@property
def fullargspec(self):
return self._fullargspec
@property
def is_method(self):
return self._is_method
@property
def args_to_prepend(self):
return self._args_to_prepend
@property
def kwargs_to_include(self):
return self._kwargs_to_include
@property
def input_signature(self):
return self._input_signature
@property
def flat_input_signature(self):
return self._flat_input_signature
def canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
@ -980,7 +996,7 @@ class FunctionSpec(object):
if index is not None:
arg_indices_to_values[index] = value
consumed_args.append(arg)
elif self.input_signature is not None:
elif self._input_signature is not None:
raise ValueError("Cannot define a TensorFlow function from a Python "
"function with keyword arguments when "
"input_signature is provided.")
@ -1003,12 +1019,13 @@ class FunctionSpec(object):
if need_packing:
inputs = nest.pack_sequence_as(
structure=inputs, flat_sequence=flat_inputs)
if self.input_signature is None:
if self._input_signature is None:
return inputs, kwargs
else:
assert not kwargs
signature_relevant_inputs = inputs[:len(self.input_signature)]
if not is_same_structure(self.input_signature, signature_relevant_inputs):
signature_relevant_inputs = inputs[:len(self._input_signature)]
if not is_same_structure(self._input_signature,
signature_relevant_inputs):
raise ValueError("Structure of Python function inputs does not match "
"input_signature.")
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
@ -1017,10 +1034,10 @@ class FunctionSpec(object):
raise ValueError("When input_signature is provided, all inputs to "
"the Python function must be Tensors.")
if any(not spec.is_compatible_with(other) for spec, other in zip(
self.flat_input_signature, signature_inputs_flat)):
self._flat_input_signature, signature_inputs_flat)):
raise ValueError("Python inputs incompatible with input_signature: "
"inputs (%s), input_signature (%s)" %
(str(inputs), str(self.input_signature)))
(str(inputs), str(self._input_signature)))
return inputs, {}

View File

@ -60,6 +60,17 @@ def _inputs_compatible(args, stored_inputs):
return True
def _deserialize_function_spec(function_spec_proto, coder):
"""Deserialize a FunctionSpec object from its proto representation."""
fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
is_method = function_spec_proto.is_method
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
input_signature = coder.decode_proto(function_spec_proto.input_signature)
return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
kwargs_to_include, input_signature)
def recreate_polymorphic_function(
saved_polymorphic_function, functions):
"""Creates a PolymorphicFunction from a SavedPolymorphicFunction.
@ -77,9 +88,8 @@ def recreate_polymorphic_function(
# serialization cycle.
coder = nested_structure_coder.StructureCoder()
function_spec_tuple = coder.decode_proto(
saved_polymorphic_function.function_spec_tuple)
function_spec = function_lib.FunctionSpec.from_tuple(function_spec_tuple)
function_spec = _deserialize_function_spec(
saved_polymorphic_function.function_spec, coder)
# TODO(mdan): We may enable autograph once exceptions are supported.
@def_function.function(autograph=False)

View File

@ -25,13 +25,28 @@ from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.util import nest
def _serialize_function_spec(function_spec, coder):
"""Serialize a FunctionSpec object into its proto representation."""
proto = saved_object_graph_pb2.FunctionSpec()
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
proto.is_method = function_spec.is_method
proto.args_to_prepend.CopyFrom(
coder.encode_structure(function_spec.args_to_prepend))
proto.kwargs_to_include.CopyFrom(
coder.encode_structure(function_spec.kwargs_to_include))
proto.input_signature.CopyFrom(
coder.encode_structure(function_spec.input_signature))
return proto
def serialize_polymorphic_function(polymorphic_function, node_ids):
"""Build a SavedPolymorphicProto."""
coder = nested_structure_coder.StructureCoder()
proto = saved_object_graph_pb2.SavedPolymorphicFunction()
proto.function_spec_tuple.CopyFrom(
coder.encode_structure(polymorphic_function.function_spec.as_tuple())) # pylint: disable=protected-access
function_spec_proto = _serialize_function_spec(
polymorphic_function.function_spec, coder)
proto.function_spec.CopyFrom(function_spec_proto)
for signature, concrete_function in list_all_concrete_functions(
polymorphic_function):
bound_inputs = []

View File

@ -86,9 +86,7 @@ message SavedAsset {
// A function with multiple signatures, possibly with non-Tensor arguments.
message SavedPolymorphicFunction {
repeated SavedMonomorphicFunction monomorphic_function = 1;
// Tuple representing a `FunctionSpec`.
// TODO(vbardiovsky): Make this a proto.
StructuredValue function_spec_tuple = 2;
FunctionSpec function_spec = 2;
}
message SavedMonomorphicFunction {
@ -118,3 +116,20 @@ message SavedVariable {
// TODO(andresp): Add save_slice_info_def?
}
// Represents FunctionSpec used in PolymorphicFunction. This represents a
// function that has been wrapped as a PolymorphicFunction.
message FunctionSpec {
// Full arg spec from inspect.getfullargspec().
StructuredValue fullargspec = 1;
// Whether this represents a class method.
bool is_method = 2;
// Which arguments to always prepend, in case the original function is based
// on a functools.partial.
StructuredValue args_to_prepend = 3;
// Which kwargs to always include, in case the original function is based on a
// functools.partial.
StructuredValue kwargs_to_include = 4;
// The input signature, if specified.
StructuredValue input_signature = 5;
}