STT-tensorflow/tensorflow/python/saved_model/signature_serialization.py
Katherine Wu 12d00c3e34 Serialize concrete function signature using structured_input_signature instead of function inputs.
This way, if the user calls set_shape on the inputs within the function body, the serialized signature is not affected.

PiperOrigin-RevId: 338345710
Change-Id: Ie31b9e2206de57aca4e592bbd43fafbff0d2bda6
2020-10-21 15:25:57 -07:00

280 lines
12 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers for working with signatures in tf.saved_model.save."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training.tracking import base
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
SIGNATURE_ATTRIBUTE_NAME = "signatures"
def _get_signature(function):
if (isinstance(function, (defun.Function, def_function.Function)) and
function.input_signature is not None):
function = function._get_concrete_function_garbage_collected() # pylint: disable=protected-access
if not isinstance(function, defun.ConcreteFunction):
return None
return function
def _valid_signature(concrete_function):
"""Returns whether concrete function can be converted to a signature."""
if not concrete_function.outputs:
# Functions without outputs don't make sense as signatures. We just don't
# have any way to run an Operation with no outputs as a SignatureDef in the
# 1.x style.
return False
try:
_validate_inputs(concrete_function)
_normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
except ValueError:
return False
return True
def _validate_inputs(concrete_function):
if any(isinstance(inp, resource_variable_ops.VariableSpec)
for inp in nest.flatten(
concrete_function.structured_input_signature)):
raise ValueError(("Functions that expect tf.Variable inputs cannot be "
"exported as signatures."))
def find_function_to_export(saveable_view):
"""Function to export, None if no suitable function was found."""
# If the user did not specify signatures, check the root object for a function
# that can be made into a signature.
functions = saveable_view.list_functions(saveable_view.root)
signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
if signature is not None:
return signature
# TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a
# user decides to annotate more functions with tf.function and suddenly
# serving that model way later in the process stops working.
possible_signatures = []
for function in functions.values():
concrete = _get_signature(function)
if concrete is not None and _valid_signature(concrete):
possible_signatures.append(concrete)
if len(possible_signatures) == 1:
single_function = possible_signatures[0]
signature = _get_signature(single_function)
if signature and _valid_signature(signature):
return signature
return None
def canonicalize_signatures(signatures):
"""Converts `signatures` into a dictionary of concrete functions."""
if signatures is None:
return {}, {}
if not isinstance(signatures, collections_abc.Mapping):
signatures = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
concrete_signatures = {}
wrapped_functions = {}
for signature_key, function in signatures.items():
original_function = signature_function = _get_signature(function)
if signature_function is None:
raise ValueError(
("Expected a TensorFlow function to generate a signature for, but "
"got {}. Only `tf.functions` with an input signature or "
"concrete functions can be used as a signature.").format(function))
wrapped_functions[original_function] = signature_function = (
wrapped_functions.get(original_function) or
function_serialization.wrap_cached_variables(original_function))
_validate_inputs(signature_function)
# Re-wrap the function so that it returns a dictionary of Tensors. This
# matches the format of 1.x-style signatures.
# pylint: disable=cell-var-from-loop
@def_function.function
def signature_wrapper(**kwargs):
structured_outputs = signature_function(**kwargs)
return _normalize_outputs(
structured_outputs, signature_function.name, signature_key)
tensor_spec_signature = {}
if signature_function.structured_input_signature is not None:
# The structured input signature may contain other non-tensor arguments.
inputs = filter(
lambda x: isinstance(x, tensor_spec.TensorSpec),
nest.flatten(signature_function.structured_input_signature,
expand_composites=True))
else:
# Structured input signature isn't always defined for some functions.
inputs = signature_function.inputs
for keyword, inp in zip(
signature_function._arg_keywords, # pylint: disable=protected-access
inputs):
keyword = compat.as_str(keyword)
if isinstance(inp, tensor_spec.TensorSpec):
spec = tensor_spec.TensorSpec(inp.shape, inp.dtype, name=keyword)
else:
spec = tensor_spec.TensorSpec.from_tensor(inp, name=keyword)
tensor_spec_signature[keyword] = spec
final_concrete = signature_wrapper._get_concrete_function_garbage_collected( # pylint: disable=protected-access
**tensor_spec_signature)
# pylint: disable=protected-access
if len(final_concrete._arg_keywords) == 1:
# If there is only one input to the signature, a very common case, then
# ordering is unambiguous and we can let people pass a positional
# argument. Since SignatureDefs are unordered (protobuf "map") multiple
# arguments means we need to be keyword-only.
final_concrete._num_positional_args = 1
else:
final_concrete._num_positional_args = 0
# pylint: enable=protected-access
concrete_signatures[signature_key] = final_concrete
# pylint: enable=cell-var-from-loop
return concrete_signatures, wrapped_functions
def _normalize_outputs(outputs, function_name, signature_key):
"""Construct an output dictionary from unnormalized function outputs."""
# Convert `outputs` to a dictionary (if it's not one already).
if not isinstance(outputs, collections_abc.Mapping):
if not isinstance(outputs, collections_abc.Sequence):
outputs = [outputs]
outputs = {("output_{}".format(output_index)): output
for output_index, output
in enumerate(outputs)}
# Check that the keys of `outputs` are strings and the values are Tensors.
for key, value in outputs.items():
if not isinstance(key, compat.bytes_or_text_types):
raise ValueError(
("Got a dictionary with a non-string key {!r} in the output of the "
"function {} used to generate the SavedModel signature {!r}.")
.format(key, compat.as_str_any(function_name), signature_key))
if not isinstance(value, ops.Tensor):
raise ValueError(
("Got a non-Tensor value {!r} for key {!r} in the output of the "
"function {} used to generate the SavedModel signature {!r}. "
"Outputs for functions used as signatures must be a single Tensor, "
"a sequence of Tensors, or a dictionary from string to Tensor.")
.format(value, key, compat.as_str_any(function_name), signature_key))
return outputs
# _SignatureMap is immutable to ensure that users do not expect changes to be
# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the
# only way to create a _SignatureMap and there is no way to modify it. So we can
# safely ignore/overwrite ".signatures" attributes attached to objects being
# saved if they contain a _SignatureMap. A ".signatures" attribute containing
# any other type (e.g. a regular dict) will raise an exception asking the user
# to first "del obj.signatures" if they want it overwritten.
class _SignatureMap(collections_abc.Mapping, base.Trackable):
"""A collection of SavedModel signatures."""
def __init__(self):
self._signatures = {}
def _add_signature(self, name, concrete_function):
"""Adds a signature to the _SignatureMap."""
# Ideally this object would be immutable, but restore is streaming so we do
# need a private API for adding new signatures to an existing object.
self._signatures[name] = concrete_function
def __getitem__(self, key):
return self._signatures[key]
def __iter__(self):
return iter(self._signatures)
def __len__(self):
return len(self._signatures)
def __repr__(self):
return "_SignatureMap({})".format(self._signatures)
def _list_functions_for_serialization(self, unused_serialization_cache):
return {
key: value for key, value in self.items()
if isinstance(value, (def_function.Function, defun.ConcreteFunction))
}
revived_types.register_revived_type(
"signature_map",
lambda obj: isinstance(obj, _SignatureMap),
versions=[revived_types.VersionedTypeRegistration(
# Standard dependencies are enough to reconstruct the trackable
# items in dictionaries, so we don't need to save any extra information.
object_factory=lambda proto: _SignatureMap(),
version=1,
min_producer_version=1,
min_consumer_version=1,
setter=_SignatureMap._add_signature # pylint: disable=protected-access
)])
def create_signature_map(signatures):
"""Creates an object containing `signatures`."""
signature_map = _SignatureMap()
for name, func in signatures.items():
# This true of any signature that came from canonicalize_signatures. Here as
# a sanity check on saving; crashing on load (e.g. in _add_signature) would
# be more problematic in case future export changes violated these
# assertions.
assert isinstance(func, defun.ConcreteFunction)
assert isinstance(func.structured_outputs, collections_abc.Mapping)
# pylint: disable=protected-access
if len(func._arg_keywords) == 1:
assert 1 == func._num_positional_args
else:
assert 0 == func._num_positional_args
signature_map._add_signature(name, func)
# pylint: enable=protected-access
return signature_map
def validate_saveable_view(saveable_view):
"""Performs signature-related sanity checks on `saveable_view`."""
for name, dep in saveable_view.list_dependencies(
saveable_view.root):
if name == SIGNATURE_ATTRIBUTE_NAME:
if not isinstance(dep, _SignatureMap):
raise ValueError(
("Exporting an object {} which has an attribute named "
"'{signatures}'. This is a reserved attribute used to store "
"SavedModel signatures in objects which come from "
"`tf.saved_model.load`. Delete this attribute "
"(e.g. 'del obj.{signatures}') before saving if this shadowing is "
"acceptable.").format(
saveable_view.root,
signatures=SIGNATURE_ATTRIBUTE_NAME))
break