Serialize FunctionSpec as proto instead of a tuple.
PiperOrigin-RevId: 229174621
This commit is contained in:
parent
3ed46c325f
commit
ad683f866b
@ -866,14 +866,6 @@ def _deterministic_dict_values(dictionary):
|
||||
class FunctionSpec(object):
|
||||
"""Specification of how to bind arguments to a function."""
|
||||
|
||||
def as_tuple(self):
|
||||
return (self._fullargspec, self._is_method, self._args_to_prepend,
|
||||
self._kwargs_to_include, self.input_signature)
|
||||
|
||||
@staticmethod
|
||||
def from_tuple(spec_tuple):
|
||||
return FunctionSpec(*spec_tuple)
|
||||
|
||||
@staticmethod
|
||||
def from_function_and_signature(python_function, input_signature):
|
||||
"""Create a FunctionSpec instance given a python function and signature."""
|
||||
@ -920,7 +912,7 @@ class FunctionSpec(object):
|
||||
}
|
||||
self._default_values_start_index = offset
|
||||
if input_signature is None:
|
||||
self.input_signature = None
|
||||
self._input_signature = None
|
||||
else:
|
||||
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
@ -931,8 +923,32 @@ class FunctionSpec(object):
|
||||
raise TypeError("input_signature must be either a tuple or a "
|
||||
"list, received " + str(type(input_signature)))
|
||||
|
||||
self.input_signature = tuple(input_signature)
|
||||
self.flat_input_signature = tuple(nest.flatten(input_signature))
|
||||
self._input_signature = tuple(input_signature)
|
||||
self._flat_input_signature = tuple(nest.flatten(input_signature))
|
||||
|
||||
@property
|
||||
def fullargspec(self):
|
||||
return self._fullargspec
|
||||
|
||||
@property
|
||||
def is_method(self):
|
||||
return self._is_method
|
||||
|
||||
@property
|
||||
def args_to_prepend(self):
|
||||
return self._args_to_prepend
|
||||
|
||||
@property
|
||||
def kwargs_to_include(self):
|
||||
return self._kwargs_to_include
|
||||
|
||||
@property
|
||||
def input_signature(self):
|
||||
return self._input_signature
|
||||
|
||||
@property
|
||||
def flat_input_signature(self):
|
||||
return self._flat_input_signature
|
||||
|
||||
def canonicalize_function_inputs(self, *args, **kwargs):
|
||||
"""Canonicalizes `args` and `kwargs`.
|
||||
@ -980,7 +996,7 @@ class FunctionSpec(object):
|
||||
if index is not None:
|
||||
arg_indices_to_values[index] = value
|
||||
consumed_args.append(arg)
|
||||
elif self.input_signature is not None:
|
||||
elif self._input_signature is not None:
|
||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||
"function with keyword arguments when "
|
||||
"input_signature is provided.")
|
||||
@ -1003,12 +1019,13 @@ class FunctionSpec(object):
|
||||
if need_packing:
|
||||
inputs = nest.pack_sequence_as(
|
||||
structure=inputs, flat_sequence=flat_inputs)
|
||||
if self.input_signature is None:
|
||||
if self._input_signature is None:
|
||||
return inputs, kwargs
|
||||
else:
|
||||
assert not kwargs
|
||||
signature_relevant_inputs = inputs[:len(self.input_signature)]
|
||||
if not is_same_structure(self.input_signature, signature_relevant_inputs):
|
||||
signature_relevant_inputs = inputs[:len(self._input_signature)]
|
||||
if not is_same_structure(self._input_signature,
|
||||
signature_relevant_inputs):
|
||||
raise ValueError("Structure of Python function inputs does not match "
|
||||
"input_signature.")
|
||||
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
||||
@ -1017,10 +1034,10 @@ class FunctionSpec(object):
|
||||
raise ValueError("When input_signature is provided, all inputs to "
|
||||
"the Python function must be Tensors.")
|
||||
if any(not spec.is_compatible_with(other) for spec, other in zip(
|
||||
self.flat_input_signature, signature_inputs_flat)):
|
||||
self._flat_input_signature, signature_inputs_flat)):
|
||||
raise ValueError("Python inputs incompatible with input_signature: "
|
||||
"inputs (%s), input_signature (%s)" %
|
||||
(str(inputs), str(self.input_signature)))
|
||||
(str(inputs), str(self._input_signature)))
|
||||
return inputs, {}
|
||||
|
||||
|
||||
|
@ -60,6 +60,17 @@ def _inputs_compatible(args, stored_inputs):
|
||||
return True
|
||||
|
||||
|
||||
def _deserialize_function_spec(function_spec_proto, coder):
|
||||
"""Deserialize a FunctionSpec object from its proto representation."""
|
||||
fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
|
||||
is_method = function_spec_proto.is_method
|
||||
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
|
||||
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
|
||||
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
||||
return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
|
||||
kwargs_to_include, input_signature)
|
||||
|
||||
|
||||
def recreate_polymorphic_function(
|
||||
saved_polymorphic_function, functions):
|
||||
"""Creates a PolymorphicFunction from a SavedPolymorphicFunction.
|
||||
@ -77,9 +88,8 @@ def recreate_polymorphic_function(
|
||||
# serialization cycle.
|
||||
|
||||
coder = nested_structure_coder.StructureCoder()
|
||||
function_spec_tuple = coder.decode_proto(
|
||||
saved_polymorphic_function.function_spec_tuple)
|
||||
function_spec = function_lib.FunctionSpec.from_tuple(function_spec_tuple)
|
||||
function_spec = _deserialize_function_spec(
|
||||
saved_polymorphic_function.function_spec, coder)
|
||||
|
||||
# TODO(mdan): We may enable autograph once exceptions are supported.
|
||||
@def_function.function(autograph=False)
|
||||
|
@ -25,13 +25,28 @@ from tensorflow.python.saved_model import saved_object_graph_pb2
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
def _serialize_function_spec(function_spec, coder):
|
||||
"""Serialize a FunctionSpec object into its proto representation."""
|
||||
proto = saved_object_graph_pb2.FunctionSpec()
|
||||
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
|
||||
proto.is_method = function_spec.is_method
|
||||
proto.args_to_prepend.CopyFrom(
|
||||
coder.encode_structure(function_spec.args_to_prepend))
|
||||
proto.kwargs_to_include.CopyFrom(
|
||||
coder.encode_structure(function_spec.kwargs_to_include))
|
||||
proto.input_signature.CopyFrom(
|
||||
coder.encode_structure(function_spec.input_signature))
|
||||
return proto
|
||||
|
||||
|
||||
def serialize_polymorphic_function(polymorphic_function, node_ids):
|
||||
"""Build a SavedPolymorphicProto."""
|
||||
coder = nested_structure_coder.StructureCoder()
|
||||
proto = saved_object_graph_pb2.SavedPolymorphicFunction()
|
||||
|
||||
proto.function_spec_tuple.CopyFrom(
|
||||
coder.encode_structure(polymorphic_function.function_spec.as_tuple())) # pylint: disable=protected-access
|
||||
function_spec_proto = _serialize_function_spec(
|
||||
polymorphic_function.function_spec, coder)
|
||||
proto.function_spec.CopyFrom(function_spec_proto)
|
||||
for signature, concrete_function in list_all_concrete_functions(
|
||||
polymorphic_function):
|
||||
bound_inputs = []
|
||||
|
@ -86,9 +86,7 @@ message SavedAsset {
|
||||
// A function with multiple signatures, possibly with non-Tensor arguments.
|
||||
message SavedPolymorphicFunction {
|
||||
repeated SavedMonomorphicFunction monomorphic_function = 1;
|
||||
// Tuple representing a `FunctionSpec`.
|
||||
// TODO(vbardiovsky): Make this a proto.
|
||||
StructuredValue function_spec_tuple = 2;
|
||||
FunctionSpec function_spec = 2;
|
||||
}
|
||||
|
||||
message SavedMonomorphicFunction {
|
||||
@ -118,3 +116,20 @@ message SavedVariable {
|
||||
|
||||
// TODO(andresp): Add save_slice_info_def?
|
||||
}
|
||||
|
||||
// Represents FunctionSpec used in PolymorphicFunction. This represents a
|
||||
// function that has been wrapped as a PolymorphicFunction.
|
||||
message FunctionSpec {
|
||||
// Full arg spec from inspect.getfullargspec().
|
||||
StructuredValue fullargspec = 1;
|
||||
// Whether this represents a class method.
|
||||
bool is_method = 2;
|
||||
// Which arguments to always prepend, in case the original function is based
|
||||
// on a functools.partial.
|
||||
StructuredValue args_to_prepend = 3;
|
||||
// Which kwargs to always include, in case the original function is based on a
|
||||
// functools.partial.
|
||||
StructuredValue kwargs_to_include = 4;
|
||||
// The input signature, if specified.
|
||||
StructuredValue input_signature = 5;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user