Log info message if input name in function signature changes in SavedModel, which get converted here: 7e3a0d6be0/tensorflow/core/framework/graph_to_functiondef.cc (L82-L93)
Also clean up signature_serialization.canonicalize_signatures PiperOrigin-RevId: 340680532 Change-Id: Ia5794e5ead1171531ebd08a8234538add68d45d4
This commit is contained in:
parent
2ed1873d25
commit
9f14e9eb15
tensorflow/python/saved_model
@ -290,6 +290,7 @@ py_strict_library(
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/training/tracking:base",
|
||||
"@absl_py//absl/logging",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,6 @@ import tempfile
|
||||
import weakref
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -1314,6 +1313,34 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
# The signatures mapping is immutable
|
||||
imported.signatures["random_key"] = 3
|
||||
|
||||
def test_names_normalized(self, cycles):
|
||||
class ObjWithFunction(module.Module):
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A-b"),
|
||||
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="A/D"),
|
||||
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="bar"),
|
||||
tensor_spec.TensorSpec([], dtype=dtypes.int32, name="e"),
|
||||
])
|
||||
def foo(self, a, b, c, d=10, **options):
|
||||
del options
|
||||
return a + b + c + d
|
||||
|
||||
exported = ObjWithFunction()
|
||||
|
||||
with self.assertLogs(level="WARNING") as logs:
|
||||
imported = cycle(exported, cycles)
|
||||
|
||||
expected_message = (
|
||||
"WARNING:absl:Function `foo` contains input name(s) A-b, A/D with "
|
||||
"unsupported characters which will be renamed to a_b, a_d in the "
|
||||
"SavedModel.")
|
||||
self.assertIn(expected_message, logs.output)
|
||||
|
||||
loaded_signature = imported.signatures["serving_default"].inputs
|
||||
self.assertEqual("a_b:0", loaded_signature[0].name)
|
||||
self.assertEqual("a_d:0", loaded_signature[1].name)
|
||||
|
||||
def test_multiple_argument_signatures_no_positional(self, cycles):
|
||||
|
||||
class Exported(tracking.AutoTrackable):
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import logging
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import ops
|
||||
@ -34,6 +36,8 @@ from tensorflow.python.util.compat import collections_abc
|
||||
|
||||
DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
|
||||
SIGNATURE_ATTRIBUTE_NAME = "signatures"
|
||||
# Max number of warnings to show if signature contains normalized input names.
|
||||
_NUM_DISPLAY_NORMALIZED_SIGNATURES = 5
|
||||
|
||||
|
||||
def _get_signature(function):
|
||||
@ -61,6 +65,7 @@ def _valid_signature(concrete_function):
|
||||
|
||||
|
||||
def _validate_inputs(concrete_function):
|
||||
"""Raises error if input type is tf.Variable."""
|
||||
if any(isinstance(inp, resource_variable_ops.VariableSpec)
|
||||
for inp in nest.flatten(
|
||||
concrete_function.structured_input_signature)):
|
||||
@ -68,6 +73,24 @@ def _validate_inputs(concrete_function):
|
||||
"exported as signatures."))
|
||||
|
||||
|
||||
def _get_signature_name_changes(concrete_function):
|
||||
"""Checks for user-specified signature input names that are normalized."""
|
||||
# Map of {user-given name: normalized name} if the names are un-identical.
|
||||
name_changes = {}
|
||||
for signature_input_name, graph_input in zip(
|
||||
concrete_function.function_def.signature.input_arg,
|
||||
concrete_function.graph.inputs):
|
||||
try:
|
||||
user_specified_name = compat.as_str(
|
||||
graph_input.op.get_attr("_user_specified_name"))
|
||||
if signature_input_name.name != user_specified_name:
|
||||
name_changes[user_specified_name] = signature_input_name.name
|
||||
except ValueError:
|
||||
# Signature input does not have a user-specified name.
|
||||
pass
|
||||
return name_changes
|
||||
|
||||
|
||||
def find_function_to_export(saveable_view):
|
||||
"""Function to export, None if no suitable function was found."""
|
||||
# If the user did not specify signatures, check the root object for a function
|
||||
@ -100,11 +123,11 @@ def canonicalize_signatures(signatures):
|
||||
if not isinstance(signatures, collections_abc.Mapping):
|
||||
signatures = {
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
|
||||
num_normalized_signatures_counter = 0
|
||||
concrete_signatures = {}
|
||||
wrapped_functions = {}
|
||||
for signature_key, function in signatures.items():
|
||||
original_function = signature_function = _get_signature(function)
|
||||
|
||||
if signature_function is None:
|
||||
raise ValueError(
|
||||
("Expected a TensorFlow function to generate a signature for, but "
|
||||
@ -115,7 +138,16 @@ def canonicalize_signatures(signatures):
|
||||
wrapped_functions.get(original_function) or
|
||||
function_serialization.wrap_cached_variables(original_function))
|
||||
_validate_inputs(signature_function)
|
||||
|
||||
if num_normalized_signatures_counter < _NUM_DISPLAY_NORMALIZED_SIGNATURES:
|
||||
signature_name_changes = _get_signature_name_changes(signature_function)
|
||||
if signature_name_changes:
|
||||
num_normalized_signatures_counter += 1
|
||||
logging.warning(
|
||||
"Function `%s` contains input name(s) %s with unsupported "
|
||||
"characters which will be renamed to %s in the SavedModel.",
|
||||
compat.as_str(signature_function.graph.name),
|
||||
", ".join(signature_name_changes.keys()),
|
||||
", ".join(signature_name_changes.values()))
|
||||
# Re-wrap the function so that it returns a dictionary of Tensors. This
|
||||
# matches the format of 1.x-style signatures.
|
||||
# pylint: disable=cell-var-from-loop
|
||||
|
Loading…
Reference in New Issue
Block a user