From b6d58af1444e20e169a782af58fd3a1894f2628e Mon Sep 17 00:00:00 2001
From: Ran Chen <crccw@google.com>
Date: Wed, 7 Oct 2020 13:54:57 -0700
Subject: [PATCH] Save FunctionSpec with BareConcreteFunction

ConcreteFunction now supports being called with both structured inputs and flatten inputs, so that it's compatible with tf.function calling semantics.

To support the same semantics for saved ConcreteFunction, we need to save the FunctionSpec.

Note that calling with structured inputs is not yet supported in SavedModel C++ API.

This change also modifies ConcreteFunction.pretty_printed_signature to make it work for loaded functions. The structured_outputs of loaded functions are tensor specs instead of tensors. Ideally it should be tensors as well, but there're already users who depend on
this behavior.

PiperOrigin-RevId: 335946423
Change-Id: I4aecf2aba51459801bd0d42a3343ad1becaef187
---
 .../core/protobuf/saved_object_graph.proto    |  7 +++
 tensorflow/python/eager/function.py           | 10 +++-
 .../saved_model/function_deserialization.py   | 15 ++++--
 .../saved_model/function_serialization.py     | 12 +++--
 tensorflow/python/saved_model/load_test.py    | 53 +++++++++++++++++++
 5 files changed, 85 insertions(+), 12 deletions(-)

diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto
index f30b282b86c..a5b4cfbe823 100644
--- a/tensorflow/core/protobuf/saved_object_graph.proto
+++ b/tensorflow/core/protobuf/saved_object_graph.proto
@@ -124,6 +124,13 @@ message SavedBareConcreteFunction {
   repeated string argument_keywords = 2;
   // The prefix of `argument_keywords` which may be identified by position.
   int64 allowed_positional_arguments = 3;
+  // The spec of the function that this ConcreteFunction is traced from. This
+  // allows the ConcreteFunction to be called with nest structure inputs. This
+  // field may not be populated. If this field is absent, the concrete function
+  // can only be called with flat inputs.
+  // TODO(b/169361281): support calling saved ConcreteFunction with structured
+  // inputs in C++ SavedModel API.
+  FunctionSpec function_spec = 4;
 }
 
 message SavedConstant {
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index bc0ee33788c..46d73613d62 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -2281,10 +2281,16 @@ class ConcreteFunction(object):
       lines.append("  Args:")
       lines.extend(arg_details)
     lines.append("  Returns:")
+
+    def spec_from_value(value):
+      # For loaded function, structured_outputs are already specs.
+      if isinstance(value, type_spec.TypeSpec):
+        return value
+      return type_spec.type_spec_from_value(value)
+
     lines.append("    {}".format(
         pretty_print_spec(
-            nest.map_structure(type_spec.type_spec_from_value,
-                               self.structured_outputs))))
+            nest.map_structure(spec_from_value, self.structured_outputs))))
 
     return "\n".join(lines)
 
diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index b3e6159cb72..092e4177f44 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -163,8 +163,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder):
 def setup_bare_concrete_function(saved_bare_concrete_function,
                                  concrete_functions):
   """Makes a restored bare concrete function callable."""
-  # Bare concrete functions accept only flat lists of Tensors with unique
-  # names.
   concrete_function = concrete_functions[
       saved_bare_concrete_function.concrete_function_name]
   # pylint: disable=protected-access
@@ -172,6 +170,12 @@ def setup_bare_concrete_function(saved_bare_concrete_function,
       saved_bare_concrete_function.argument_keywords)
   concrete_function._num_positional_args = (
       saved_bare_concrete_function.allowed_positional_arguments)
+  if saved_bare_concrete_function.HasField("function_spec"):
+    coder = nested_structure_coder.StructureCoder()
+    function_spec = _deserialize_function_spec_as_nonmethod(
+        saved_bare_concrete_function.function_spec,
+        coder)
+    concrete_function._set_function_spec(function_spec)
   # pylint: enable=protected-access
   concrete_function.add_to_graph()
   return concrete_function
@@ -338,10 +342,11 @@ def load_function_def_library(library, load_shared_name_suffix=None):
     for dep in _list_function_deps(fdef, library_function_names):
       functions[dep].add_to_graph(func_graph)
 
-    # We do not initialize the new ConcreteFunction's function_spec or
+    # We do not initialize the new ConcreteFunction's function_spec and/or
     # arg_keywords here (which are used to parse the structured and flat
