From 9f14e9eb15ba26fbf1e30750494cdd4f0b42ab5f Mon Sep 17 00:00:00 2001 From: Monica Song Date: Wed, 4 Nov 2020 10:29:27 -0800 Subject: [PATCH] Log info message if input name in function signature changes in SavedModel, which get converted here: https://github.com/tensorflow/tensorflow/blob/7e3a0d6be0e1c5f2c87d8552c092055d6340f596/tensorflow/core/framework/graph_to_functiondef.cc#L82-L93 Also clean up signature_serialization.canonicalize_signatures PiperOrigin-RevId: 340680532 Change-Id: Ia5794e5ead1171531ebd08a8234538add68d45d4 --- tensorflow/python/saved_model/BUILD | 1 + tensorflow/python/saved_model/load_test.py | 29 ++++++++++++++- .../saved_model/signature_serialization.py | 36 +++++++++++++++++-- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 5768cbdc15d..3e4fb285faa 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -290,6 +290,7 @@ py_strict_library( "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", "//tensorflow/python/training/tracking:base", + "@absl_py//absl/logging", ], ) diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 6d8d7520c2f..23211132dc5 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -26,7 +26,6 @@ import tempfile import weakref from absl.testing import parameterized - from tensorflow.python.client import session as session_lib from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import backprop @@ -1314,6 +1313,34 @@ class LoadTest(test.TestCase, parameterized.TestCase): # The signatures mapping is immutable imported.signatures["random_key"] = 3 + def test_names_normalized(self, cycles): + class ObjWithFunction(module.Module): + + @def_function.function(input_signature=[ + tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A-b"), + tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A/D"), + tensor_spec.TensorSpec([], dtype=dtypes.int32, name="bar"), + tensor_spec.TensorSpec([], dtype=dtypes.int32, name="e"), + ]) + def foo(self, a, b, c, d=10, **options): + del options + return a + b + c + d + + exported = ObjWithFunction() + + with self.assertLogs(level="WARNING") as logs: + imported = cycle(exported, cycles) + + expected_message = ( + "WARNING:absl:Function `foo` contains input name(s) A-b, A/D with " + "unsupported characters which will be renamed to a_b, a_d in the " + "SavedModel.") + self.assertIn(expected_message, logs.output) + + loaded_signature = imported.signatures["serving_default"].inputs + self.assertEqual("a_b:0", loaded_signature[0].name) + self.assertEqual("a_d:0", loaded_signature[1].name) + def test_multiple_argument_signatures_no_positional(self, cycles): class Exported(tracking.AutoTrackable): diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py index 14f4df81380..4250efd7c01 100644 --- a/tensorflow/python/saved_model/signature_serialization.py +++ b/tensorflow/python/saved_model/signature_serialization.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl import logging + from tensorflow.python.eager import def_function from tensorflow.python.eager import function as defun from tensorflow.python.framework import ops @@ -34,6 +36,8 @@ from tensorflow.python.util.compat import collections_abc DEFAULT_SIGNATURE_ATTR = "_default_save_signature" SIGNATURE_ATTRIBUTE_NAME = "signatures" +# Max number of warnings to show if signature contains normalized input names. +_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5 def _get_signature(function): @@ -61,6 +65,7 @@ def _valid_signature(concrete_function): def _validate_inputs(concrete_function): + """Raises error if input type is tf.Variable.""" if any(isinstance(inp, resource_variable_ops.VariableSpec) for inp in nest.flatten( concrete_function.structured_input_signature)): @@ -68,6 +73,24 @@ def _validate_inputs(concrete_function): "exported as signatures.")) +def _get_signature_name_changes(concrete_function): + """Checks for user-specified signature input names that are normalized.""" + # Map of {user-given name: normalized name} if the names are un-identical. + name_changes = {} + for signature_input_name, graph_input in zip( + concrete_function.function_def.signature.input_arg, + concrete_function.graph.inputs): + try: + user_specified_name = compat.as_str( + graph_input.op.get_attr("_user_specified_name")) + if signature_input_name.name != user_specified_name: + name_changes[user_specified_name] = signature_input_name.name + except ValueError: + # Signature input does not have a user-specified name. + pass + return name_changes + + def find_function_to_export(saveable_view): """Function to export, None if no suitable function was found.""" # If the user did not specify signatures, check the root object for a function @@ -100,11 +123,11 @@ def canonicalize_signatures(signatures): if not isinstance(signatures, collections_abc.Mapping): signatures = { signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures} + num_normalized_signatures_counter = 0 concrete_signatures = {} wrapped_functions = {} for signature_key, function in signatures.items(): original_function = signature_function = _get_signature(function) - if signature_function is None: raise ValueError( ("Expected a TensorFlow function to generate a signature for, but " @@ -115,7 +138,16 @@ def canonicalize_signatures(signatures): wrapped_functions.get(original_function) or function_serialization.wrap_cached_variables(original_function)) _validate_inputs(signature_function) - + if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES: + signature_name_changes = _get_signature_name_changes(signature_function) + if signature_name_changes: + num_normalized_signatures_counter += 1 + logging.warning( + "Function `%s` contains input name(s) %s with unsupported " + "characters which will be renamed to %s in the SavedModel.", + compat.as_str(signature_function.graph.name), + ", ".join(signature_name_changes.keys()), + ", ".join(signature_name_changes.values())) # Re-wrap the function so that it returns a dictionary of Tensors. This # matches the format of 1.x-style signatures. # pylint: disable=cell-var-from-loop