diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index f30b282b86c..a5b4cfbe823 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -124,6 +124,13 @@ message SavedBareConcreteFunction { repeated string argument_keywords = 2; // The prefix of `argument_keywords` which may be identified by position. int64 allowed_positional_arguments = 3; + // The spec of the function that this ConcreteFunction is traced from. This + // allows the ConcreteFunction to be called with nest structure inputs. This + // field may not be populated. If this field is absent, the concrete function + // can only be called with flat inputs. + // TODO(b/169361281): support calling saved ConcreteFunction with structured + // inputs in C++ SavedModel API. + FunctionSpec function_spec = 4; } message SavedConstant { diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index bc0ee33788c..46d73613d62 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -2281,10 +2281,16 @@ class ConcreteFunction(object): lines.append(" Args:") lines.extend(arg_details) lines.append(" Returns:") + + def spec_from_value(value): + # For loaded function, structured_outputs are already specs. + if isinstance(value, type_spec.TypeSpec): + return value + return type_spec.type_spec_from_value(value) + lines.append(" {}".format( pretty_print_spec( - nest.map_structure(type_spec.type_spec_from_value, - self.structured_outputs)))) + nest.map_structure(spec_from_value, self.structured_outputs)))) return "\n".join(lines) diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index b3e6159cb72..092e4177f44 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -163,8 +163,6 @@ def _deserialize_function_spec_as_nonmethod(function_spec_proto, coder): def setup_bare_concrete_function(saved_bare_concrete_function, concrete_functions): """Makes a restored bare concrete function callable.""" - # Bare concrete functions accept only flat lists of Tensors with unique - # names. concrete_function = concrete_functions[ saved_bare_concrete_function.concrete_function_name] # pylint: disable=protected-access @@ -172,6 +170,12 @@ def setup_bare_concrete_function(saved_bare_concrete_function, saved_bare_concrete_function.argument_keywords) concrete_function._num_positional_args = ( saved_bare_concrete_function.allowed_positional_arguments) + if saved_bare_concrete_function.HasField("function_spec"): + coder = nested_structure_coder.StructureCoder() + function_spec = _deserialize_function_spec_as_nonmethod( + saved_bare_concrete_function.function_spec, + coder) + concrete_function._set_function_spec(function_spec) # pylint: enable=protected-access concrete_function.add_to_graph() return concrete_function @@ -338,10 +342,11 @@ def load_function_def_library(library, load_shared_name_suffix=None): for dep in _list_function_deps(fdef, library_function_names): functions[dep].add_to_graph(func_graph) - # We do not initialize the new ConcreteFunction's function_spec or + # We do not initialize the new ConcreteFunction's function_spec and/or # arg_keywords here (which are used to parse the structured and flat - # signatures, respectively). function_spec is set up later by - # recreate_function(); and arg_keywords by setup_bare_concrete_function(). + # signatures, respectively). ConcreteFunction that are part of a saved + # function is set up later by recreate_function(); and bare ConcreteFunction + # is set up by by setup_bare_concrete_function(). func = function_lib.ConcreteFunction(func_graph) func.add_to_graph(graph) diff --git a/tensorflow/python/saved_model/function_serialization.py b/tensorflow/python/saved_model/function_serialization.py index 07a80539a33..ad18e8f5d2a 100644 --- a/tensorflow/python/saved_model/function_serialization.py +++ b/tensorflow/python/saved_model/function_serialization.py @@ -85,17 +85,19 @@ def serialize_concrete_function(concrete_function, node_ids, coder): def serialize_bare_concrete_function(concrete_function, name_map): """Build a SavedBareConcreteFunction.""" - # TODO(edloper): Currently, bare concrete functions don't have access to a - # function_spec, so they can't be called with the structured signature. - # Update the serialization to include a function_spec. - # pylint: disable=protected-access name = name_map.get(compat.as_text(concrete_function.name), concrete_function.name) - return saved_object_graph_pb2.SavedBareConcreteFunction( + proto = saved_object_graph_pb2.SavedBareConcreteFunction( concrete_function_name=name, allowed_positional_arguments=concrete_function._num_positional_args, argument_keywords=concrete_function._arg_keywords) + if concrete_function._pre_initialized_function_spec is not None: + coder = nested_structure_coder.StructureCoder() + proto.function_spec.CopyFrom( + _serialize_function_spec( + concrete_function._pre_initialized_function_spec, coder)) + return proto # pylint: enable=protected-access diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index db5b7e449d3..a8244850308 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -480,6 +480,34 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual(31, imported.f(input1).numpy()) self.assertEqual(32, imported.f(input3).numpy()) + def test_structured_inputs_bare_concrete_function(self, cycles): + + def func(x, training=True): + # x is a nested structure, we care about one particular tensor. + _, (a, b) = x + if training: + return 2 * a["a"] + b + else: + return 7 + + x = constant_op.constant(10) + y = constant_op.constant(11) + + input1 = [6, ({"a": x}, y)] + input2 = [7, ({"a": x}, y)] # Not compatible with input1 signature. + input3 = [6, ({"a": y}, x)] # Compatible with input1 signature. + + root = tracking.AutoTrackable() + root.f = def_function.function(func).get_concrete_function(input1) + + imported = cycle(root, cycles) + + with self.assertRaises(TypeError): + imported.f(input2) + + self.assertEqual(31, imported.f(input1).numpy()) + self.assertEqual(32, imported.f(input3).numpy()) + def test_structured_output(self, cycles): # Use fields with non-alphabetical order @@ -509,6 +537,31 @@ class LoadTest(test.TestCase, parameterized.TestCase): self.assertEqual(5, result[1].numpy()) self.assertEqual(0.5, result[2]["x"].numpy()) + def test_pretty_print_signature(self, cycles): + + named_tuple_type = collections.namedtuple("NamedTupleHello", ["b", "a"]) + + def func(input1, input2): + named_tuple = named_tuple_type(a=input1 + input2, b=input1 * input2) + return [named_tuple, input2, {"x": 0.5}] + + root = tracking.AutoTrackable() + root.f = def_function.function(func).get_concrete_function( + constant_op.constant(2), constant_op.constant(3)) + + imported = cycle(root, cycles) + self.assertEqual( + imported.f.pretty_printed_signature(), """func(input1, input2) + Args: + input1: int32 Tensor, shape=() + input2: int32 Tensor, shape=() + Returns: + [NamedTupleHello(b=<1>, a=<2>), <3>, {'x': <4>}] + <1>: int32 Tensor, shape=() + <2>: int32 Tensor, shape=() + <3>: int32 Tensor, shape=() + <4>: float32 Tensor, shape=()""") + def test_positional_arguments(self, cycles): def func(x, training=False, abc=7.1, defg=7.7): del abc