-    # signatures, respectively).  function_spec is set up later by
-    # recreate_function(); and arg_keywords by setup_bare_concrete_function().
+    # signatures, respectively). ConcreteFunction that are part of a saved
+    # function is set up later by recreate_function(); and bare ConcreteFunction
+    # is set up by by setup_bare_concrete_function().
     func = function_lib.ConcreteFunction(func_graph)
     func.add_to_graph(graph)
 
diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py
index 07a80539a33..ad18e8f5d2a 100644
--- a/tensorflow/python/saved_model/function_serialization.py
+++ b/tensorflow/python/saved_model/function_serialization.py
@@ -85,17 +85,19 @@ def serialize_concrete_function(concrete_function, node_ids, coder):
 
 def serialize_bare_concrete_function(concrete_function, name_map):
   """Build a SavedBareConcreteFunction."""
-  # TODO(edloper): Currently, bare concrete functions don't have access to a
-  # function_spec, so they can't be called with the structured signature.
-  # Update the serialization to include a function_spec.
-
   # pylint: disable=protected-access
   name = name_map.get(compat.as_text(concrete_function.name),
                       concrete_function.name)
-  return saved_object_graph_pb2.SavedBareConcreteFunction(
+  proto = saved_object_graph_pb2.SavedBareConcreteFunction(
       concrete_function_name=name,
       allowed_positional_arguments=concrete_function._num_positional_args,
       argument_keywords=concrete_function._arg_keywords)
+  if concrete_function._pre_initialized_function_spec is not None:
+    coder = nested_structure_coder.StructureCoder()
+    proto.function_spec.CopyFrom(
+        _serialize_function_spec(
+            concrete_function._pre_initialized_function_spec, coder))
+  return proto
   # pylint: enable=protected-access
 
 
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index db5b7e449d3..a8244850308 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -480,6 +480,34 @@ class LoadTest(test.TestCase, parameterized.TestCase):
     self.assertEqual(31, imported.f(input1).numpy())
     self.assertEqual(32, imported.f(input3).numpy())
 
+  def test_structured_inputs_bare_concrete_function(self, cycles):
+
+    def func(x, training=True):
+      # x is a nested structure, we care about one particular tensor.
+      _, (a, b) = x
+      if training:
+        return 2 * a["a"] + b
+      else:
+        return 7
+
+    x = constant_op.constant(10)
+    y = constant_op.constant(11)
+
+    input1 = [6, ({"a": x}, y)]
+    input2 = [7, ({"a": x}, y)]  # Not compatible with input1 signature.
+    input3 = [6, ({"a": y}, x)]  # Compatible with input1 signature.
+
+    root = tracking.AutoTrackable()
+    root.f = def_function.function(func).get_concrete_function(input1)
+
+    imported = cycle(root, cycles)
+
+    with self.assertRaises(TypeError):
+      imported.f(input2)
+
+    self.assertEqual(31, imported.f(input1).numpy())
+    self.assertEqual(32, imported.f(input3).numpy())
+
   def test_structured_output(self, cycles):
 
     # Use fields with non-alphabetical order
@@ -509,6 +537,31 @@ class LoadTest(test.TestCase, parameterized.TestCase):
     self.assertEqual(5, result[1].numpy())
     self.assertEqual(0.5, result[2]["x"].numpy())
 
+  def test_pretty_print_signature(self, cycles):
+
+    named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"])
+
+    def func(input1, input2):
+      named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2)
+      return [named_tuple, input2, {"x": 0.5}]
+
+    root = tracking.AutoTrackable()
+    root.f = def_function.function(func).get_concrete_function(
+        constant_op.constant(2), constant_op.constant(3))
+
+    imported = cycle(root, cycles)
+    self.assertEqual(
+        imported.f.pretty_printed_signature(), """func(input1, input2)
+  Args:
+    input1: int32 Tensor, shape=()
+    input2: int32 Tensor, shape=()
+  Returns:
+    [NamedTupleHello(b=<1>, a=<2>), <3>, {'x': <4>}]
+      <1>: int32 Tensor, shape=()
+      <2>: int32 Tensor, shape=()
+      <3>: int32 Tensor, shape=()
+      <4>: float32 Tensor, shape=()""")
+
   def test_positional_arguments(self, cycles):
     def func(x, training=False, abc=7.1, defg=7.7):
       del abc