1199 lines
54 KiB
Python
1199 lines
54 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Exports a SavedModel from a Trackable Python object."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import functools
|
|
import gc
|
|
import os
|
|
|
|
from absl import logging
|
|
from tensorflow.core.framework import versions_pb2
|
|
from tensorflow.core.protobuf import meta_graph_pb2
|
|
from tensorflow.core.protobuf import saved_model_pb2
|
|
from tensorflow.core.protobuf import saved_object_graph_pb2
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.eager import def_function
|
|
from tensorflow.python.eager import function as defun
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import error_interpolation
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import meta_graph
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.framework import versions
|
|
from tensorflow.python.lib.io import file_io
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import resource_variable_ops
|
|
from tensorflow.python.saved_model import builder_impl
|
|
from tensorflow.python.saved_model import constants
|
|
from tensorflow.python.saved_model import function_serialization
|
|
from tensorflow.python.saved_model import nested_structure_coder
|
|
from tensorflow.python.saved_model import revived_types
|
|
from tensorflow.python.saved_model import save_context
|
|
from tensorflow.python.saved_model import save_options
|
|
from tensorflow.python.saved_model import signature_constants
|
|
from tensorflow.python.saved_model import signature_def_utils
|
|
from tensorflow.python.saved_model import signature_serialization
|
|
from tensorflow.python.saved_model import tag_constants
|
|
from tensorflow.python.saved_model import utils_impl
|
|
from tensorflow.python.training.saving import checkpoint_options
|
|
from tensorflow.python.training.saving import functional_saver
|
|
from tensorflow.python.training.saving import saveable_object_util
|
|
from tensorflow.python.training.tracking import base
|
|
from tensorflow.python.training.tracking import graph_view
|
|
from tensorflow.python.training.tracking import tracking
|
|
from tensorflow.python.training.tracking import util
|
|
from tensorflow.python.util import compat
|
|
from tensorflow.python.util import object_identity
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
|
|
|
|
# A container for an EagerTensor constant which has been copied to the exported
|
|
# Graph.
|
|
_CapturedConstant = collections.namedtuple("_CapturedConstant",
|
|
["eager_tensor", "graph_tensor"])
|
|
|
|
# Number of untraced functions to display to user in warning message.
|
|
_NUM_DISPLAY_UNTRACED_FUNCTIONS = 5
|
|
|
|
|
|
class _AugmentedGraphView(graph_view.ObjectGraphView):
|
|
"""An extendable graph which also tracks functions attached to objects.
|
|
|
|
Extensions through `add_object` appear in the object graph and any checkpoints
|
|
generated from it, even if they are not dependencies of the node they were
|
|
attached to in the saving program. For example a `.signatures` attribute is
|
|
added to exported SavedModel root objects without modifying the root object
|
|
itself.
|
|
|
|
Also tracks functions attached to objects in the graph, through the caching
|
|
`list_functions` method. Enumerating functions only through this method
|
|
ensures that we get a consistent view of functions, even if object attributes
|
|
create new functions every time they are accessed.
|
|
"""
|
|
|
|
def __init__(self, root):
|
|
if (not context.executing_eagerly() and not ops.inside_function()):
|
|
saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
|
|
else:
|
|
saveables_cache = None
|
|
super(_AugmentedGraphView, self).__init__(root, saveables_cache)
|
|
# Object -> (name -> dep)
|
|
self._extra_dependencies = object_identity.ObjectIdentityDictionary()
|
|
self._functions = object_identity.ObjectIdentityDictionary()
|
|
# Cache shared between objects in the same object graph. This is passed to
|
|
# each trackable object's `_list_extra_dependencies_for_serialization` and
|
|
# `_list_functions_for_serialization` function.
|
|
self._serialization_cache = object_identity.ObjectIdentityDictionary()
|
|
|
|
def add_object(self, parent_node, name_in_parent, subgraph_root):
|
|
"""Attach an object to `parent_node`, overriding any existing dependency."""
|
|
self._extra_dependencies.setdefault(parent_node,
|
|
{})[name_in_parent] = subgraph_root
|
|
|
|
def list_dependencies(self, obj):
|
|
"""Overrides a parent method to include `add_object` objects."""
|
|
extra_dependencies = self.list_extra_dependencies(obj)
|
|
extra_dependencies.update(self._extra_dependencies.get(obj, {}))
|
|
|
|
used_names = set()
|
|
for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
|
|
used_names.add(name)
|
|
if name in extra_dependencies:
|
|
# Extra dependencies (except for `.signatures`, which is always added
|
|
# when saving) should not have naming conflicts with dependencies
|
|
# defined by the user.
|
|
if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME:
|
|
raise ValueError(
|
|
"Error when exporting object {} of with identifier={}. The object"
|
|
" has an attribute named {}, which is reserved. List of all "
|
|
"reserved attributes: {}".format(
|
|
obj,
|
|
obj._object_identifier, # pylint: disable=protected-access
|
|
name,
|
|
extra_dependencies.keys()))
|
|
yield base.TrackableReference(name, extra_dependencies[name])
|
|
else:
|
|
yield base.TrackableReference(name, dep)
|
|
for name, dep in extra_dependencies.items():
|
|
if name in used_names:
|
|
continue
|
|
yield base.TrackableReference(name, dep)
|
|
|
|
def list_extra_dependencies(self, obj):
|
|
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
|
|
self._serialization_cache)
|
|
|
|
def list_functions(self, obj, extra_functions=None):
|
|
obj_functions = self._functions.get(obj, None)
|
|
if obj_functions is None:
|
|
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
|
|
self._serialization_cache)
|
|
self._functions[obj] = obj_functions
|
|
if extra_functions:
|
|
obj_functions = obj_functions.copy()
|
|
obj_functions.update(extra_functions)
|
|
return obj_functions
|
|
|
|
|
|
class _SaveableView(object):
|
|
"""Provides a frozen view over a trackable root.
|
|
|
|
This class helps to create a single stable view over an object to save. The
|
|
saving code should access properties and functions via this class and not via
|
|
the original object as there are cases where an object construct their
|
|
trackable attributes and functions dynamically per call and will yield
|
|
different objects if invoked more than once.
|
|
|
|
Changes to the graph, for example adding objects, must happen in
|
|
`checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is
|
|
constructed. Changes after the `_SaveableView` has been constructed will be
|
|
ignored.
|
|
"""
|
|
|
|
def __init__(self, checkpoint_view, options, wrapped_functions=None):
|
|
"""Initializes a SaveableView.
|
|
|
|
Args:
|
|
checkpoint_view: A GraphView object.
|
|
options: A SaveOptions instance.
|
|
wrapped_functions: Dictionary that maps concrete functions to functions
|
|
that do not capture cached variable values.
|
|
"""
|
|
self.options = options
|
|
self.checkpoint_view = checkpoint_view
|
|
trackable_objects, node_ids, slot_variables = (
|
|
self.checkpoint_view.objects_ids_and_slot_variables())
|
|
self.nodes = trackable_objects
|
|
self.node_ids = node_ids
|
|
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
|
self.slot_variables = slot_variables
|
|
self.concrete_functions = []
|
|
self.untraced_functions = []
|
|
|
|
self.saveable_objects_for_node, all_saveable_functions = (
|
|
self._add_saveable_objects())
|
|
saveable_object_functions = {
|
|
"__SAVEABLE_FUNCTION_{}".format(n): fn
|
|
for n, fn in enumerate(all_saveable_functions)}
|
|
|
|
# Maps functions -> wrapped functions that capture variables
|
|
self.wrapped_functions = wrapped_functions or {}
|
|
# Maps names of concrete functions in the object to names of wrapped
|
|
# functions. When writing the SavedFunction protos, the names of the
|
|
# wrapped functions should be used in place of the original functions.
|
|
self.function_name_map = {
|
|
compat.as_text(original.name): compat.as_text(wrapped.name)
|
|
for original, wrapped in self.wrapped_functions.items()}
|
|
|
|
# Also add `Function`s as nodes.
|
|
nodes_without_functions = list(self.nodes)
|
|
seen_function_names = set()
|
|
for node in nodes_without_functions:
|
|
for function in checkpoint_view.list_functions(
|
|
node, saveable_object_functions).values():
|
|
if function not in self.node_ids:
|
|
self.node_ids[function] = len(self.nodes)
|
|
self.nodes.append(function)
|
|
if isinstance(function, def_function.Function):
|
|
# Force listing the concrete functions for the side effects:
|
|
# - populate the cache for functions that have an input_signature
|
|
# and have not been called.
|
|
# - force side effects of creation of concrete functions, e.g. create
|
|
# variables on first run.
|
|
concrete_functions = (
|
|
function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access
|
|
else:
|
|
concrete_functions = [function]
|
|
if not concrete_functions:
|
|
self.untraced_functions.append(function._name)
|
|
|
|
for concrete_function in concrete_functions:
|
|
if concrete_function.name not in seen_function_names:
|
|
seen_function_names.add(concrete_function.name)
|
|
self.concrete_functions.append(concrete_function)
|
|
if self.untraced_functions:
|
|
logging.warning(
|
|
"Found untraced functions such as %s while saving (showing %d of %d)."
|
|
" These functions will not be directly callable after loading.",
|
|
", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
|
|
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
|
|
len(self.untraced_functions))
|
|
|
|
def _add_saveable_objects(self):
|
|
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
|
# Maps node -> local name -> (save function, restore function)
|
|
saveable_objects_map = object_identity.ObjectIdentityDictionary()
|
|
all_saveable_functions = []
|
|
for node in self.nodes:
|
|
if resource_variable_ops.is_resource_variable(node):
|
|
# Resource (and TPU/Mirrored) variables are automatically revived with
|
|
# their saveables defined, so there is no need to trace the save
|
|
# and restore functions.
|
|
continue
|
|
saveable_map = saveable_object_util.trace_save_restore_functions(node)
|
|
if saveable_map:
|
|
saveable_objects_map[node] = saveable_map
|
|
for save_fn, restore_fn in saveable_map.values():
|
|
all_saveable_functions.append(save_fn)
|
|
all_saveable_functions.append(restore_fn)
|
|
return saveable_objects_map, all_saveable_functions
|
|
|
|
@property
|
|
def root(self):
|
|
return self.nodes[0]
|
|
|
|
def fill_object_graph_proto(self, proto):
|
|
"""Populate the nodes, children and slot_variables of a SavedObjectGraph."""
|
|
for node_id, node in enumerate(self.nodes):
|
|
assert self.node_ids[node] == node_id
|
|
object_proto = proto.nodes.add()
|
|
object_proto.slot_variables.extend(self.slot_variables.get(node, ()))
|
|
if isinstance(
|
|
node,
|
|
(def_function.Function, defun.ConcreteFunction, _CapturedConstant)):
|
|
continue
|
|
for child in self.checkpoint_view.list_dependencies(node):
|
|
child_proto = object_proto.children.add()
|
|
child_proto.node_id = self.node_ids[child.ref]
|
|
child_proto.local_name = child.name
|
|
for local_name, ref_function in (
|
|
self.checkpoint_view.list_functions(node).items()):
|
|
child_proto = object_proto.children.add()
|
|
child_proto.node_id = self.node_ids[ref_function]
|
|
child_proto.local_name = local_name
|
|
|
|
if node not in self.saveable_objects_for_node:
|
|
continue
|
|
|
|
for local_name, (save_fn, restore_fn) in (
|
|
self.saveable_objects_for_node[node].items()):
|
|
saveable_object_proto = object_proto.saveable_objects[local_name]
|
|
saveable_object_proto.save_function = self.node_ids[save_fn]
|
|
saveable_object_proto.restore_function = self.node_ids[restore_fn]
|
|
|
|
def map_resources(self):
|
|
"""Makes new resource handle ops corresponding to existing resource tensors.
|
|
|
|
Creates resource handle ops in the current default graph, whereas
|
|
`accessible_objects` will be from an eager context. Resource mapping adds
|
|
resource handle ops to the main GraphDef of a SavedModel, which allows the
|
|
C++ loader API to interact with resources.
|
|
|
|
Returns:
|
|
A tuple of (object_map, resource_map, asset_info):
|
|
object_map: A dictionary mapping from object in `accessible_objects` to
|
|
replacement objects created to hold the new resource tensors.
|
|
resource_map: A dictionary mapping from resource tensors extracted from
|
|
`accessible_objects` to newly created resource tensors.
|
|
asset_info: An _AssetInfo tuple describing external assets referenced
|
|
from accessible_objects.
|
|
"""
|
|
# Only makes sense when adding to the export Graph
|
|
assert not context.executing_eagerly()
|
|
# TODO(allenl): Handle MirroredVariables and other types of variables which
|
|
# may need special casing.
|
|
object_map = object_identity.ObjectIdentityDictionary()
|
|
resource_map = {}
|
|
asset_info = _AssetInfo(
|
|
asset_defs=[],
|
|
asset_initializers_by_resource={},
|
|
asset_filename_map={},
|
|
asset_index={})
|
|
|
|
for node_id, obj in enumerate(self.nodes):
|
|
if isinstance(obj, tracking.Asset):
|
|
_process_asset(obj, asset_info, resource_map)
|
|
self.captured_tensor_node_ids[obj.asset_path] = node_id
|
|
elif isinstance(obj, base.Trackable):
|
|
node_object_map, node_resource_map = obj._map_resources(self.options) # pylint: disable=protected-access
|
|
for capturable in node_resource_map.keys():
|
|
self.captured_tensor_node_ids[capturable] = node_id
|
|
object_map.update(node_object_map)
|
|
resource_map.update(node_resource_map)
|
|
|
|
# Note: some concrete functions can have been realized when tracing other
|
|
# functions, and might closure-capture tensors from their parent functions.
|
|
# This is normal, but it means those concrete functions can't be serialized
|
|
# as their own independent endpoints, so we filter them out here.
|
|
bad_functions = []
|
|
for concrete_function in self.concrete_functions:
|
|
if not concrete_function.graph.saveable:
|
|
raise ValueError(
|
|
("Unable to save function {name} for the following reason(s):\n" +
|
|
"\n".join(concrete_function.graph.saving_errors)).format(
|
|
name=concrete_function.name))
|
|
for capture in concrete_function.captured_inputs:
|
|
if (tensor_util.is_tensor(capture) and
|
|
capture.dtype not in _UNCOPIABLE_DTYPES and
|
|
capture not in self.captured_tensor_node_ids):
|
|
if hasattr(capture, "_cached_variable"):
|
|
if concrete_function not in self.wrapped_functions:
|
|
wrapped = self.wrapped_functions[concrete_function] = (
|
|
function_serialization.wrap_cached_variables(
|
|
concrete_function))
|
|
self.function_name_map[compat.as_text(concrete_function.name)] = (
|
|
compat.as_text(wrapped.name))
|
|
continue
|
|
capture_constant_value = tensor_util.constant_value(capture)
|
|
if capture_constant_value is None:
|
|
bad_functions.append(concrete_function)
|
|
continue
|
|
copied_tensor = constant_op.constant(capture_constant_value)
|
|
node_id = len(self.nodes)
|
|
node = _CapturedConstant(
|
|
eager_tensor=capture, graph_tensor=copied_tensor)
|
|
self.nodes.append(node)
|
|
self.node_ids[capture] = node_id
|
|
self.node_ids[node] = node_id
|
|
self.captured_tensor_node_ids[capture] = node_id
|
|
resource_map[capture] = copied_tensor
|
|
|
|
self.concrete_functions = [
|
|
self.wrapped_functions.get(x, x) for x in self.concrete_functions
|
|
if x not in bad_functions
|
|
]
|
|
return object_map, resource_map, asset_info
|
|
|
|
|
|
def _tensor_dict_to_tensorinfo(tensor_dict):
|
|
return {
|
|
key: utils_impl.build_tensor_info_internal(value)
|
|
for key, value in tensor_dict.items()
|
|
}
|
|
|
|
|
|
def _map_captures_to_created_tensors(original_captures, resource_map):
|
|
"""Maps eager tensors captured by a function to Graph resources for export.
|
|
|
|
Args:
|
|
original_captures: A dictionary mapping from tensors captured by the
|
|
function to interior placeholders for those tensors (inside the function
|
|
body).
|
|
resource_map: A dictionary mapping from resource tensors owned by the eager
|
|
context to resource tensors in the exported graph.
|
|
|
|
Returns:
|
|
A list of stand-in tensors which belong to the exported graph, corresponding
|
|
to the function's captures.
|
|
|
|
Raises:
|
|
AssertionError: If the function references a resource which is not part of
|
|
`resource_map`.
|
|
"""
|
|
export_captures = []
|
|
for exterior, interior in original_captures:
|
|
mapped_resource = resource_map.get(exterior, None)
|
|
if mapped_resource is None:
|
|
trackable_referrers = []
|
|
# Try to figure out where the resource came from by iterating over objects
|
|
# which reference it. This is slow and doesn't help us figure out how to
|
|
# match it to other objects when loading the SavedModel as a checkpoint,
|
|
# so we can't continue saving. But we can at least tell the user what
|
|
# needs attaching.
|
|
for primary_referrer in gc.get_referrers(exterior):
|
|
if isinstance(primary_referrer, base.Trackable):
|
|
trackable_referrers.append(primary_referrer)
|
|
for secondary_referrer in gc.get_referrers(primary_referrer):
|
|
if isinstance(secondary_referrer, base.Trackable):
|
|
trackable_referrers.append(secondary_referrer)
|
|
raise AssertionError(
|
|
("Tried to export a function which references untracked resource {}."
|
|
"TensorFlow objects (e.g. tf.Variable) captured by functions must "
|
|
"be tracked by assigning them to an attribute of a tracked object "
|
|
"or assigned to an attribute of the main object directly.\n\n"
|
|
"Trackable Python objects referring to this tensor "
|
|
"(from gc.get_referrers, limited to two hops):\n{}"
|
|
).format(interior,
|
|
"\n".join([repr(obj) for obj in trackable_referrers])))
|
|
export_captures.append(mapped_resource)
|
|
return export_captures
|
|
|
|
|
|
def _map_function_arguments_to_created_inputs(function_arguments, signature_key,
|
|
function_name):
|
|
"""Creates exterior placeholders in the exported graph for function arguments.
|
|
|
|
Functions have two types of inputs: tensors captured from the outside (eager)
|
|
context, and arguments to the function which we expect to receive from the
|
|
user at each call. `_map_captures_to_created_tensors` replaces
|
|
captured tensors with stand-ins (typically these are resource dtype tensors
|
|
associated with variables). `_map_function_inputs_to_created_inputs` runs over
|
|
every argument, creating a new placeholder for each which will belong to the
|
|
exported graph rather than the function body.
|
|
|
|
Args:
|
|
function_arguments: A list of argument placeholders in the function body.
|
|
signature_key: The name of the signature being exported, for error messages.
|
|
function_name: The name of the function, for error messages.
|
|
|
|
Returns:
|
|
A tuple of (mapped_inputs, exterior_placeholders)
|
|
mapped_inputs: A list with entries corresponding to `function_arguments`
|
|
containing all of the inputs of the function gathered from the exported
|
|
graph (both captured resources and arguments).
|
|
exterior_argument_placeholders: A dictionary mapping from argument names
|
|
to placeholders in the exported graph, containing the explicit arguments
|
|
to the function which a user is expected to provide.
|
|
|
|
Raises:
|
|
ValueError: If argument names are not unique.
|
|
"""
|
|
# `exterior_argument_placeholders` holds placeholders which are outside the
|
|
# function body, directly contained in a MetaGraph of the SavedModel. The
|
|
# function body itself contains nearly identical placeholders used when
|
|
# running the function, but these exterior placeholders allow Session-based
|
|
# APIs to call the function using feeds and fetches which name Tensors in the
|
|
# MetaGraph.
|
|
exterior_argument_placeholders = {}
|
|
mapped_inputs = []
|
|
for placeholder in function_arguments:
|
|
# `export_captures` contains an exhaustive set of captures, so if we don't
|
|
# find the input there then we now know we have an argument.
|
|
user_input_name = compat.as_str_any(
|
|
placeholder.op.get_attr("_user_specified_name"))
|
|
# If the internal placeholders for a function have names which were
|
|
# uniquified by TensorFlow, then a single user-specified argument name
|
|
# must refer to multiple Tensors. The resulting signatures would be
|
|
# confusing to call. Instead, we throw an exception telling the user to
|
|
# specify explicit names.
|
|
if user_input_name != placeholder.op.name:
|
|
# This should be unreachable, since concrete functions may not be
|
|
# generated with non-unique argument names.
|
|
raise ValueError(
|
|
("Got non-flat/non-unique argument names for SavedModel "
|
|
"signature '{}': more than one argument to '{}' was named '{}'. "
|
|
"Signatures have one Tensor per named input, so to have "
|
|
"predictable names Python functions used to generate these "
|
|
"signatures should avoid *args and Tensors in nested "
|
|
"structures unless unique names are specified for each. Use "
|
|
"tf.TensorSpec(..., name=...) to provide a name for a Tensor "
|
|
"input.").format(signature_key, compat.as_str_any(function_name),
|
|
user_input_name))
|
|
arg_placeholder = array_ops.placeholder(
|
|
shape=placeholder.shape,
|
|
dtype=placeholder.dtype,
|
|
name="{}_{}".format(signature_key, user_input_name))
|
|
exterior_argument_placeholders[user_input_name] = arg_placeholder
|
|
mapped_inputs.append(arg_placeholder)
|
|
return mapped_inputs, exterior_argument_placeholders
|
|
|
|
|
|
def _call_function_with_mapped_captures(function, args, resource_map):
|
|
"""Calls `function` in the exported graph, using mapped resource captures."""
|
|
export_captures = _map_captures_to_created_tensors(function.graph.captures,
|
|
resource_map)
|
|
# Calls the function quite directly, since we have new captured resource
|
|
# tensors we need to feed in which weren't part of the original function
|
|
# definition.
|
|
# pylint: disable=protected-access
|
|
outputs = function._call_flat(args, export_captures)
|
|
# pylint: enable=protected-access
|
|
return outputs
|
|
|
|
|
|
def _generate_signatures(signature_functions, resource_map):
|
|
"""Validates and calls `signature_functions` in the default graph.
|
|
|
|
Args:
|
|
signature_functions: A dictionary mapping string keys to concrete TensorFlow
|
|
functions (e.g. from `signature_serialization.canonicalize_signatures`)
|
|
which will be used to generate SignatureDefs.
|
|
resource_map: A dictionary mapping from resource tensors in the eager
|
|
context to resource tensors in the Graph being exported. This dictionary
|
|
is used to re-bind resources captured by functions to tensors which will
|
|
exist in the SavedModel.
|
|
|
|
Returns:
|
|
Each function in the `signature_functions` dictionary is called with
|
|
placeholder Tensors, generating a function call operation and output
|
|
Tensors. The placeholder Tensors, the function call operation, and the
|
|
output Tensors from the function call are part of the default Graph.
|
|
|
|
This function then returns a dictionary with the same structure as
|
|
`signature_functions`, with the concrete functions replaced by SignatureDefs
|
|
implicitly containing information about how to call each function from a
|
|
TensorFlow 1.x Session / the C++ Loader API. These SignatureDefs reference
|
|
the generated placeholders and Tensor outputs by name.
|
|
|
|
The caller is expected to include the default Graph set while calling this
|
|
function as a MetaGraph in a SavedModel, including the returned
|
|
SignatureDefs as part of that MetaGraph.
|
|
"""
|
|
signatures = {}
|
|
for signature_key, function in sorted(signature_functions.items()):
|
|
if function.graph.captures:
|
|
argument_inputs = function.graph.inputs[:-len(function.graph.captures)]
|
|
else:
|
|
argument_inputs = function.graph.inputs
|
|
mapped_inputs, exterior_argument_placeholders = (
|
|
_map_function_arguments_to_created_inputs(argument_inputs,
|
|
signature_key, function.name))
|
|
outputs = _call_function_with_mapped_captures(
|
|
function, mapped_inputs, resource_map)
|
|
signatures[signature_key] = signature_def_utils.build_signature_def(
|
|
_tensor_dict_to_tensorinfo(exterior_argument_placeholders),
|
|
_tensor_dict_to_tensorinfo(outputs),
|
|
method_name=signature_constants.PREDICT_METHOD_NAME)
|
|
return signatures
|
|
|
|
|
|
def _trace_resource_initializers(accessible_objects):
|
|
"""Create concrete functions from `CapturableResource` objects."""
|
|
resource_initializers = []
|
|
|
|
def _wrap_initializer(obj):
|
|
obj._initialize() # pylint: disable=protected-access
|
|
return constant_op.constant(1.) # Dummy control output
|
|
|
|
def _wrap_obj_initializer(obj):
|
|
return lambda: _wrap_initializer(obj)
|
|
|
|
for obj in accessible_objects:
|
|
if isinstance(obj, tracking.CapturableResource):
|
|
resource_initializers.append(
|
|
def_function.function(
|
|
_wrap_obj_initializer(obj),
|
|
# All inputs are captures.
|
|
input_signature=[]).get_concrete_function())
|
|
return resource_initializers
|
|
|
|
|
|
_AssetInfo = collections.namedtuple(
|
|
"_AssetInfo",
|
|
[
|
|
# List of AssetFileDef protocol buffers
|
|
"asset_defs",
|
|
# Map from asset variable resource Tensors to their init ops
|
|
"asset_initializers_by_resource",
|
|
# Map from base asset filenames to full paths
|
|
"asset_filename_map",
|
|
# Map from Asset to index of corresponding AssetFileDef
|
|
"asset_index"
|
|
])
|
|
|
|
|
|
def _process_asset(trackable_asset, asset_info, resource_map):
|
|
"""Add `trackable_asset` to `asset_info` and `resource_map`."""
|
|
original_path_tensor = trackable_asset.asset_path
|
|
original_path = tensor_util.constant_value(original_path_tensor)
|
|
try:
|
|
original_path = str(original_path.astype(str))
|
|
except AttributeError:
|
|
# Already a string rather than a numpy array
|
|
pass
|
|
path = builder_impl.get_asset_filename_to_add(
|
|
asset_filepath=original_path,
|
|
asset_filename_map=asset_info.asset_filename_map)
|
|
# TODO(andresp): Instead of mapping 1-1 between trackable asset
|
|
# and asset in the graph def consider deduping the assets that
|
|
# point to the same file.
|
|
asset_path_initializer = array_ops.placeholder(
|
|
shape=original_path_tensor.shape,
|
|
dtype=dtypes.string,
|
|
name="asset_path_initializer")
|
|
asset_variable = resource_variable_ops.ResourceVariable(
|
|
asset_path_initializer)
|
|
asset_info.asset_filename_map[path] = original_path
|
|
asset_def = meta_graph_pb2.AssetFileDef()
|
|
asset_def.filename = path
|
|
asset_def.tensor_info.name = asset_path_initializer.name
|
|
asset_info.asset_defs.append(asset_def)
|
|
asset_info.asset_initializers_by_resource[original_path_tensor] = (
|
|
asset_variable.initializer)
|
|
asset_info.asset_index[trackable_asset] = len(asset_info.asset_defs) - 1
|
|
resource_map[original_path_tensor] = asset_variable
|
|
|
|
|
|
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
|
|
namespace_whitelist):
|
|
"""Generates a MetaGraph which calls `signature_functions`.
|
|
|
|
Args:
|
|
meta_graph_def: The MetaGraphDef proto to fill.
|
|
saveable_view: The _SaveableView being exported.
|
|
signature_functions: A dictionary mapping signature keys to concrete
|
|
functions containing signatures to add to the MetaGraph.
|
|
namespace_whitelist: List of strings containing whitelisted op namespaces.
|
|
|
|
Returns:
|
|
A tuple of (_AssetInfo, Graph) containing the captured assets and
|
|
exported Graph generated from tracing the saveable_view.
|
|
"""
|
|
# List objects from the eager context to make sure Optimizers give us the
|
|
# right Graph-dependent variables.
|
|
accessible_objects = saveable_view.nodes
|
|
resource_initializer_functions = _trace_resource_initializers(
|
|
accessible_objects)
|
|
exported_graph = ops.Graph()
|
|
resource_initializer_ops = []
|
|
with exported_graph.as_default():
|
|
object_map, resource_map, asset_info = saveable_view.map_resources()
|
|
for resource_initializer_function in resource_initializer_functions:
|
|
asset_dependencies = []
|
|
for capture in resource_initializer_function.graph.external_captures:
|
|
asset_initializer = asset_info.asset_initializers_by_resource.get(
|
|
capture, None)
|
|
if asset_initializer is not None:
|
|
asset_dependencies.append(asset_initializer)
|
|
with ops.control_dependencies(asset_dependencies):
|
|
resource_initializer_ops.append(
|
|
_call_function_with_mapped_captures(resource_initializer_function,
|
|
[], resource_map))
|
|
resource_initializer_ops.extend(
|
|
asset_info.asset_initializers_by_resource.values())
|
|
with ops.control_dependencies(resource_initializer_ops):
|
|
init_op = control_flow_ops.no_op()
|
|
# Add the same op to the main_op collection and to the init_op
|
|
# signature. The collection is for compatibility with older loader APIs;
|
|
# only one will be executed.
|
|
meta_graph_def.collection_def[constants.MAIN_OP_KEY].node_list.value.append(
|
|
init_op.name)
|
|
meta_graph_def.signature_def[constants.INIT_OP_SIGNATURE_KEY].CopyFrom(
|
|
signature_def_utils.op_signature_def(init_op,
|
|
constants.INIT_OP_SIGNATURE_KEY))
|
|
|
|
# Saving an object-based checkpoint again gathers variables. We need to do the
|
|
# gathering from the eager context so Optimizers save the right set of
|
|
# variables, but want any operations associated with the save/restore to be in
|
|
# the exported graph (thus the `to_graph` argument).
|
|
saver = functional_saver.MultiDeviceSaver(
|
|
saveable_view.checkpoint_view.frozen_saveable_objects(
|
|
object_map=object_map, to_graph=exported_graph,
|
|
call_with_mapped_captures=functools.partial(
|
|
_call_function_with_mapped_captures, resource_map=resource_map)))
|
|
|
|
with exported_graph.as_default():
|
|
signatures = _generate_signatures(signature_functions, resource_map)
|
|
for concrete_function in saveable_view.concrete_functions:
|
|
concrete_function.add_to_graph()
|
|
saver_def = saver.to_proto()
|
|
meta_graph_def.saver_def.CopyFrom(saver_def)
|
|
graph_def = exported_graph.as_graph_def(add_shapes=True)
|
|
_verify_ops(graph_def, namespace_whitelist)
|
|
|
|
meta_graph_def.graph_def.CopyFrom(graph_def)
|
|
meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
|
|
meta_graph_def.meta_info_def.tensorflow_version = versions.__version__
|
|
meta_graph_def.meta_info_def.tensorflow_git_version = (
|
|
versions.__git_version__)
|
|
# We currently always strip default attributes.
|
|
meta_graph_def.meta_info_def.stripped_default_attrs = True
|
|
meta_graph_def.meta_info_def.stripped_op_list.MergeFrom(
|
|
meta_graph.stripped_op_list_for_graph(meta_graph_def.graph_def))
|
|
meta_graph_def.asset_file_def.extend(asset_info.asset_defs)
|
|
for signature_key, signature in signatures.items():
|
|
meta_graph_def.signature_def[signature_key].CopyFrom(signature)
|
|
meta_graph.strip_graph_default_valued_attrs(meta_graph_def)
|
|
return asset_info, exported_graph
|
|
|
|
|
|
def _verify_ops(graph_def, namespace_whitelist):
|
|
"""Verifies that all namespaced ops in the graph are whitelisted."""
|
|
invalid_ops = []
|
|
invalid_namespaces = set()
|
|
|
|
all_operations = []
|
|
all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def))
|
|
|
|
for op in all_operations:
|
|
if ">" in op:
|
|
namespace = op.split(">")[0]
|
|
if namespace not in namespace_whitelist:
|
|
invalid_ops.append(op)
|
|
invalid_namespaces.add(namespace)
|
|
if invalid_ops:
|
|
raise ValueError(
|
|
"Attempted to save ops from non-whitelisted namespaces to SavedModel: "
|
|
"{}.\nPlease verify that these ops should be saved, since they must be "
|
|
"available when loading the SavedModel. If loading from Python, you "
|
|
"must import the library defining these ops. From C++, link the custom "
|
|
"ops to the serving binary. Once you've confirmed this, please add the "
|
|
"following namespaces to the `namespace_whitelist` argument in "
|
|
"tf.saved_model.SaveOptions: {}.".format(invalid_ops,
|
|
invalid_namespaces))
|
|
|
|
|
|
def _serialize_object_graph(saveable_view, asset_file_def_index):
|
|
"""Save a SavedObjectGraph proto for `root`."""
|
|
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
|
|
# checkpoint. It will eventually go into the SavedModel.
|
|
proto = saved_object_graph_pb2.SavedObjectGraph()
|
|
saveable_view.fill_object_graph_proto(proto)
|
|
|
|
coder = nested_structure_coder.StructureCoder()
|
|
for concrete_function in saveable_view.concrete_functions:
|
|
name = compat.as_text(concrete_function.name)
|
|
name = saveable_view.function_name_map.get(name, name)
|
|
serialized = function_serialization.serialize_concrete_function(
|
|
concrete_function, saveable_view.captured_tensor_node_ids, coder)
|
|
if serialized is not None:
|
|
proto.concrete_functions[name].CopyFrom(serialized)
|
|
|
|
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
|
_write_object_proto(obj, obj_proto, asset_file_def_index,
|
|
saveable_view.function_name_map)
|
|
return proto
|
|
|
|
|
|
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
|
"""Saves an object into SavedObject proto."""
|
|
if isinstance(obj, tracking.Asset):
|
|
proto.asset.SetInParent()
|
|
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
|
elif resource_variable_ops.is_resource_variable(obj):
|
|
proto.variable.SetInParent()
|
|
if not obj.name.endswith(":0"):
|
|
raise ValueError("Cowardly refusing to save variable %s because of"
|
|
" unexpected suffix which won't be restored.")
|
|
proto.variable.name = meta_graph._op_name(obj.name) # pylint: disable=protected-access
|
|
proto.variable.trainable = obj.trainable
|
|
proto.variable.dtype = obj.dtype.as_datatype_enum
|
|
proto.variable.synchronization = obj.synchronization.value
|
|
proto.variable.aggregation = obj.aggregation.value
|
|
proto.variable.shape.CopyFrom(obj.shape.as_proto())
|
|
options = save_context.get_save_options()
|
|
if options.experimental_variable_policy._save_variable_devices( # pylint: disable=protected-access
|
|
):
|
|
if hasattr(obj, "device"):
|
|
proto.variable.device = obj.device
|
|
elif isinstance(obj, def_function.Function):
|
|
proto.function.CopyFrom(function_serialization.serialize_function(
|
|
obj, function_name_map))
|
|
elif isinstance(obj, defun.ConcreteFunction):
|
|
proto.bare_concrete_function.CopyFrom(
|
|
function_serialization.serialize_bare_concrete_function(
|
|
obj, function_name_map))
|
|
elif isinstance(obj, _CapturedConstant):
|
|
proto.constant.operation = obj.graph_tensor.op.name
|
|
elif isinstance(obj, tracking.CapturableResource):
|
|
proto.resource.device = obj._resource_device # pylint: disable=protected-access
|
|
else:
|
|
registered_type_proto = revived_types.serialize(obj)
|
|
if registered_type_proto is None:
|
|
# Fallback for types with no matching registration
|
|
# pylint:disable=protected-access
|
|
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
|
|
identifier=obj._object_identifier,
|
|
version=versions_pb2.VersionDef(
|
|
producer=1, min_consumer=1, bad_consumers=[]),
|
|
metadata=obj._tracking_metadata)
|
|
# pylint:enable=protected-access
|
|
proto.user_object.CopyFrom(registered_type_proto)
|
|
|
|
# Give the object a chance to modify the SavedObject proto.
|
|
# This is currently used by MirroredVariables to optionally write their
|
|
# component variables to the proto.
|
|
#
|
|
# This is not yet an official Trackable method, the only current use case
|
|
# being MirroredVariables. See the method implementation there for more
|
|
# documentation.
|
|
if hasattr(obj, "_write_object_proto"):
|
|
obj._write_object_proto(proto, options) # pylint: disable=protected-access
|
|
|
|
|
|
def _export_debug_info(exported_graph, export_dir):
|
|
"""Exports debug information from graph to file.
|
|
|
|
Creates and writes GraphDebugInfo with traces for ops in all functions of the
|
|
exported_graph.
|
|
|
|
Args:
|
|
exported_graph: A Graph that has been created by tracing a saveable view.
|
|
export_dir: SavedModel directory in which to write the debug info.
|
|
"""
|
|
exported_operations = []
|
|
for fn_name in exported_graph._functions: # pylint: disable=protected-access
|
|
fn = exported_graph._get_function(fn_name) # pylint: disable=protected-access
|
|
if not isinstance(fn, defun._EagerDefinedFunction): # pylint: disable=protected-access
|
|
continue
|
|
|
|
fn_graph = fn.graph
|
|
for fn_op in fn_graph.get_operations():
|
|
exported_operations.append((fn_name, fn_op))
|
|
|
|
graph_debug_info = error_interpolation.create_graph_debug_info_def(
|
|
exported_operations)
|
|
file_io.atomic_write_string_to_file(
|
|
os.path.join(
|
|
utils_impl.get_or_create_debug_dir(export_dir),
|
|
constants.DEBUG_INFO_FILENAME_PB),
|
|
graph_debug_info.SerializeToString(deterministic=True))
|
|
|
|
|
|
@tf_export(
|
|
"saved_model.save",
|
|
v1=["saved_model.save", "saved_model.experimental.save"])
|
|
def save(obj, export_dir, signatures=None, options=None):
|
|
# pylint: disable=line-too-long
|
|
"""Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
|
|
|
|
Example usage:
|
|
|
|
```python
|
|
class Adder(tf.Module):
|
|
|
|
@tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
|
|
def add(self, x):
|
|
return x + x + 1.
|
|
|
|
to_export = Adder()
|
|
tf.saved_model.save(to_export, '/tmp/adder')
|
|
```
|
|
|
|
The resulting SavedModel is then servable with an input named "x", its value
|
|
having any shape and dtype float32.
|
|
|
|
The optional `signatures` argument controls which methods in `obj` will be
|
|
available to programs which consume `SavedModel`s, for example, serving
|
|
APIs. Python functions may be decorated with
|
|
`@tf.function(input_signature=...)` and passed as signatures directly, or
|
|
lazily with a call to `get_concrete_function` on the method decorated with
|
|
`@tf.function`.
|
|
|
|
If the `signatures` argument is omitted, `obj` will be searched for
|
|
`@tf.function`-decorated methods. If exactly one `@tf.function` is found, that
|
|
method will be used as the default signature for the SavedModel. This behavior
|
|
is expected to change in the future, when a corresponding
|
|
`tf.saved_model.load` symbol is added. At that point signatures will be
|
|
completely optional, and any `@tf.function` attached to `obj` or its
|
|
dependencies will be exported for use with `load`.
|
|
|
|
When invoking a signature in an exported SavedModel, `Tensor` arguments are
|
|
identified by name. These names will come from the Python function's argument
|
|
names by default. They may be overridden by specifying a `name=...` argument
|
|
in the corresponding `tf.TensorSpec` object. Explicit naming is required if
|
|
multiple `Tensor`s are passed through a single argument to the Python
|
|
function.
|
|
|
|
The outputs of functions used as `signatures` must either be flat lists, in
|
|
which case outputs will be numbered, or a dictionary mapping string keys to
|
|
`Tensor`, in which case the keys will be used to name outputs.
|
|
|
|
Signatures are available in objects returned by `tf.saved_model.load` as a
|
|
`.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
|
|
on an object with a custom `.signatures` attribute will raise an exception.
|
|
|
|
Since `tf.keras.Model` objects are also Trackable, this function can be
|
|
used to export Keras models. For example, exporting with a signature
|
|
specified:
|
|
|
|
```python
|
|
class Model(tf.keras.Model):
|
|
|
|
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
|
|
def serve(self, serialized):
|
|
...
|
|
|
|
m = Model()
|
|
tf.saved_model.save(m, '/tmp/saved_model/')
|
|
```
|
|
|
|
Exporting from a function without a fixed signature:
|
|
|
|
```python
|
|
class Model(tf.keras.Model):
|
|
|
|
@tf.function
|
|
def call(self, x):
|
|
...
|
|
|
|
m = Model()
|
|
tf.saved_model.save(
|
|
m, '/tmp/saved_model/',
|
|
signatures=m.call.get_concrete_function(
|
|
tf.TensorSpec(shape=[None, 3], dtype=tf.float32, name="inp")))
|
|
```
|
|
|
|
`tf.keras.Model` instances constructed from inputs and outputs already have a
|
|
signature and so do not require a `@tf.function` decorator or a `signatures`
|
|
argument. If neither are specified, the model's forward pass is exported.
|
|
|
|
```python
|
|
x = input_layer.Input((4,), name="x")
|
|
y = core.Dense(5, name="out")(x)
|
|
model = training.Model(x, y)
|
|
tf.saved_model.save(model, '/tmp/saved_model/')
|
|
# The exported SavedModel takes "x" with shape [None, 4] and returns "out"
|
|
# with shape [None, 5]
|
|
```
|
|
|
|
Variables must be tracked by assigning them to an attribute of a tracked
|
|
object or to an attribute of `obj` directly. TensorFlow objects (e.g. layers
|
|
from `tf.keras.layers`, optimizers from `tf.train`) track their variables
|
|
automatically. This is the same tracking scheme that `tf.train.Checkpoint`
|
|
uses, and an exported `Checkpoint` object may be restored as a training
|
|
checkpoint by pointing `tf.train.Checkpoint.restore` to the SavedModel's
|
|
"variables/" subdirectory. Currently, variables are the only stateful objects
|
|
supported by `tf.saved_model.save`, but others (e.g. tables) will be supported
|
|
in the future.
|
|
|
|
`tf.function` does not hard-code device annotations from outside the function
|
|
body, instead of using the calling context's device. This means for example
|
|
that exporting a model that runs on a GPU and serving it on a CPU will
|
|
generally work, with some exceptions. `tf.device` annotations inside the body
|
|
of the function will be hard-coded in the exported model; this type of
|
|
annotation is discouraged. Device-specific operations, e.g. with "cuDNN" in
|
|
the name or with device-specific layouts, may cause issues. Currently a
|
|
`DistributionStrategy` is another exception: active distribution strategies
|
|
will cause device placements to be hard-coded in a function. Exporting a
|
|
single-device computation and importing under a `DistributionStrategy` is
|
|
not currently supported, but may be in the future.
|
|
|
|
SavedModels exported with `tf.saved_model.save` [strip default-valued
|
|
attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes)
|
|
automatically, which removes one source of incompatibilities when the consumer
|
|
of a SavedModel is running an older TensorFlow version than the
|
|
producer. There are however other sources of incompatibilities which are not
|
|
handled automatically, such as when the exported model contains operations
|
|
which the consumer does not have definitions for.
|
|
|
|
A single tf.function can generate many ConcreteFunctions. If a downstream tool
|
|
wants to refer to all concrete functions generated by a single tf.function you
|
|
can use the `function_aliases` argument to store a map from the alias name to
|
|
all concrete function names.
|
|
E.g.
|
|
```python
|
|
class MyModel:
|
|
@tf.function
|
|
def func():
|
|
...
|
|
|
|
@tf.function
|
|
def serve():
|
|
...
|
|
func()
|
|
|
|
model = MyModel()
|
|
signatures = {
|
|
'serving_default': model.serve.get_concrete_function(),
|
|
}
|
|
options = tf.saved_model.SaveOptions(function_aliases={
|
|
'my_func': func,
|
|
})
|
|
tf.saved_model.save(model, export_dir, signatures, options)
|
|
```
|
|
|
|
Args:
|
|
obj: A trackable object to export.
|
|
export_dir: A directory in which to write the SavedModel.
|
|
signatures: Optional, one of three types:
|
|
* a `tf.function` with an input signature specified, which will use the
|
|
default serving signature key,
|
|
* the result of `f.get_concrete_function` on a `@tf.function`-decorated
|
|
function `f`, in which case `f` will be used to generate a signature for
|
|
the SavedModel under the default serving signature key,
|
|
* a dictionary, which maps signature keys to either `tf.function`
|
|
instances with input signatures or concrete functions. Keys of such a
|
|
dictionary may be arbitrary strings, but will typically be from the
|
|
`tf.saved_model.signature_constants` module.
|
|
options: Optional, `tf.saved_model.SaveOptions` object that specifies
|
|
options for saving.
|
|
|
|
Raises:
|
|
ValueError: If `obj` is not trackable.
|
|
|
|
@compatibility(eager)
|
|
Not well supported when graph building. From TensorFlow 1.x,
|
|
`tf.compat.v1.enable_eager_execution()` should run first. Calling
|
|
tf.saved_model.save in a loop when graph building from TensorFlow 1.x will
|
|
add new save operations to the default graph each iteration.
|
|
|
|
May not be called from within a function body.
|
|
@end_compatibility
|
|
"""
|
|
options = options or save_options.SaveOptions()
|
|
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
|
# compatible (no sessions) and share it with this export API rather than
|
|
# making a SavedModel proto and writing it directly.
|
|
saved_model = saved_model_pb2.SavedModel()
|
|
meta_graph_def = saved_model.meta_graphs.add()
|
|
|
|
_, exported_graph, object_saver, asset_info = _build_meta_graph(
|
|
obj, signatures, options, meta_graph_def)
|
|
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
|
|
|
|
# Write the checkpoint, copy assets into the assets directory, and write out
|
|
# the SavedModel proto itself.
|
|
utils_impl.get_or_create_variables_dir(export_dir)
|
|
ckpt_options = checkpoint_options.CheckpointOptions(
|
|
experimental_io_device=options.experimental_io_device)
|
|
object_saver.save(
|
|
utils_impl.get_variables_path(export_dir), options=ckpt_options)
|
|
builder_impl.copy_assets_to_destination_dir(asset_info.asset_filename_map,
|
|
export_dir)
|
|
# Note that this needs to be the last file operation when saving the
|
|
# SavedModel. Users rely on checking saved_model_dir/saved_model.pb as an
|
|
# indication that the SavedModel is completely written.
|
|
if context.executing_eagerly():
|
|
try:
|
|
context.async_wait() # Ensure save operations have completed.
|
|
except errors.NotFoundError as err:
|
|
raise FileNotFoundError(
|
|
str(err) + "\n If trying to save on a different device from the "
|
|
"computational device, consider using setting the "
|
|
"`experimental_io_device` option on tf.saved_model.SaveOptions "
|
|
"to the io_device such as '/job:localhost'."
|
|
)
|
|
|
|
path = os.path.join(
|
|
compat.as_str(export_dir),
|
|
compat.as_str(constants.SAVED_MODEL_FILENAME_PB))
|
|
file_io.atomic_write_string_to_file(
|
|
path, saved_model.SerializeToString(deterministic=True))
|
|
# Save debug info, if requested.
|
|
if options.save_debug_info:
|
|
_export_debug_info(exported_graph, export_dir)
|
|
|
|
# Clean reference cycles so repeated export()s don't make work for the garbage
|
|
# collector. Before this point, we need to keep references to captured
|
|
# constants in the saved graph.
|
|
ops.dismantle_graph(exported_graph)
|
|
|
|
|
|
def export_meta_graph(obj, filename, signatures=None, options=None):
|
|
"""Exports the MetaGraph proto of the `obj` to a file.
|
|
|
|
This function goes through the same procedures saved_model.save goes to
|
|
produce the given object's MetaGraph, then saves it to the given file. It
|
|
skips saving checkpoint information, and is useful when all one wants is the
|
|
graph defining the model.
|
|
|
|
Args:
|
|
obj: A trackable object to build the MetaGraph from.
|
|
filename: The file into which to write the MetaGraph.
|
|
signatures: Optional, either a `tf.function` with an input signature
|
|
specified or the result of `f.get_concrete_function` on a
|
|
`@tf.function`-decorated function `f`, in which case `f` will be used to
|
|
generate a signature for the SavedModel under the default serving
|
|
signature key. `signatures` may also be a dictionary, in which case it
|
|
maps from signature keys to either `tf.function` instances with input
|
|
signatures or concrete functions. The keys of such a dictionary may be
|
|
arbitrary strings, but will typically be from the
|
|
`tf.saved_model.signature_constants` module.
|
|
options: Optional, `tf.saved_model.SaveOptions` object that specifies
|
|
options for saving.
|
|
"""
|
|
options = options or save_options.SaveOptions()
|
|
export_dir = os.path.dirname(filename)
|
|
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
|
|
obj, signatures, options)
|
|
|
|
file_io.atomic_write_string_to_file(
|
|
filename, meta_graph_def.SerializeToString(deterministic=True))
|
|
|
|
# Save debug info, if requested.
|
|
if options.save_debug_info:
|
|
_export_debug_info(exported_graph, export_dir)
|
|
|
|
# Clean reference cycles so repeated export()s don't make work for the garbage
|
|
# collector. Before this point, we need to keep references to captured
|
|
# constants in the saved graph.
|
|
ops.dismantle_graph(exported_graph)
|
|
|
|
|
|
def _build_meta_graph_impl(obj,
|
|
signatures,
|
|
options,
|
|
meta_graph_def=None):
|
|
"""Creates a MetaGraph containing the resources and functions of an object."""
|
|
if ops.inside_function():
|
|
raise AssertionError(
|
|
"tf.saved_model.save is not supported inside a traced @tf.function. "
|
|
"Move the call to the outer eagerly-executed context.")
|
|
# pylint: enable=line-too-long
|
|
if not isinstance(obj, base.Trackable):
|
|
raise ValueError(
|
|
"Expected a Trackable object for export, got {}.".format(obj))
|
|
meta_graph_def = meta_graph_def or meta_graph_pb2.MetaGraphDef()
|
|
|
|
checkpoint_graph_view = _AugmentedGraphView(obj)
|
|
if signatures is None:
|
|
signatures = signature_serialization.find_function_to_export(
|
|
checkpoint_graph_view)
|
|
|
|
signatures, wrapped_functions = (
|
|
signature_serialization.canonicalize_signatures(signatures))
|
|
signature_serialization.validate_saveable_view(checkpoint_graph_view)
|
|
signature_map = signature_serialization.create_signature_map(signatures)
|
|
checkpoint_graph_view.add_object(
|
|
parent_node=checkpoint_graph_view.root,
|
|
name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
|
|
subgraph_root=signature_map)
|
|
|
|
# Use _SaveableView to provide a frozen listing of properties and functions.
|
|
# Note we run this twice since, while constructing the view the first time
|
|
# there can be side effects of creating variables.
|
|
_ = _SaveableView(checkpoint_graph_view, options)
|
|
saveable_view = _SaveableView(checkpoint_graph_view, options,
|
|
wrapped_functions)
|
|
object_saver = util.TrackableSaver(checkpoint_graph_view)
|
|
asset_info, exported_graph = _fill_meta_graph_def(meta_graph_def,
|
|
saveable_view, signatures,
|
|
options.namespace_whitelist)
|
|
if options.function_aliases:
|
|
function_aliases = meta_graph_def.meta_info_def.function_aliases
|
|
for alias, func in options.function_aliases.items():
|
|
for fdef in func._stateful_fn._function_cache.all_values(): # pylint: disable=protected-access
|
|
function_aliases[fdef.name] = alias
|
|
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
|
function_aliases[fdef.name] = alias
|
|
|
|
object_graph_proto = _serialize_object_graph(saveable_view,
|
|
asset_info.asset_index)
|
|
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
|
|
|
return meta_graph_def, exported_graph, object_saver, asset_info
|
|
|
|
|
|
def _build_meta_graph(obj,
|
|
signatures,
|
|
options,
|
|
meta_graph_def=None):
|
|
"""Creates a MetaGraph under a save context.
|
|
|
|
Args:
|
|
obj: A trackable object to build the MetaGraph from.
|
|
signatures: Can be a `tf.function` with an input signature specified or the
|
|
result of `f.get_concrete_function` on a `@tf.function`-decorated function
|
|
`f`. `signatures` may also be a dictionary, in which case it maps from
|
|
signature keys to `tf.function` instances. If None, finds signature to
|
|
export from the `@tf.function`-decorated methods in `obj`.
|
|
options: `tf.saved_model.SaveOptions` object that specifies options for
|
|
saving.
|
|
meta_graph_def: Optional, the MetaGraphDef proto fill.
|
|
|
|
Raises:
|
|
AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
|
|
ValueError: If `obj` is not trackable.
|
|
|
|
Returns:
|
|
meta_graph_def: Filled MetaGraphDef proto
|
|
exported_graph: `tf.Graph` object generated from `obj`.
|
|
object_saver: `util.TrackableSaver` of the `obj` and its dependencies.
|
|
asset_info: `_AssetInfo` tuple containing external assets in the `obj`.
|
|
"""
|
|
|
|
with save_context.save_context(options):
|
|
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
|