STT-tensorflow/tensorflow/python/saved_model/function_serialization.py
Edward Loper f39aab3092 Functions returned by the get_concrete_function method of tf.Function objects can now be called with arguments consistent with the original arguments or type specs passed to get_concrete_function. In particular:
* If a composite tensor (such as RaggedTensor or SparseTensor) was passed to `get_concrete_function`, then the returned function will accept a composite tensor of the same type for that argument.

* If a nested structure (such as a list or dict) was passed to `get_concrete_function`, then the returned function will accept a value with the same nesting structure.  Each tensor or composite tensor value must have the same type as was used in the original argument; and each non-Tensor value (such as bools or ints) must be equal value that was used in the original argument.

* If a non-tensor value (such as a bool or int) was passed to `get_concrete_function`, then the returned function no longer deletes that argument; instead, it updates the argument's default value to the value that was passed to `get_concrete_function`.  Passing in any other value will raise an exception.

* Arguments are not renamed based on `TensorSpec.name`.

For backwards compatibility, the functions returned by `get_concrete_function` will continue to accept arguments with the existing calling conventions (where nested structures and composite tensors are flattened; non-tensor arguments are deleted; suffixes are automatically added to disambiguate arguments with the same name; and TensorSpec.name is used to rename arguments).  However, the preferred calling convention is the one that is consistent with the original arguments or type specs passed to `get_concrete_function`.

PiperOrigin-RevId: 307398918
Change-Id: Ie4685b32d9f151c82f6c79a6c41379faa96b5ee8
2020-04-20 08:04:18 -07:00

167 lines
7.0 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tools for serializing `Function`s."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import saved_object_graph_pb2
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.util import compat
from tensorflow.python.util import nest
def _serialize_function_spec(function_spec, coder):
"""Serialize a FunctionSpec object into its proto representation."""
if function_spec.is_method and not function_spec.fullargspec.args:
raise NotImplementedError(
"Missing support to serialize a method function without a named "
"'self' argument.")
proto = saved_object_graph_pb2.FunctionSpec()
# Intentionally skip encoding annotations of a function because function
# annotations are mainly for optional type checking during development
# and does not affect runtime behavior.
# https://www.python.org/dev/peps/pep-3107/
# https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
proto.fullargspec.CopyFrom(
coder.encode_structure(
function_spec.fullargspec._replace(annotations={})))
proto.is_method = function_spec.is_method
proto.input_signature.CopyFrom(
coder.encode_structure(function_spec.input_signature))
return proto
def serialize_concrete_function(concrete_function, node_ids, coder):
"""Build a SavedConcreteFunction."""
bound_inputs = []
try:
for capture in concrete_function.captured_inputs:
bound_inputs.append(node_ids[capture])
except KeyError:
raise KeyError(
"Failed to add concrete function %s to object based saved model as it "
"captures tensor %s which is unsupported or not reachable from root. "
"One reason could be that a stateful object or a variable that the "
"function depends on is not assigned to an attribute of the serialized "
"trackable object "
"(see SaveTest.test_captures_unreachable_variable)."
% (concrete_function.name, capture))
concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
structured_outputs = func_graph_module.convert_structure_to_signature(
concrete_function.structured_outputs)
concrete_function_proto.canonicalized_input_signature.CopyFrom(
coder.encode_structure(concrete_function.structured_input_signature))
concrete_function_proto.output_signature.CopyFrom(
coder.encode_structure(structured_outputs))
concrete_function_proto.bound_inputs.extend(bound_inputs)
return concrete_function_proto
def serialize_bare_concrete_function(concrete_function, name_map):
"""Build a SavedBareConcreteFunction."""
# TODO(edloper): Currently, bare concrete functions don't have access to a
# function_spec, so they can't be called with the structured signature.
# Update the serialization to include a function_spec.
# pylint: disable=protected-access
name = name_map.get(compat.as_text(concrete_function.name),
concrete_function.name)
return saved_object_graph_pb2.SavedBareConcreteFunction(
concrete_function_name=name,
allowed_positional_arguments=concrete_function._num_positional_args,
argument_keywords=concrete_function._arg_keywords)
# pylint: enable=protected-access
def serialize_function(function, name_map):
"""Build a SavedFunction proto."""
coder = nested_structure_coder.StructureCoder()
proto = saved_object_graph_pb2.SavedFunction()
function_spec_proto = _serialize_function_spec(function.function_spec, coder)
proto.function_spec.CopyFrom(function_spec_proto)
all_concrete_functions = \
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
for concrete_function in all_concrete_functions:
proto.concrete_functions.append(
name_map.get(compat.as_text(concrete_function.name),
concrete_function.name))
return proto
def wrap_cached_variables(concrete_function):
"""Wraps the concrete function if it uses cached read tensors.
This function creates a new concrete function that captures variables
instead of the cached read tensors.
Args:
concrete_function: A Concrete function that maybe captures cached read
tensors.
Returns:
A concrete function that wraps the original concrete function, which
captures variables instead. If the original function did not capture any
cached values, then the function is not wrapped and the original object is
returned.
"""
outer_graph = func_graph_module.FuncGraph(
"{}_no_cache".format(concrete_function.graph.name))
captures = concrete_function.graph._captures # pylint: disable=protected-access
mapped_captures = None
remapped_captures = {}
# Update the external captures to use read tensors generated in the outer
# graph.
with outer_graph.as_default():
for capture, placeholder in concrete_function.graph.captures:
cached_variable = getattr(capture, "_cached_variable", None)
if cached_variable is None:
continue
cached_variable = cached_variable()
new_cached_value = cached_variable.read_value()
remapped_captures[id(capture)] = captures[id(capture)]
captures[id(capture)] = (new_cached_value, placeholder)
mapped_captures = True
if not mapped_captures:
return concrete_function
inner_concrete = defun.ConcreteFunction(concrete_function.graph)
def wrap_function(*args):
return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access
args = nest.flatten(concrete_function.structured_input_signature,
expand_composites=True)
func_graph_module.func_graph_from_py_func(
None, wrap_function, args=tuple(args), kwargs={},
func_graph=outer_graph)
fn = defun.ConcreteFunction(
outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
# Return the captures to their original values
for key, capture in remapped_captures.items():
captures[key] = capture
return fn