Save FunctionSpec with BareConcreteFunction

ConcreteFunction now supports being called with both structured inputs and flatten inputs, so that it's compatible with tf.function calling semantics.

To support the same semantics for saved ConcreteFunction, we need to save the FunctionSpec.

Note that calling with structured inputs is not yet supported in SavedModel C++ API.

This change also modifies ConcreteFunction.pretty_printed_signature to make it work for loaded functions. The structured_outputs of loaded functions are tensor specs instead of tensors. Ideally it should be tensors as well, but there're already users who depend on
this behavior.

PiperOrigin-RevId: 335946423
Change-Id: I4aecf2aba51459801bd0d42a3343ad1becaef187
This commit is contained in:
Ran Chen 2020-10-07 13:54:57 -07:00 committed by TensorFlower Gardener
parent 36bc549146
commit b6d58af144
5 changed files with 85 additions and 12 deletions

View File

@ -124,6 +124,13 @@ message SavedBareConcreteFunction {
repeated string argument_keywords = 2;
// The prefix of `argument_keywords` which may be identified by position.
int64 allowed_positional_arguments = 3;
// The spec of the function that this ConcreteFunction is traced from. This
// allows the ConcreteFunction to be called with nest structure inputs. This
// field may not be populated. If this field is absent, the concrete function
// can only be called with flat inputs.
// TODO(b/169361281): support calling saved ConcreteFunction with structured
// inputs in C++ SavedModel API.
FunctionSpec function_spec = 4;
}
message SavedConstant {

View File

@ -2281,10 +2281,16 @@ class ConcreteFunction(object):
lines.append(" Args:")
lines.extend(arg_details)
lines.append(" Returns:")
def spec_from_value(value):
# For loaded function, structured_outputs are already specs.
if isinstance(value, type_spec.TypeSpec):
return value
return type_spec.type_spec_from_value(value)
lines.append(" {}".format(
pretty_print_spec(
nest.map_structure(type_spec.type_spec_from_value,
self.structured_outputs))))
nest.map_structure(spec_from_value, self.structured_outputs))))
return "\n".join(lines)

View File

@ -163,8 +163,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
def setup_bare_concrete_function(saved_bare_concrete_function,
concrete_functions):
"""Makes a restored bare concrete function callable."""
# Bare concrete functions accept only flat lists of Tensors with unique
# names.
concrete_function = concrete_functions[
saved_bare_concrete_function.concrete_function_name]
# pylint: disable=protected-access
@ -172,6 +170,12 @@ def setup_bare_concrete_function(saved_bare_concrete_function,
saved_bare_concrete_function.argument_keywords)
concrete_function._num_positional_args = (
saved_bare_concrete_function.allowed_positional_arguments)
if saved_bare_concrete_function.HasField("function_spec"):
coder = nested_structure_coder.StructureCoder()
function_spec = _deserialize_function_spec_as_nonmethod(
saved_bare_concrete_function.function_spec,
coder)
concrete_function._set_function_spec(function_spec)
# pylint: enable=protected-access
concrete_function.add_to_graph()
return concrete_function
@ -338,10 +342,11 @@ def load_function_def_library(library, load_shared_name_suffix=None):
for dep in _list_function_deps(fdef, library_function_names):
functions[dep].add_to_graph(func_graph)
# We do not initialize the new ConcreteFunction's function_spec or
# We do not initialize the new ConcreteFunction's function_spec and/or
# arg_keywords here (which are used to parse the structured and flat
# signatures, respectively). function_spec is set up later by
# recreate_function(); and arg_keywords by setup_bare_concrete_function().
# signatures, respectively). ConcreteFunction that are part of a saved
# function is set up later by recreate_function(); and bare ConcreteFunction
# is set up by by setup_bare_concrete_function().
func = function_lib.ConcreteFunction(func_graph)
func.add_to_graph(graph)

View File

@ -85,17 +85,19 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
def serialize_bare_concrete_function(concrete_function, name_map):
"""Build a SavedBareConcreteFunction."""
# TODO(edloper): Currently, bare concrete functions don't have access to a
# function_spec, so they can't be called with the structured signature.
# Update the serialization to include a function_spec.
# pylint: disable=protected-access
name = name_map.get(compat.as_text(concrete_function.name),
concrete_function.name)
return saved_object_graph_pb2.SavedBareConcreteFunction(
proto = saved_object_graph_pb2.SavedBareConcreteFunction(
concrete_function_name=name,
allowed_positional_arguments=concrete_function._num_positional_args,
argument_keywords=concrete_function._arg_keywords)
if concrete_function._pre_initialized_function_spec is not None:
coder = nested_structure_coder.StructureCoder()
proto.function_spec.CopyFrom(
_serialize_function_spec(
concrete_function._pre_initialized_function_spec, coder))
return proto
# pylint: enable=protected-access

View File

@ -480,6 +480,34 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(31, imported.f(input1).numpy())
self.assertEqual(32, imported.f(input3).numpy())
def test_structured_inputs_bare_concrete_function(self, cycles):
def func(x, training=True):
# x is a nested structure, we care about one particular tensor.
_, (a, b) = x
if training:
return 2 * a["a"] + b
else:
return 7
x = constant_op.constant(10)
y = constant_op.constant(11)
input1 = [6, ({"a": x}, y)]
input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature.
input3 = [6, ({"a": y}, x)] # Compatible with input1 signature.
root = tracking.AutoTrackable()
root.f = def_function.function(func).get_concrete_function(input1)
imported = cycle(root, cycles)
with self.assertRaises(TypeError):
imported.f(input2)
self.assertEqual(31, imported.f(input1).numpy())
self.assertEqual(32, imported.f(input3).numpy())
def test_structured_output(self, cycles):
# Use fields with non-alphabetical order
@ -509,6 +537,31 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertEqual(5, result[1].numpy())
self.assertEqual(0.5, result[2]["x"].numpy())
def test_pretty_print_signature(self, cycles):
named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])
def func(input1, input2):
named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
return [named_tuple, input2, {"x": 0.5}]
root = tracking.AutoTrackable()
root.f = def_function.function(func).get_concrete_function(
constant_op.constant(2), constant_op.constant(3))
imported = cycle(root, cycles)
self.assertEqual(
imported.f.pretty_printed_signature(), """func(input1, input2)
Args:
input1: int32 Tensor, shape=()
input2: int32 Tensor, shape=()
Returns:
[NamedTupleHello(b=<1>, a=<2>), <3>, {'x': <4>}]
<1>: int32 Tensor, shape=()
<2>: int32 Tensor, shape=()
<3>: int32 Tensor, shape=()
<4>: float32 Tensor, shape=()""")
def test_positional_arguments(self, cycles):
def func(x, training=False, abc=7.1, defg=7.7):
del abc