Write AssetFileDef to Metagraph's asset_file_def field.
PiperOrigin-RevId: 222144547
This commit is contained in:
parent
bbe996e251
commit
dcbc3e1deb
@ -193,6 +193,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
|
@ -24,5 +24,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.saved_model.builder_impl import _SavedModelBuilder
|
||||
from tensorflow.python.saved_model.builder_impl import SavedModelBuilder
|
||||
# pylint: enable=unused-import
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
from google.protobuf.any_pb2 import Any
|
||||
@ -39,8 +40,7 @@ from tensorflow.python.util.deprecation import deprecated_args
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"])
|
||||
class SavedModelBuilder(object):
|
||||
class _SavedModelBuilder(object):
|
||||
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
|
||||
|
||||
The `SavedModelBuilder` class provides functionality to build a `SavedModel`
|
||||
@ -68,7 +68,7 @@ class SavedModelBuilder(object):
|
||||
builder.add_meta_graph_and_variables(sess,
|
||||
["foo-tag"],
|
||||
signature_def_map=foo_signatures,
|
||||
assets_collection=foo_assets)
|
||||
assets_list=foo_assets)
|
||||
...
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
@ -105,19 +105,8 @@ class SavedModelBuilder(object):
|
||||
# weights.
|
||||
self._has_saved_variables = False
|
||||
|
||||
def _save_and_write_assets(self, assets_collection_to_add=None):
|
||||
"""Saves asset to the meta graph and writes asset files to disk.
|
||||
|
||||
Args:
|
||||
assets_collection_to_add: The collection where the asset paths are setup.
|
||||
"""
|
||||
asset_filename_map = _maybe_save_assets(assets_collection_to_add)
|
||||
|
||||
# Return if there are no assets to write.
|
||||
if not asset_filename_map:
|
||||
tf_logging.info("No assets to write.")
|
||||
return
|
||||
|
||||
def _copy_assets_to_destination_dir(self, asset_filename_map):
|
||||
"""Copy all assets from source path to destination path."""
|
||||
assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
|
||||
self._export_dir)
|
||||
|
||||
@ -136,6 +125,25 @@ class SavedModelBuilder(object):
|
||||
tf_logging.info("Assets written to: %s",
|
||||
compat.as_text(assets_destination_dir))
|
||||
|
||||
def _save_and_write_assets(self, meta_graph_def, assets_list=None):
|
||||
"""Saves asset to the meta graph and writes asset files to disk.
|
||||
|
||||
Args:
|
||||
meta_graph_def: The meta graph def to which the assets will be added.
|
||||
assets_list: The list where the asset paths are setup.
|
||||
"""
|
||||
# Creates a function that adds assets into the meta graph def.
|
||||
write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def)
|
||||
asset_filename_map = _maybe_save_assets(write_fn, assets_list)
|
||||
|
||||
# Return if there are no assets to write.
|
||||
if not asset_filename_map:
|
||||
tf_logging.info("No assets to write.")
|
||||
return
|
||||
|
||||
# Copy assets from source path to destination path.
|
||||
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||
|
||||
def _maybe_add_main_op(self, main_op):
|
||||
"""Adds main op to the SavedModel.
|
||||
|
||||
@ -252,12 +260,8 @@ class SavedModelBuilder(object):
|
||||
for outputs_key in outputs:
|
||||
self._validate_tensor_info(outputs[outputs_key])
|
||||
|
||||
def _add_collections(
|
||||
self, assets_collection, main_op, train_op):
|
||||
def _add_collections(self, main_op, train_op):
|
||||
"""Add asset and op collections to be saved."""
|
||||
# Save asset files and write them to disk, if any.
|
||||
self._save_and_write_assets(assets_collection)
|
||||
|
||||
self._maybe_add_main_op(main_op)
|
||||
|
||||
self._add_train_op(train_op)
|
||||
@ -280,7 +284,7 @@ class SavedModelBuilder(object):
|
||||
def add_meta_graph(self,
|
||||
tags,
|
||||
signature_def_map=None,
|
||||
assets_collection=None,
|
||||
assets_list=None,
|
||||
legacy_init_op=None,
|
||||
clear_devices=False,
|
||||
main_op=None,
|
||||
@ -297,8 +301,8 @@ class SavedModelBuilder(object):
|
||||
tags: The set of tags to annotate the meta graph def with.
|
||||
signature_def_map: The map of signature defs to be added to the meta graph
|
||||
def.
|
||||
assets_collection: Assets collection to be saved with SavedModel. Note
|
||||
that this collection should be a subset of the assets saved as part of
|
||||
assets_list: Assets to be saved with SavedModel. Note
|
||||
that this list should be a subset of the assets saved as part of
|
||||
the first meta graph in the SavedModel.
|
||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||
restore op upon a load. Deprecated; please use main_op instead.
|
||||
@ -332,6 +336,212 @@ class SavedModelBuilder(object):
|
||||
# Re-mapping to main_op, as treatment is identical regardless.
|
||||
main_op = main_op or legacy_init_op
|
||||
|
||||
# Add ops to collection.
|
||||
self._add_collections(main_op=main_op, train_op=None)
|
||||
|
||||
saver = self._maybe_create_saver(saver)
|
||||
|
||||
# The graph almost certainly previously contained at least one Saver, and
|
||||
# possibly several (e.g. one for loading a pretrained embedding, and another
|
||||
# for the model weights). Removing the preexisting ones was the
|
||||
# motivation for the clear_extraneous_savers option, but it turns out that
|
||||
# there are edge cases where that option breaks the graph. Until that is
|
||||
# resolved, we just leave the option set to False for now.
|
||||
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
||||
meta_graph_def = saver.export_meta_graph(
|
||||
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
|
||||
|
||||
# Save asset files and write them to disk, if any.
|
||||
self._save_and_write_assets(meta_graph_def, assets_list)
|
||||
|
||||
# Tag the meta graph def and add it to the SavedModel.
|
||||
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
||||
|
||||
@deprecated_args(None,
|
||||
"Pass your op to the equivalent parameter main_op instead.",
|
||||
"legacy_init_op")
|
||||
def add_meta_graph_and_variables(self,
|
||||
sess,
|
||||
tags,
|
||||
signature_def_map=None,
|
||||
assets_list=None,
|
||||
legacy_init_op=None,
|
||||
clear_devices=False,
|
||||
main_op=None,
|
||||
strip_default_attrs=False,
|
||||
saver=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Adds the current meta graph to the SavedModel and saves variables.
|
||||
|
||||
Creates a Saver to save the variables from the provided session. Exports the
|
||||
corresponding meta graph def. This function assumes that the variables to be
|
||||
saved have been initialized. For a given `SavedModelBuilder`, this API must
|
||||
be called exactly once and for the first meta graph to save. For subsequent
|
||||
meta graph defs to be added, the `add_meta_graph()` API must be used.
|
||||
|
||||
Args:
|
||||
sess: The TensorFlow session from which to save the meta graph and
|
||||
variables.
|
||||
tags: The set of tags with which to save the meta graph.
|
||||
signature_def_map: The map of signature def map to add to the meta graph
|
||||
def.
|
||||
assets_list: Assets to be saved with SavedModel.
|
||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||
restore op upon a load. Deprecated; please use main_op instead.
|
||||
clear_devices: Set to true if the device info on the default graph should
|
||||
be cleared.
|
||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
||||
that when the main_op is specified it is run after the restore op at
|
||||
load-time.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
saver: An instance of tf.train.Saver that will be used to export the
|
||||
metagraph and save variables. If None, a sharded Saver that restores
|
||||
all variables will be used.
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if self._has_saved_variables:
|
||||
raise AssertionError("Graph state including variables and assets has "
|
||||
"already been saved. Please invoke "
|
||||
"`add_meta_graph()` instead.")
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
self._validate_signature_def_map(signature_def_map)
|
||||
|
||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||
# Re-mapping to main_op, as treatment is identical regardless.
|
||||
main_op = main_op or legacy_init_op
|
||||
|
||||
# Add ops to collection.
|
||||
self._add_collections(main_op=main_op, train_op=None)
|
||||
|
||||
saved_model_utils.get_or_create_variables_dir(self._export_dir)
|
||||
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
||||
|
||||
saver = self._maybe_create_saver(saver)
|
||||
|
||||
# Save the variables. Also, disable writing the checkpoint state proto. The
|
||||
# file is not used during SavedModel loading. In addition, since a
|
||||
# SavedModel can be copied or moved, this avoids the checkpoint state to
|
||||
# become outdated.
|
||||
saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
|
||||
|
||||
# Export the meta graph def.
|
||||
|
||||
# The graph almost certainly previously contained at least one Saver, and
|
||||
# possibly several (e.g. one for loading a pretrained embedding, and another
|
||||
# for the model weights). Removing the preexisting ones was the
|
||||
# motivation for the clear_extraneous_savers option, but it turns out that
|
||||
# there are edge cases where that option breaks the graph. Until that is
|
||||
# resolved, we just leave the option set to False for now.
|
||||
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
||||
meta_graph_def = saver.export_meta_graph(
|
||||
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
|
||||
|
||||
# Save asset files and write them to disk, if any.
|
||||
self._save_and_write_assets(meta_graph_def, assets_list)
|
||||
|
||||
# Tag the meta graph def and add it to the SavedModel.
|
||||
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
||||
|
||||
# Mark this instance of SavedModel as having saved variables, such that
|
||||
# subsequent attempts to save variables will fail.
|
||||
self._has_saved_variables = True
|
||||
|
||||
def save(self, as_text=False):
|
||||
"""Writes a `SavedModel` protocol buffer to disk.
|
||||
|
||||
The function writes the SavedModel protocol buffer to the export directory
|
||||
in serialized format.
|
||||
|
||||
Args:
|
||||
as_text: Writes the SavedModel protocol buffer in text format to disk.
|
||||
|
||||
Returns:
|
||||
The path to which the SavedModel protocol buffer was written.
|
||||
"""
|
||||
if not file_io.file_exists(self._export_dir):
|
||||
file_io.recursive_create_dir(self._export_dir)
|
||||
|
||||
if as_text:
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||
file_io.write_string_to_file(path, str(self._saved_model))
|
||||
else:
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
file_io.write_string_to_file(path, self._saved_model.SerializeToString())
|
||||
tf_logging.info("SavedModel written to: %s", compat.as_text(path))
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring
|
||||
class SavedModelBuilder(_SavedModelBuilder):
|
||||
__doc__ = _SavedModelBuilder.__doc__.replace("assets_list",
|
||||
"assets_collection")
|
||||
|
||||
def __init__(self, export_dir):
|
||||
super(SavedModelBuilder, self).__init__(export_dir=export_dir)
|
||||
|
||||
def _add_collections(self, assets_collection, main_op, train_op):
|
||||
"""Add asset and op collections to be saved."""
|
||||
# Save asset files and write them to disk, if any.
|
||||
self._save_and_write_assets(assets_collection)
|
||||
|
||||
self._maybe_add_main_op(main_op)
|
||||
|
||||
self._add_train_op(train_op)
|
||||
|
||||
def _save_and_write_assets(self, assets_collection_to_add=None):
|
||||
"""Saves asset to the meta graph and writes asset files to disk.
|
||||
|
||||
Args:
|
||||
assets_collection_to_add: The collection where the asset paths are setup.
|
||||
"""
|
||||
# Add assets to the collection with key `constants.ASSETS_KEY`, in the
|
||||
# graph.
|
||||
asset_filename_map = _maybe_save_assets(_add_asset_to_collection,
|
||||
assets_collection_to_add)
|
||||
|
||||
# Return if there are no assets to write.
|
||||
if not asset_filename_map:
|
||||
tf_logging.info("No assets to write.")
|
||||
return
|
||||
|
||||
# Copy assets from source path to destination path.
|
||||
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||
|
||||
@deprecated_args(None,
|
||||
"Pass your op to the equivalent parameter main_op instead.",
|
||||
"legacy_init_op")
|
||||
def add_meta_graph(self,
|
||||
tags,
|
||||
signature_def_map=None,
|
||||
assets_collection=None,
|
||||
legacy_init_op=None,
|
||||
clear_devices=False,
|
||||
main_op=None,
|
||||
strip_default_attrs=False,
|
||||
saver=None):
|
||||
if not self._has_saved_variables:
|
||||
raise AssertionError(
|
||||
"Graph state including variables and assets has not been saved yet. "
|
||||
"Please invoke `add_meta_graph_and_variables()` first.")
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
self._validate_signature_def_map(signature_def_map)
|
||||
|
||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||
# Re-mapping to main_op, as treatment is identical regardless.
|
||||
main_op = main_op or legacy_init_op
|
||||
|
||||
# Add assets and ops
|
||||
self._add_collections(assets_collection, main_op, None)
|
||||
|
||||
@ -363,38 +573,6 @@ class SavedModelBuilder(object):
|
||||
main_op=None,
|
||||
strip_default_attrs=False,
|
||||
saver=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Adds the current meta graph to the SavedModel and saves variables.
|
||||
|
||||
Creates a Saver to save the variables from the provided session. Exports the
|
||||
corresponding meta graph def. This function assumes that the variables to be
|
||||
saved have been initialized. For a given `SavedModelBuilder`, this API must
|
||||
be called exactly once and for the first meta graph to save. For subsequent
|
||||
meta graph defs to be added, the `add_meta_graph()` API must be used.
|
||||
|
||||
Args:
|
||||
sess: The TensorFlow session from which to save the meta graph and
|
||||
variables.
|
||||
tags: The set of tags with which to save the meta graph.
|
||||
signature_def_map: The map of signature def map to add to the meta graph
|
||||
def.
|
||||
assets_collection: Assets collection to be saved with SavedModel.
|
||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||
restore op upon a load. Deprecated; please use main_op instead.
|
||||
clear_devices: Set to true if the device info on the default graph should
|
||||
be cleared.
|
||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
||||
that when the main_op is specified it is run after the restore op at
|
||||
load-time.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
saver: An instance of tf.train.Saver that will be used to export the
|
||||
metagraph and save variables. If None, a sharded Saver that restores
|
||||
all variables will be used.
|
||||
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if self._has_saved_variables:
|
||||
raise AssertionError("Graph state including variables and assets has "
|
||||
"already been saved. Please invoke "
|
||||
@ -441,41 +619,19 @@ class SavedModelBuilder(object):
|
||||
# subsequent attempts to save variables will fail.
|
||||
self._has_saved_variables = True
|
||||
|
||||
def save(self, as_text=False):
|
||||
"""Writes a `SavedModel` protocol buffer to disk.
|
||||
|
||||
The function writes the SavedModel protocol buffer to the export directory
|
||||
in serialized format.
|
||||
|
||||
Args:
|
||||
as_text: Writes the SavedModel protocol buffer in text format to disk.
|
||||
|
||||
Returns:
|
||||
The path to which the SavedModel protocol buffer was written.
|
||||
"""
|
||||
if not file_io.file_exists(self._export_dir):
|
||||
file_io.recursive_create_dir(self._export_dir)
|
||||
|
||||
if as_text:
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||
file_io.write_string_to_file(path, str(self._saved_model))
|
||||
else:
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
file_io.write_string_to_file(path, self._saved_model.SerializeToString())
|
||||
tf_logging.info("SavedModel written to: %s", compat.as_text(path))
|
||||
|
||||
return path
|
||||
add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace(
|
||||
"assets_list", "assets_collection")
|
||||
add_meta_graph_and_variables.__doc__ = \
|
||||
_SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace(
|
||||
"assets_list", "assets_collection")
|
||||
|
||||
|
||||
def _maybe_save_assets(assets_collection_to_add=None):
|
||||
def _maybe_save_assets(write_fn, assets_to_add=None):
|
||||
"""Saves assets to the meta graph.
|
||||
|
||||
Args:
|
||||
assets_collection_to_add: The collection where the asset paths are setup.
|
||||
write_fn: A function callback that writes asset into meta graph.
|
||||
assets_to_add: The list where the asset paths are setup.
|
||||
|
||||
Returns:
|
||||
A dict of asset basenames for saving to the original full path to the asset.
|
||||
@ -486,14 +642,13 @@ def _maybe_save_assets(assets_collection_to_add=None):
|
||||
# Map of target file names to original filenames
|
||||
asset_filename_map = {}
|
||||
|
||||
if assets_collection_to_add is None:
|
||||
if assets_to_add is None:
|
||||
tf_logging.info("No assets to save.")
|
||||
return asset_filename_map
|
||||
|
||||
# Iterate over the supplied asset collection, build the `AssetFile` proto
|
||||
# and add them to the collection with key `constants.ASSETS_KEY`, in the
|
||||
# graph.
|
||||
for asset_tensor in assets_collection_to_add:
|
||||
# Iterate over the supplied assets, build the `AssetFile` proto and add them
|
||||
# to the meta graph.
|
||||
for asset_tensor in assets_to_add:
|
||||
asset_source_filepath = _asset_path_from_tensor(asset_tensor)
|
||||
if not asset_source_filepath:
|
||||
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
|
||||
@ -501,10 +656,11 @@ def _maybe_save_assets(assets_collection_to_add=None):
|
||||
asset_filename = _get_asset_filename_to_add(
|
||||
asset_source_filepath, asset_filename_map)
|
||||
|
||||
# Build `AssetFile` proto and add it to the asset collection in the graph.
|
||||
# Call the passed-in function that builds AssetFileDef proto and adds it
|
||||
# to either the collection or asset_file_def field of the meta graph.
|
||||
# Note that this should be done even when the file is a duplicate of an
|
||||
# already-added file, as the tensor reference should still exist.
|
||||
_add_asset_to_collection(asset_filename, asset_tensor)
|
||||
write_fn(asset_filename, asset_tensor)
|
||||
|
||||
# In the cases where we are adding a duplicate, this will result in the
|
||||
# last of the filepaths being the one used for copying the file to the
|
||||
@ -542,7 +698,7 @@ def _get_asset_filename_to_add(asset_filepath, asset_filename_map):
|
||||
|
||||
other_asset_filepath = asset_filename_map[asset_filename]
|
||||
if other_asset_filepath == asset_filepath:
|
||||
# This is the same file, stored twice in the collection list. No need
|
||||
# This is the same file, stored twice in the list. No need
|
||||
# to make unique.
|
||||
return asset_filename
|
||||
|
||||
@ -589,6 +745,20 @@ def _asset_path_from_tensor(path_tensor):
|
||||
return str_values[0]
|
||||
|
||||
|
||||
def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor):
|
||||
"""Builds an asset proto and adds it to the meta graph def.
|
||||
|
||||
Args:
|
||||
meta_graph_def: The meta graph def to which the asset will be added.
|
||||
asset_filename: The filename of the asset to be added.
|
||||
asset_tensor: The asset tensor used to populate the tensor info of the asset
|
||||
proto.
|
||||
"""
|
||||
asset_proto = meta_graph_def.asset_file_def.add()
|
||||
asset_proto.filename = asset_filename
|
||||
asset_proto.tensor_info.name = asset_tensor.name
|
||||
|
||||
|
||||
def _add_asset_to_collection(asset_filename, asset_tensor):
|
||||
"""Builds an asset proto and adds it to the asset collection of the graph.
|
||||
|
||||
|
@ -99,22 +99,29 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
|
||||
collection_def = meta_graph_def_to_load.collection_def
|
||||
|
||||
asset_tensor_dict = {}
|
||||
if constants.ASSETS_KEY in collection_def:
|
||||
# Location of the assets for SavedModel.
|
||||
assets_directory = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY))
|
||||
asset_protos = []
|
||||
|
||||
if meta_graph_def_to_load.asset_file_def:
|
||||
asset_protos = meta_graph_def_to_load.asset_file_def
|
||||
elif constants.ASSETS_KEY in collection_def:
|
||||
assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
|
||||
# Process each asset and add it to the asset tensor dictionary.
|
||||
for asset_any_proto in assets_any_proto:
|
||||
asset_proto = meta_graph_pb2.AssetFileDef()
|
||||
asset_any_proto.Unpack(asset_proto)
|
||||
tensor_name = asset_proto.tensor_info.name
|
||||
if import_scope:
|
||||
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
||||
asset_tensor_dict[tensor_name] = os.path.join(
|
||||
compat.as_bytes(assets_directory),
|
||||
compat.as_bytes(asset_proto.filename))
|
||||
asset_protos.append(asset_proto)
|
||||
|
||||
# Location of the assets for SavedModel.
|
||||
assets_directory = os.path.join(
|
||||
compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
|
||||
# Process each asset and add it to the asset tensor dictionary.
|
||||
for asset_proto in asset_protos:
|
||||
tensor_name = asset_proto.tensor_info.name
|
||||
if import_scope:
|
||||
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
||||
asset_tensor_dict[tensor_name] = os.path.join(
|
||||
compat.as_bytes(assets_directory),
|
||||
compat.as_bytes(asset_proto.filename))
|
||||
|
||||
return asset_tensor_dict
|
||||
|
||||
|
||||
|
@ -54,7 +54,7 @@ def tearDownModule():
|
||||
file_io.delete_recursively(test.get_temp_dir())
|
||||
|
||||
|
||||
class SavedModelTest(test.TestCase):
|
||||
class SavedModelTestBase(test.TestCase):
|
||||
|
||||
def _get_export_dir(self, label):
|
||||
return os.path.join(test.get_temp_dir(), label)
|
||||
@ -78,14 +78,16 @@ class SavedModelTest(test.TestCase):
|
||||
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||
return asset_collection
|
||||
|
||||
def _validate_asset_collection(self, export_dir, graph_collection_def,
|
||||
expected_asset_file_name,
|
||||
expected_asset_file_contents,
|
||||
expected_asset_tensor_name,
|
||||
asset_id=0):
|
||||
assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
|
||||
asset = meta_graph_pb2.AssetFileDef()
|
||||
assets_any[asset_id].Unpack(asset)
|
||||
|
||||
class SavedModelTest(SavedModelTestBase):
|
||||
|
||||
def _validate_assets(self,
|
||||
export_dir,
|
||||
asset_file_def,
|
||||
expected_asset_file_name,
|
||||
expected_asset_file_contents,
|
||||
expected_asset_tensor_name,
|
||||
asset_id=0):
|
||||
assets_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
@ -93,8 +95,10 @@ class SavedModelTest(test.TestCase):
|
||||
actual_asset_contents = file_io.read_file_to_string(assets_path)
|
||||
self.assertEqual(expected_asset_file_contents,
|
||||
compat.as_text(actual_asset_contents))
|
||||
self.assertEqual(expected_asset_file_name, asset.filename)
|
||||
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
|
||||
self.assertEqual(expected_asset_file_name,
|
||||
asset_file_def[asset_id].filename)
|
||||
self.assertEqual(expected_asset_tensor_name,
|
||||
asset_file_def[asset_id].tensor_info.name)
|
||||
|
||||
def _validate_inputs_tensor_info_fail(self, builder, tensor_info):
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
@ -185,7 +189,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testVerifySessionGraphUsage(self):
|
||||
export_dir = self._get_export_dir("test_verify_session_graph_usage")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
@ -205,7 +209,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testSequence(self):
|
||||
export_dir = self._get_export_dir("test_sequence")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Expect an assertion error since add_meta_graph_and_variables() should be
|
||||
# invoked before any add_meta_graph() calls.
|
||||
@ -222,7 +226,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testTags(self):
|
||||
export_dir = self._get_export_dir("test_tags")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with a single variable. SavedModel invoked to:
|
||||
# - add with weights.
|
||||
@ -311,7 +315,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testVariables(self):
|
||||
export_dir = self._get_export_dir("test_variables")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with two variables. SavedModel invoked to:
|
||||
# - add with weights.
|
||||
@ -363,7 +367,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testGraphWithoutVariables(self):
|
||||
export_dir = self._get_export_dir("test_graph_has_variables")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with no variables.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
@ -398,7 +402,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testNoOverwrite(self):
|
||||
export_dir = self._get_export_dir("test_no_overwrite")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with a single variable. SavedModel invoked to:
|
||||
# - add with weights.
|
||||
@ -417,12 +421,12 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
# An attempt to create another builder with the same export directory should
|
||||
# result in an assertion error.
|
||||
self.assertRaises(AssertionError, saved_model_builder.SavedModelBuilder,
|
||||
self.assertRaises(AssertionError, saved_model_builder._SavedModelBuilder,
|
||||
export_dir)
|
||||
|
||||
def testSaveAsText(self):
|
||||
export_dir = self._get_export_dir("test_astext")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with a single variable. SavedModel invoked to:
|
||||
# - add with weights.
|
||||
@ -453,7 +457,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testCollections(self):
|
||||
export_dir = self._get_export_dir("test_collections")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with a single variable added to a collection. SavedModel invoked to:
|
||||
# - add with weights.
|
||||
@ -503,7 +507,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testSignatureDefs(self):
|
||||
export_dir = self._get_export_dir("test_signature_defs")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Graph with a single variable and a single entry in the signature def map.
|
||||
# SavedModel is invoked to add with weights.
|
||||
@ -563,7 +567,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testSignatureDefValidationFails(self):
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_fail")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
tensor_without_encoding = meta_graph_pb2.TensorInfo()
|
||||
tensor_without_encoding.dtype = types_pb2.DT_FLOAT
|
||||
@ -585,11 +589,11 @@ class SavedModelTest(test.TestCase):
|
||||
tensor_with_name.dtype = types_pb2.DT_FLOAT
|
||||
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_name_1")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_inputs_tensor_info_accept(builder, tensor_with_name)
|
||||
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_name_2")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_outputs_tensor_info_accept(builder, tensor_with_name)
|
||||
|
||||
def testSignatureDefValidationSucceedsWithCoo(self):
|
||||
@ -599,16 +603,16 @@ class SavedModelTest(test.TestCase):
|
||||
tensor_with_coo.dtype = types_pb2.DT_FLOAT
|
||||
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_coo_1")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_inputs_tensor_info_accept(builder, tensor_with_coo)
|
||||
|
||||
export_dir = self._get_export_dir("test_signature_def_validation_coo_2")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
|
||||
|
||||
def testAssets(self):
|
||||
export_dir = self._get_export_dir("test_assets")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
@ -618,21 +622,19 @@ class SavedModelTest(test.TestCase):
|
||||
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
|
||||
file_io.write_string_to_file(ignored_filepath, "will be ignored")
|
||||
|
||||
asset_collection = self._build_asset_collection("hello42.txt",
|
||||
"foo bar baz",
|
||||
"asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor")
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||
"foo bar baz", "asset_file_tensor:0")
|
||||
ignored_asset_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
@ -641,64 +643,66 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testAssetsNameCollisionDiffFile(self):
|
||||
export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar bak", "asset_file_tensor",
|
||||
asset_subdir="1")
|
||||
asset_list = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar bak", "asset_file_tensor", asset_subdir="1")
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1",
|
||||
asset_subdir="2")
|
||||
asset_list = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1", asset_subdir="2")
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar bak",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt_1", "foo bar baz",
|
||||
"asset_file_tensor_1:0",
|
||||
asset_id=1)
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||
"foo bar bak", "asset_file_tensor:0")
|
||||
self._validate_assets(
|
||||
export_dir,
|
||||
foo_graph.asset_file_def,
|
||||
"hello42.txt_1",
|
||||
"foo bar baz",
|
||||
"asset_file_tensor_1:0",
|
||||
asset_id=1)
|
||||
|
||||
def testAssetsNameCollisionSameFilepath(self):
|
||||
export_dir = self._get_export_dir("test_assets_name_collision_same_path")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor")
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1")
|
||||
asset_list = self._build_asset_collection("hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor_1")
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||
"foo bar baz", "asset_file_tensor:0")
|
||||
# The second tensor should be recorded, but the same.
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor_1:0",
|
||||
asset_id=1)
|
||||
self._validate_assets(
|
||||
export_dir,
|
||||
foo_graph.asset_file_def,
|
||||
"hello42.txt",
|
||||
"foo bar baz",
|
||||
"asset_file_tensor_1:0",
|
||||
asset_id=1)
|
||||
ignored_asset_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
@ -707,35 +711,35 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testAssetsNameCollisionSameFile(self):
|
||||
export_dir = self._get_export_dir("test_assets_name_collision_same_file")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor",
|
||||
asset_subdir="1")
|
||||
asset_list = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor", asset_subdir="1")
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1",
|
||||
asset_subdir="2")
|
||||
asset_list = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1", asset_subdir="2")
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||
"foo bar baz", "asset_file_tensor:0")
|
||||
# The second tensor should be recorded, but the same.
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor_1:0",
|
||||
asset_id=1)
|
||||
self._validate_assets(
|
||||
export_dir,
|
||||
foo_graph.asset_file_def,
|
||||
"hello42.txt",
|
||||
"foo bar baz",
|
||||
"asset_file_tensor_1:0",
|
||||
asset_id=1)
|
||||
ignored_asset_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
@ -744,19 +748,21 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testAssetsNameCollisionManyFiles(self):
|
||||
export_dir = self._get_export_dir("test_assets_name_collision_many_files")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
for i in range(5):
|
||||
idx = str(i)
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx,
|
||||
asset_list = self._build_asset_collection(
|
||||
"hello42.txt",
|
||||
"foo bar baz " + idx,
|
||||
"asset_file_tensor_" + idx,
|
||||
asset_subdir=idx)
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
@ -765,18 +771,20 @@ class SavedModelTest(test.TestCase):
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
for i in range(1, 5):
|
||||
idx = str(i)
|
||||
self._validate_asset_collection(
|
||||
export_dir, foo_graph.collection_def, "hello42.txt_" + idx,
|
||||
"foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx),
|
||||
self._validate_assets(
|
||||
export_dir,
|
||||
foo_graph.asset_file_def,
|
||||
"hello42.txt_" + idx,
|
||||
"foo bar baz " + idx,
|
||||
"asset_file_tensor_{}:0".format(idx),
|
||||
asset_id=i)
|
||||
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz 0",
|
||||
"asset_file_tensor_0:0")
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||
"foo bar baz 0", "asset_file_tensor_0:0")
|
||||
|
||||
def testCustomMainOp(self):
|
||||
export_dir = self._get_export_dir("test_main_op")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
@ -811,7 +819,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testLegacyInitOp(self):
|
||||
export_dir = self._get_export_dir("test_legacy_init_op")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
@ -855,7 +863,7 @@ class SavedModelTest(test.TestCase):
|
||||
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
||||
|
||||
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
g = ops.Graph()
|
||||
with self.session(graph=g) as sess:
|
||||
@ -885,7 +893,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testTrainOp(self):
|
||||
export_dir = self._get_export_dir("test_train_op")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
@ -914,7 +922,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testTrainOpGroup(self):
|
||||
export_dir = self._get_export_dir("test_train_op_group")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
@ -943,7 +951,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testTrainOpAfterVariables(self):
|
||||
export_dir = self._get_export_dir("test_train_op_after_variables")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
@ -975,28 +983,28 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testMultipleAssets(self):
|
||||
export_dir = self._get_export_dir("test_multiple_assets")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
# Build an asset collection specific to `foo` graph.
|
||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
|
||||
# Add the asset collection as part of the graph with tag "foo".
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
# Build an asset collection specific to `bar` graph.
|
||||
asset_collection = self._build_asset_collection("bar.txt", "content_bar",
|
||||
"asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("bar.txt", "content_bar",
|
||||
"asset_file_tensor")
|
||||
|
||||
# Add the asset collection as part of the graph with tag "bar".
|
||||
builder.add_meta_graph(["bar"], assets_collection=asset_collection)
|
||||
builder.add_meta_graph(["bar"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
@ -1004,43 +1012,41 @@ class SavedModelTest(test.TestCase):
|
||||
# Check assets restored for graph with tag "foo".
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"foo.txt", "content_foo",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "foo.txt",
|
||||
"content_foo", "asset_file_tensor:0")
|
||||
|
||||
# Check assets restored for graph with tag "bar".
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
bar_graph = loader.load(sess, ["bar"], export_dir)
|
||||
self._validate_asset_collection(export_dir, bar_graph.collection_def,
|
||||
"bar.txt", "content_bar",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, bar_graph.asset_file_def, "bar.txt",
|
||||
"content_bar", "asset_file_tensor:0")
|
||||
|
||||
def testDuplicateAssets(self):
|
||||
export_dir = self._get_export_dir("test_duplicate_assets")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
# Build an asset collection with `foo.txt` that has `foo` specific
|
||||
# content.
|
||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
|
||||
# Add the asset collection as part of the graph with tag "foo".
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
sess, ["foo"], assets_list=asset_list)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
# Build an asset collection with `foo.txt` that has `bar` specific
|
||||
# content.
|
||||
asset_collection = self._build_asset_collection("foo.txt", "content_bar",
|
||||
"asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("foo.txt", "content_bar",
|
||||
"asset_file_tensor")
|
||||
|
||||
# Add the asset collection as part of the graph with tag "bar".
|
||||
builder.add_meta_graph(["bar"], assets_collection=asset_collection)
|
||||
builder.add_meta_graph(["bar"], assets_list=asset_list)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
@ -1048,9 +1054,8 @@ class SavedModelTest(test.TestCase):
|
||||
# Check assets restored for graph with tag "foo".
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"foo.txt", "content_foo",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "foo.txt",
|
||||
"content_foo", "asset_file_tensor:0")
|
||||
|
||||
# Check assets restored for graph with tag "bar".
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
@ -1059,13 +1064,12 @@ class SavedModelTest(test.TestCase):
|
||||
# Validate the assets for `bar` graph. `foo.txt` should contain the
|
||||
# original contents corresponding to `foo` graph since an asset with the
|
||||
# same name across multiple graphs is only stored the first time
|
||||
self._validate_asset_collection(export_dir, bar_graph.collection_def,
|
||||
"foo.txt", "content_foo",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, bar_graph.asset_file_def, "foo.txt",
|
||||
"content_foo", "asset_file_tensor:0")
|
||||
|
||||
def testOp(self):
|
||||
export_dir = self._get_export_dir("test_op")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with session.Session(
|
||||
graph=ops.Graph(),
|
||||
@ -1108,7 +1112,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testCustomSaveable(self):
|
||||
export_dir = self._get_export_dir("custom_saveable")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with session.Session(
|
||||
graph=ops.Graph(),
|
||||
@ -1137,7 +1141,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testCustomSaver(self):
|
||||
export_dir = self._get_export_dir("test_custom_saver")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
variables.VariableV1(1, name="v1")
|
||||
@ -1159,7 +1163,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testNoCustomSaver(self):
|
||||
export_dir = self._get_export_dir("test_no_custom_saver")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
variables.VariableV1(1, name="v1")
|
||||
@ -1181,7 +1185,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testMultipleCustomSavers(self):
|
||||
export_dir = self._get_export_dir("test_multiple_custom_savers")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
variables.VariableV1(1, name="v1")
|
||||
@ -1211,19 +1215,19 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testImportScope(self):
|
||||
export_dir = self._get_export_dir("test_scoped_assets")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Build a SavedModel with a variable, an asset, and a constant tensor.
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
asset_list = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
constant_op.constant("constant value", name="constant_tensor_name")
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["tag_name"], assets_collection=asset_collection)
|
||||
sess, ["tag_name"], assets_list=asset_list)
|
||||
|
||||
# Save the asset file path for later comparison.
|
||||
asset_file_path = asset_collection[0].eval()
|
||||
asset_file_path = asset_list[0].eval()
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
@ -1244,16 +1248,14 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
# The loaded asset tensor should be scoped, but the asset file path and
|
||||
# contents should be unchanged.
|
||||
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||
self.assertEqual(1, len(asset_collection))
|
||||
self.assertEqual(asset_file_path, asset_collection[0].eval())
|
||||
self.assertEqual("scope_name/asset_file_tensor:0",
|
||||
asset_collection[0].name)
|
||||
asset_list = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||
self.assertEqual(1, len(asset_list))
|
||||
self.assertEqual(asset_file_path, asset_list[0].eval())
|
||||
self.assertEqual("scope_name/asset_file_tensor:0", asset_list[0].name)
|
||||
# The static asset data inside graph_proto.collection_def should not be
|
||||
# scoped.
|
||||
self._validate_asset_collection(export_dir, graph_proto.collection_def,
|
||||
"foo.txt", "content_foo",
|
||||
"asset_file_tensor:0")
|
||||
self._validate_assets(export_dir, graph_proto.asset_file_def, "foo.txt",
|
||||
"content_foo", "asset_file_tensor:0")
|
||||
|
||||
# The constant tensor should be scoped, but its contents should be
|
||||
# unchanged.
|
||||
@ -1264,7 +1266,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testClearDevices(self):
|
||||
export_dir = self._get_export_dir("test_clear_devices")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Specify a device and save a variable.
|
||||
ops.reset_default_graph()
|
||||
@ -1288,7 +1290,7 @@ class SavedModelTest(test.TestCase):
|
||||
|
||||
def testStripDefaultAttrs(self):
|
||||
export_dir = self._get_export_dir("test_strip_default_attrs")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Add a graph with two float32 variables and a Complex Op composing them
|
||||
# with strip_default_attrs enabled.
|
||||
@ -1361,7 +1363,7 @@ class SavedModelTest(test.TestCase):
|
||||
def testInconsistentConsumerDefaultAttrs(self):
|
||||
export_dir = self._get_export_dir(
|
||||
"test_strip_default_attrs_no_consumer_defaults")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Add a graph with a single variable and a test op with a defaultless
|
||||
# float32 attr, "test_attr".
|
||||
@ -1428,5 +1430,60 @@ class SavedModelTest(test.TestCase):
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
|
||||
|
||||
class SavedModelV1Test(SavedModelTestBase):
|
||||
|
||||
def _validate_asset_collection(self,
|
||||
export_dir,
|
||||
graph_collection_def,
|
||||
expected_asset_file_name,
|
||||
expected_asset_file_contents,
|
||||
expected_asset_tensor_name,
|
||||
asset_id=0):
|
||||
assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
|
||||
asset = meta_graph_pb2.AssetFileDef()
|
||||
assets_any[asset_id].Unpack(asset)
|
||||
assets_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
compat.as_bytes(expected_asset_file_name))
|
||||
actual_asset_contents = file_io.read_file_to_string(assets_path)
|
||||
self.assertEqual(expected_asset_file_contents,
|
||||
compat.as_text(actual_asset_contents))
|
||||
self.assertEqual(expected_asset_file_name, asset.filename)
|
||||
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
|
||||
|
||||
def testWritingAssetsToCollection(self):
|
||||
export_dir = self._get_export_dir("test_writing_assets_to_collection")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
# Build an asset list.
|
||||
ignored_filepath = os.path.join(
|
||||
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
|
||||
file_io.write_string_to_file(ignored_filepath, "will be ignored")
|
||||
|
||||
asset_collection = self._build_asset_collection(
|
||||
"hello42.txt", "foo bar baz", "asset_file_tensor")
|
||||
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], assets_collection=asset_collection)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||
"hello42.txt", "foo bar baz",
|
||||
"asset_file_tensor:0")
|
||||
ignored_asset_path = os.path.join(
|
||||
compat.as_bytes(export_dir),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
compat.as_bytes("ignored.txt"))
|
||||
self.assertFalse(file_io.file_exists(ignored_asset_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.saved_model.Builder"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
|
||||
is_instance: "<class \'tensorflow.python.saved_model.builder_impl._SavedModelBuilder\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
|
@ -1,6 +1,7 @@
|
||||
path: "tensorflow.saved_model.builder.SavedModelBuilder"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
|
||||
is_instance: "<class \'tensorflow.python.saved_model.builder_impl._SavedModelBuilder\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
|
Loading…
Reference in New Issue
Block a user