Log info message if input name in function signature changes in SavedModel, which get converted here: 7e3a0d6be0/tensorflow/core/framework/graph_to_functiondef.cc (L82-L93)

Also clean up signature_serialization.canonicalize_signatures

PiperOrigin-RevId: 340680532
Change-Id: Ia5794e5ead1171531ebd08a8234538add68d45d4
This commit is contained in:
Monica Song 2020-11-04 10:29:27 -08:00 committed by TensorFlower Gardener
parent 2ed1873d25
commit 9f14e9eb15
3 changed files with 63 additions and 3 deletions
tensorflow/python/saved_model

View File

@ -290,6 +290,7 @@ py_strict_library(
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/tracking:base",
"@absl_py//absl/logging",
],
)

View File

@ -26,7 +26,6 @@ import tempfile
import weakref
from absl.testing import parameterized
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
@ -1314,6 +1313,34 @@ class LoadTest(test.TestCase, parameterized.TestCase):
# The signatures mapping is immutable
imported.signatures["random_key"] = 3
def test_names_normalized(self, cycles):
class ObjWithFunction(module.Module):
@def_function.function(input_signature=[
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A-b"),
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A/D"),
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="bar"),
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="e"),
])
def foo(self, a, b, c, d=10, **options):
del options
return a + b + c + d
exported = ObjWithFunction()
with self.assertLogs(level="WARNING") as logs:
imported = cycle(exported, cycles)
expected_message = (
"WARNING:absl:Function `foo` contains input name(s) A-b, A/D with "
"unsupported characters which will be renamed to a_b, a_d in the "
"SavedModel.")
self.assertIn(expected_message, logs.output)
loaded_signature = imported.signatures["serving_default"].inputs
self.assertEqual("a_b:0", loaded_signature[0].name)
self.assertEqual("a_d:0", loaded_signature[1].name)
def test_multiple_argument_signatures_no_positional(self, cycles):
class Exported(tracking.AutoTrackable):

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import ops
@ -34,6 +36,8 @@ from tensorflow.python.util.compat import collections_abc
DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
SIGNATURE_ATTRIBUTE_NAME = "signatures"
# Max number of warnings to show if signature contains normalized input names.
_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5
def _get_signature(function):
@ -61,6 +65,7 @@ def _valid_signature(concrete_function):
def _validate_inputs(concrete_function):
"""Raises error if input type is tf.Variable."""
if any(isinstance(inp, resource_variable_ops.VariableSpec)
for inp in nest.flatten(
concrete_function.structured_input_signature)):
@ -68,6 +73,24 @@ def _validate_inputs(concrete_function):
"exported as signatures."))
def _get_signature_name_changes(concrete_function):
"""Checks for user-specified signature input names that are normalized."""
# Map of {user-given name: normalized name} if the names are un-identical.
name_changes = {}
for signature_input_name, graph_input in zip(
concrete_function.function_def.signature.input_arg,
concrete_function.graph.inputs):
try:
user_specified_name = compat.as_str(
graph_input.op.get_attr("_user_specified_name"))
if signature_input_name.name != user_specified_name:
name_changes[user_specified_name] = signature_input_name.name
except ValueError:
# Signature input does not have a user-specified name.
pass
return name_changes
def find_function_to_export(saveable_view):
"""Function to export, None if no suitable function was found."""
# If the user did not specify signatures, check the root object for a function
@ -100,11 +123,11 @@ def canonicalize_signatures(signatures):
if not isinstance(signatures, collections_abc.Mapping):
signatures = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
num_normalized_signatures_counter = 0
concrete_signatures = {}
wrapped_functions = {}
for signature_key, function in signatures.items():
original_function = signature_function = _get_signature(function)
if signature_function is None:
raise ValueError(
("Expected a TensorFlow function to generate a signature for, but "
@ -115,7 +138,16 @@ def canonicalize_signatures(signatures):
wrapped_functions.get(original_function) or
function_serialization.wrap_cached_variables(original_function))
_validate_inputs(signature_function)
if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES:
signature_name_changes = _get_signature_name_changes(signature_function)
if signature_name_changes:
num_normalized_signatures_counter += 1
logging.warning(
"Function `%s` contains input name(s) %s with unsupported "
"characters which will be renamed to %s in the SavedModel.",
compat.as_str(signature_function.graph.name),
", ".join(signature_name_changes.keys()),
", ".join(signature_name_changes.values()))
# Re-wrap the function so that it returns a dictionary of Tensors. This
# matches the format of 1.x-style signatures.
# pylint: disable=cell-var-from-loop