Serialize FunctionSpec as proto instead of a tuple.
PiperOrigin-RevId: 229174621
This commit is contained in:
parent
3ed46c325f
commit
ad683f866b
@ -866,14 +866,6 @@ def _deterministic_dict_values(dictionary):
|
|||||||
class FunctionSpec(object):
|
class FunctionSpec(object):
|
||||||
"""Specification of how to bind arguments to a function."""
|
"""Specification of how to bind arguments to a function."""
|
||||||
|
|
||||||
def as_tuple(self):
|
|
||||||
return (self._fullargspec, self._is_method, self._args_to_prepend,
|
|
||||||
self._kwargs_to_include, self.input_signature)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_tuple(spec_tuple):
|
|
||||||
return FunctionSpec(*spec_tuple)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_function_and_signature(python_function, input_signature):
|
def from_function_and_signature(python_function, input_signature):
|
||||||
"""Create a FunctionSpec instance given a python function and signature."""
|
"""Create a FunctionSpec instance given a python function and signature."""
|
||||||
@ -920,7 +912,7 @@ class FunctionSpec(object):
|
|||||||
}
|
}
|
||||||
self._default_values_start_index = offset
|
self._default_values_start_index = offset
|
||||||
if input_signature is None:
|
if input_signature is None:
|
||||||
self.input_signature = None
|
self._input_signature = None
|
||||||
else:
|
else:
|
||||||
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
if fullargspec.varkw is not None or fullargspec.kwonlyargs:
|
||||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||||
@ -931,8 +923,32 @@ class FunctionSpec(object):
|
|||||||
raise TypeError("input_signature must be either a tuple or a "
|
raise TypeError("input_signature must be either a tuple or a "
|
||||||
"list, received " + str(type(input_signature)))
|
"list, received " + str(type(input_signature)))
|
||||||
|
|
||||||
self.input_signature = tuple(input_signature)
|
self._input_signature = tuple(input_signature)
|
||||||
self.flat_input_signature = tuple(nest.flatten(input_signature))
|
self._flat_input_signature = tuple(nest.flatten(input_signature))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fullargspec(self):
|
||||||
|
return self._fullargspec
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_method(self):
|
||||||
|
return self._is_method
|
||||||
|
|
||||||
|
@property
|
||||||
|
def args_to_prepend(self):
|
||||||
|
return self._args_to_prepend
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kwargs_to_include(self):
|
||||||
|
return self._kwargs_to_include
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_signature(self):
|
||||||
|
return self._input_signature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flat_input_signature(self):
|
||||||
|
return self._flat_input_signature
|
||||||
|
|
||||||
def canonicalize_function_inputs(self, *args, **kwargs):
|
def canonicalize_function_inputs(self, *args, **kwargs):
|
||||||
"""Canonicalizes `args` and `kwargs`.
|
"""Canonicalizes `args` and `kwargs`.
|
||||||
@ -980,7 +996,7 @@ class FunctionSpec(object):
|
|||||||
if index is not None:
|
if index is not None:
|
||||||
arg_indices_to_values[index] = value
|
arg_indices_to_values[index] = value
|
||||||
consumed_args.append(arg)
|
consumed_args.append(arg)
|
||||||
elif self.input_signature is not None:
|
elif self._input_signature is not None:
|
||||||
raise ValueError("Cannot define a TensorFlow function from a Python "
|
raise ValueError("Cannot define a TensorFlow function from a Python "
|
||||||
"function with keyword arguments when "
|
"function with keyword arguments when "
|
||||||
"input_signature is provided.")
|
"input_signature is provided.")
|
||||||
@ -1003,12 +1019,13 @@ class FunctionSpec(object):
|
|||||||
if need_packing:
|
if need_packing:
|
||||||
inputs = nest.pack_sequence_as(
|
inputs = nest.pack_sequence_as(
|
||||||
structure=inputs, flat_sequence=flat_inputs)
|
structure=inputs, flat_sequence=flat_inputs)
|
||||||
if self.input_signature is None:
|
if self._input_signature is None:
|
||||||
return inputs, kwargs
|
return inputs, kwargs
|
||||||
else:
|
else:
|
||||||
assert not kwargs
|
assert not kwargs
|
||||||
signature_relevant_inputs = inputs[:len(self.input_signature)]
|
signature_relevant_inputs = inputs[:len(self._input_signature)]
|
||||||
if not is_same_structure(self.input_signature, signature_relevant_inputs):
|
if not is_same_structure(self._input_signature,
|
||||||
|
signature_relevant_inputs):
|
||||||
raise ValueError("Structure of Python function inputs does not match "
|
raise ValueError("Structure of Python function inputs does not match "
|
||||||
"input_signature.")
|
"input_signature.")
|
||||||
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
signature_inputs_flat = nest.flatten(signature_relevant_inputs)
|
||||||
@ -1017,10 +1034,10 @@ class FunctionSpec(object):
|
|||||||
raise ValueError("When input_signature is provided, all inputs to "
|
raise ValueError("When input_signature is provided, all inputs to "
|
||||||
"the Python function must be Tensors.")
|
"the Python function must be Tensors.")
|
||||||
if any(not spec.is_compatible_with(other) for spec, other in zip(
|
if any(not spec.is_compatible_with(other) for spec, other in zip(
|
||||||
self.flat_input_signature, signature_inputs_flat)):
|
self._flat_input_signature, signature_inputs_flat)):
|
||||||
raise ValueError("Python inputs incompatible with input_signature: "
|
raise ValueError("Python inputs incompatible with input_signature: "
|
||||||
"inputs (%s), input_signature (%s)" %
|
"inputs (%s), input_signature (%s)" %
|
||||||
(str(inputs), str(self.input_signature)))
|
(str(inputs), str(self._input_signature)))
|
||||||
return inputs, {}
|
return inputs, {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,6 +60,17 @@ def _inputs_compatible(args, stored_inputs):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_function_spec(function_spec_proto, coder):
|
||||||
|
"""Deserialize a FunctionSpec object from its proto representation."""
|
||||||
|
fullargspec = coder.decode_proto(function_spec_proto.fullargspec)
|
||||||
|
is_method = function_spec_proto.is_method
|
||||||
|
args_to_prepend = coder.decode_proto(function_spec_proto.args_to_prepend)
|
||||||
|
kwargs_to_include = coder.decode_proto(function_spec_proto.kwargs_to_include)
|
||||||
|
input_signature = coder.decode_proto(function_spec_proto.input_signature)
|
||||||
|
return function_lib.FunctionSpec(fullargspec, is_method, args_to_prepend,
|
||||||
|
kwargs_to_include, input_signature)
|
||||||
|
|
||||||
|
|
||||||
def recreate_polymorphic_function(
|
def recreate_polymorphic_function(
|
||||||
saved_polymorphic_function, functions):
|
saved_polymorphic_function, functions):
|
||||||
"""Creates a PolymorphicFunction from a SavedPolymorphicFunction.
|
"""Creates a PolymorphicFunction from a SavedPolymorphicFunction.
|
||||||
@ -77,9 +88,8 @@ def recreate_polymorphic_function(
|
|||||||
# serialization cycle.
|
# serialization cycle.
|
||||||
|
|
||||||
coder = nested_structure_coder.StructureCoder()
|
coder = nested_structure_coder.StructureCoder()
|
||||||
function_spec_tuple = coder.decode_proto(
|
function_spec = _deserialize_function_spec(
|
||||||
saved_polymorphic_function.function_spec_tuple)
|
saved_polymorphic_function.function_spec, coder)
|
||||||
function_spec = function_lib.FunctionSpec.from_tuple(function_spec_tuple)
|
|
||||||
|
|
||||||
# TODO(mdan): We may enable autograph once exceptions are supported.
|
# TODO(mdan): We may enable autograph once exceptions are supported.
|
||||||
@def_function.function(autograph=False)
|
@def_function.function(autograph=False)
|
||||||
|
@ -25,13 +25,28 @@ from tensorflow.python.saved_model import saved_object_graph_pb2
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_function_spec(function_spec, coder):
|
||||||
|
"""Serialize a FunctionSpec object into its proto representation."""
|
||||||
|
proto = saved_object_graph_pb2.FunctionSpec()
|
||||||
|
proto.fullargspec.CopyFrom(coder.encode_structure(function_spec.fullargspec))
|
||||||
|
proto.is_method = function_spec.is_method
|
||||||
|
proto.args_to_prepend.CopyFrom(
|
||||||
|
coder.encode_structure(function_spec.args_to_prepend))
|
||||||
|
proto.kwargs_to_include.CopyFrom(
|
||||||
|
coder.encode_structure(function_spec.kwargs_to_include))
|
||||||
|
proto.input_signature.CopyFrom(
|
||||||
|
coder.encode_structure(function_spec.input_signature))
|
||||||
|
return proto
|
||||||
|
|
||||||
|
|
||||||
def serialize_polymorphic_function(polymorphic_function, node_ids):
|
def serialize_polymorphic_function(polymorphic_function, node_ids):
|
||||||
"""Build a SavedPolymorphicProto."""
|
"""Build a SavedPolymorphicProto."""
|
||||||
coder = nested_structure_coder.StructureCoder()
|
coder = nested_structure_coder.StructureCoder()
|
||||||
proto = saved_object_graph_pb2.SavedPolymorphicFunction()
|
proto = saved_object_graph_pb2.SavedPolymorphicFunction()
|
||||||
|
|
||||||
proto.function_spec_tuple.CopyFrom(
|
function_spec_proto = _serialize_function_spec(
|
||||||
coder.encode_structure(polymorphic_function.function_spec.as_tuple())) # pylint: disable=protected-access
|
polymorphic_function.function_spec, coder)
|
||||||
|
proto.function_spec.CopyFrom(function_spec_proto)
|
||||||
for signature, concrete_function in list_all_concrete_functions(
|
for signature, concrete_function in list_all_concrete_functions(
|
||||||
polymorphic_function):
|
polymorphic_function):
|
||||||
bound_inputs = []
|
bound_inputs = []
|
||||||
|
@ -86,9 +86,7 @@ message SavedAsset {
|
|||||||
// A function with multiple signatures, possibly with non-Tensor arguments.
|
// A function with multiple signatures, possibly with non-Tensor arguments.
|
||||||
message SavedPolymorphicFunction {
|
message SavedPolymorphicFunction {
|
||||||
repeated SavedMonomorphicFunction monomorphic_function = 1;
|
repeated SavedMonomorphicFunction monomorphic_function = 1;
|
||||||
// Tuple representing a `FunctionSpec`.
|
FunctionSpec function_spec = 2;
|
||||||
// TODO(vbardiovsky): Make this a proto.
|
|
||||||
StructuredValue function_spec_tuple = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message SavedMonomorphicFunction {
|
message SavedMonomorphicFunction {
|
||||||
@ -118,3 +116,20 @@ message SavedVariable {
|
|||||||
|
|
||||||
// TODO(andresp): Add save_slice_info_def?
|
// TODO(andresp): Add save_slice_info_def?
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Represents FunctionSpec used in PolymorphicFunction. This represents a
|
||||||
|
// function that has been wrapped as a PolymorphicFunction.
|
||||||
|
message FunctionSpec {
|
||||||
|
// Full arg spec from inspect.getfullargspec().
|
||||||
|
StructuredValue fullargspec = 1;
|
||||||
|
// Whether this represents a class method.
|
||||||
|
bool is_method = 2;
|
||||||
|
// Which arguments to always prepend, in case the original function is based
|
||||||
|
// on a functools.partial.
|
||||||
|
StructuredValue args_to_prepend = 3;
|
||||||
|
// Which kwargs to always include, in case the original function is based on a
|
||||||
|
// functools.partial.
|
||||||
|
StructuredValue kwargs_to_include = 4;
|
||||||
|
// The input signature, if specified.
|
||||||
|
StructuredValue input_signature = 5;
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user