* If a composite tensor (such as RaggedTensor or SparseTensor) was passed to `get_concrete_function`, then the returned function will accept a composite tensor of the same type for that argument. * If a nested structure (such as a list or dict) was passed to `get_concrete_function`, then the returned function will accept a value with the same nesting structure. Each tensor or composite tensor value must have the same type as was used in the original argument; and each non-Tensor value (such as bools or ints) must be equal value that was used in the original argument. * If a non-tensor value (such as a bool or int) was passed to `get_concrete_function`, then the returned function no longer deletes that argument; instead, it updates the argument's default value to the value that was passed to `get_concrete_function`. Passing in any other value will raise an exception. * Arguments are not renamed based on `TensorSpec.name`. For backwards compatibility, the functions returned by `get_concrete_function` will continue to accept arguments with the existing calling conventions (where nested structures and composite tensors are flattened; non-tensor arguments are deleted; suffixes are automatically added to disambiguate arguments with the same name; and TensorSpec.name is used to rename arguments). However, the preferred calling convention is the one that is consistent with the original arguments or type specs passed to `get_concrete_function`. PiperOrigin-RevId: 307398918 Change-Id: Ie4685b32d9f151c82f6c79a6c41379faa96b5ee8
167 lines
7.0 KiB
Python
167 lines
7.0 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tools for serializing `Function`s."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.core.protobuf import saved_object_graph_pb2
|
|
from tensorflow.python.eager import function as defun
|
|
from tensorflow.python.framework import func_graph as func_graph_module
|
|
from tensorflow.python.saved_model import nested_structure_coder
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util import nest
|
|
|
|
|
|
def _serialize_function_spec(function_spec, coder):
|
|
"""Serialize a FunctionSpec object into its proto representation."""
|
|
if function_spec.is_method and not function_spec.fullargspec.args:
|
|
raise NotImplementedError(
|
|
"Missing support to serialize a method function without a named "
|
|
"'self' argument.")
|
|
proto = saved_object_graph_pb2.FunctionSpec()
|
|
|
|
# Intentionally skip encoding annotations of a function because function
|
|
# annotations are mainly for optional type checking during development
|
|
# and does not affect runtime behavior.
|
|
# https://www.python.org/dev/peps/pep-3107/
|
|
# https://docs.python.org/3/library/inspect.html#inspect.getfullargspec
|
|
proto.fullargspec.CopyFrom(
|
|
coder.encode_structure(
|
|
function_spec.fullargspec._replace(annotations={})))
|
|
|
|
proto.is_method = function_spec.is_method
|
|
proto.input_signature.CopyFrom(
|
|
coder.encode_structure(function_spec.input_signature))
|
|
return proto
|
|
|
|
|
|
def serialize_concrete_function(concrete_function, node_ids, coder):
|
|
"""Build a SavedConcreteFunction."""
|
|
bound_inputs = []
|
|
try:
|
|
for capture in concrete_function.captured_inputs:
|
|
bound_inputs.append(node_ids[capture])
|
|
except KeyError:
|
|
raise KeyError(
|
|
"Failed to add concrete function %s to object based saved model as it "
|
|
"captures tensor %s which is unsupported or not reachable from root. "
|
|
"One reason could be that a stateful object or a variable that the "
|
|
"function depends on is not assigned to an attribute of the serialized "
|
|
"trackable object "
|
|
"(see SaveTest.test_captures_unreachable_variable)."
|
|
% (concrete_function.name, capture))
|
|
concrete_function_proto = saved_object_graph_pb2.SavedConcreteFunction()
|
|
structured_outputs = func_graph_module.convert_structure_to_signature(
|
|
concrete_function.structured_outputs)
|
|
concrete_function_proto.canonicalized_input_signature.CopyFrom(
|
|
coder.encode_structure(concrete_function.structured_input_signature))
|
|
concrete_function_proto.output_signature.CopyFrom(
|
|
coder.encode_structure(structured_outputs))
|
|
concrete_function_proto.bound_inputs.extend(bound_inputs)
|
|
return concrete_function_proto
|
|
|
|
|
|
def serialize_bare_concrete_function(concrete_function, name_map):
|
|
"""Build a SavedBareConcreteFunction."""
|
|
# TODO(edloper): Currently, bare concrete functions don't have access to a
|
|
# function_spec, so they can't be called with the structured signature.
|
|
# Update the serialization to include a function_spec.
|
|
|
|
# pylint: disable=protected-access
|
|
name = name_map.get(compat.as_text(concrete_function.name),
|
|
concrete_function.name)
|
|
return saved_object_graph_pb2.SavedBareConcreteFunction(
|
|
concrete_function_name=name,
|
|
allowed_positional_arguments=concrete_function._num_positional_args,
|
|
argument_keywords=concrete_function._arg_keywords)
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
def serialize_function(function, name_map):
|
|
"""Build a SavedFunction proto."""
|
|
coder = nested_structure_coder.StructureCoder()
|
|
proto = saved_object_graph_pb2.SavedFunction()
|
|
|
|
function_spec_proto = _serialize_function_spec(function.function_spec, coder)
|
|
proto.function_spec.CopyFrom(function_spec_proto)
|
|
all_concrete_functions = \
|
|
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
|
|
for concrete_function in all_concrete_functions:
|
|
proto.concrete_functions.append(
|
|
name_map.get(compat.as_text(concrete_function.name),
|
|
concrete_function.name))
|
|
return proto
|
|
|
|
|
|
def wrap_cached_variables(concrete_function):
|
|
"""Wraps the concrete function if it uses cached read tensors.
|
|
|
|
This function creates a new concrete function that captures variables
|
|
instead of the cached read tensors.
|
|
|
|
Args:
|
|
concrete_function: A Concrete function that maybe captures cached read
|
|
tensors.
|
|
|
|
Returns:
|
|
A concrete function that wraps the original concrete function, which
|
|
captures variables instead. If the original function did not capture any
|
|
cached values, then the function is not wrapped and the original object is
|
|
returned.
|
|
"""
|
|
outer_graph = func_graph_module.FuncGraph(
|
|
"{}_no_cache".format(concrete_function.graph.name))
|
|
captures = concrete_function.graph._captures # pylint: disable=protected-access
|
|
mapped_captures = None
|
|
remapped_captures = {}
|
|
|
|
# Update the external captures to use read tensors generated in the outer
|
|
# graph.
|
|
with outer_graph.as_default():
|
|
for capture, placeholder in concrete_function.graph.captures:
|
|
cached_variable = getattr(capture, "_cached_variable", None)
|
|
if cached_variable is None:
|
|
continue
|
|
cached_variable = cached_variable()
|
|
new_cached_value = cached_variable.read_value()
|
|
remapped_captures[id(capture)] = captures[id(capture)]
|
|
captures[id(capture)] = (new_cached_value, placeholder)
|
|
mapped_captures = True
|
|
|
|
if not mapped_captures:
|
|
return concrete_function
|
|
|
|
inner_concrete = defun.ConcreteFunction(concrete_function.graph)
|
|
|
|
def wrap_function(*args):
|
|
return inner_concrete._call_flat(args, inner_concrete.captured_inputs) # pylint:disable=protected-access
|
|
|
|
args = nest.flatten(concrete_function.structured_input_signature,
|
|
expand_composites=True)
|
|
func_graph_module.func_graph_from_py_func(
|
|
None, wrap_function, args=tuple(args), kwargs={},
|
|
func_graph=outer_graph)
|
|
fn = defun.ConcreteFunction(
|
|
outer_graph, function_spec=concrete_function._function_spec) # pylint: disable=protected-access
|
|
fn._arg_keywords = concrete_function._arg_keywords # pylint: disable=protected-access
|
|
fn._num_positional_args = concrete_function._num_positional_args # pylint: disable=protected-access
|
|
|
|
# Return the captures to their original values
|
|
for key, capture in remapped_captures.items():
|
|
captures[key] = capture
|
|
return fn
|