Write AssetFileDef to Metagraph's asset_file_def field.
PiperOrigin-RevId: 222144547
This commit is contained in:
parent
bbe996e251
commit
dcbc3e1deb
@ -193,6 +193,15 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
|||||||
|
|
||||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||||
std::vector<AssetFileDef>* asset_file_defs) {
|
std::vector<AssetFileDef>* asset_file_defs) {
|
||||||
|
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||||
|
// collection, so read from metagraph first.
|
||||||
|
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||||
|
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||||
|
asset_file_defs->push_back(asset);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// Fall back to read from collection to be backward compatible with v1.
|
||||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||||
if (assets_it == collection_def_map.end()) {
|
if (assets_it == collection_def_map.end()) {
|
||||||
|
@ -24,5 +24,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
|
from tensorflow.python.saved_model.builder_impl import _SavedModelBuilder
|
||||||
from tensorflow.python.saved_model.builder_impl import SavedModelBuilder
|
from tensorflow.python.saved_model.builder_impl import SavedModelBuilder
|
||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from google.protobuf.any_pb2 import Any
|
from google.protobuf.any_pb2 import Any
|
||||||
@ -39,8 +40,7 @@ from tensorflow.python.util.deprecation import deprecated_args
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"])
|
class _SavedModelBuilder(object):
|
||||||
class SavedModelBuilder(object):
|
|
||||||
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
|
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
|
||||||
|
|
||||||
The `SavedModelBuilder` class provides functionality to build a `SavedModel`
|
The `SavedModelBuilder` class provides functionality to build a `SavedModel`
|
||||||
@ -68,7 +68,7 @@ class SavedModelBuilder(object):
|
|||||||
builder.add_meta_graph_and_variables(sess,
|
builder.add_meta_graph_and_variables(sess,
|
||||||
["foo-tag"],
|
["foo-tag"],
|
||||||
signature_def_map=foo_signatures,
|
signature_def_map=foo_signatures,
|
||||||
assets_collection=foo_assets)
|
assets_list=foo_assets)
|
||||||
...
|
...
|
||||||
|
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with tf.Session(graph=tf.Graph()) as sess:
|
||||||
@ -105,19 +105,8 @@ class SavedModelBuilder(object):
|
|||||||
# weights.
|
# weights.
|
||||||
self._has_saved_variables = False
|
self._has_saved_variables = False
|
||||||
|
|
||||||
def _save_and_write_assets(self, assets_collection_to_add=None):
|
def _copy_assets_to_destination_dir(self, asset_filename_map):
|
||||||
"""Saves asset to the meta graph and writes asset files to disk.
|
"""Copy all assets from source path to destination path."""
|
||||||
|
|
||||||
Args:
|
|
||||||
assets_collection_to_add: The collection where the asset paths are setup.
|
|
||||||
"""
|
|
||||||
asset_filename_map = _maybe_save_assets(assets_collection_to_add)
|
|
||||||
|
|
||||||
# Return if there are no assets to write.
|
|
||||||
if not asset_filename_map:
|
|
||||||
tf_logging.info("No assets to write.")
|
|
||||||
return
|
|
||||||
|
|
||||||
assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
|
assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
|
||||||
self._export_dir)
|
self._export_dir)
|
||||||
|
|
||||||
@ -136,6 +125,25 @@ class SavedModelBuilder(object):
|
|||||||
tf_logging.info("Assets written to: %s",
|
tf_logging.info("Assets written to: %s",
|
||||||
compat.as_text(assets_destination_dir))
|
compat.as_text(assets_destination_dir))
|
||||||
|
|
||||||
|
def _save_and_write_assets(self, meta_graph_def, assets_list=None):
|
||||||
|
"""Saves asset to the meta graph and writes asset files to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
meta_graph_def: The meta graph def to which the assets will be added.
|
||||||
|
assets_list: The list where the asset paths are setup.
|
||||||
|
"""
|
||||||
|
# Creates a function that adds assets into the meta graph def.
|
||||||
|
write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def)
|
||||||
|
asset_filename_map = _maybe_save_assets(write_fn, assets_list)
|
||||||
|
|
||||||
|
# Return if there are no assets to write.
|
||||||
|
if not asset_filename_map:
|
||||||
|
tf_logging.info("No assets to write.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Copy assets from source path to destination path.
|
||||||
|
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||||
|
|
||||||
def _maybe_add_main_op(self, main_op):
|
def _maybe_add_main_op(self, main_op):
|
||||||
"""Adds main op to the SavedModel.
|
"""Adds main op to the SavedModel.
|
||||||
|
|
||||||
@ -252,12 +260,8 @@ class SavedModelBuilder(object):
|
|||||||
for outputs_key in outputs:
|
for outputs_key in outputs:
|
||||||
self._validate_tensor_info(outputs[outputs_key])
|
self._validate_tensor_info(outputs[outputs_key])
|
||||||
|
|
||||||
def _add_collections(
|
def _add_collections(self, main_op, train_op):
|
||||||
self, assets_collection, main_op, train_op):
|
|
||||||
"""Add asset and op collections to be saved."""
|
"""Add asset and op collections to be saved."""
|
||||||
# Save asset files and write them to disk, if any.
|
|
||||||
self._save_and_write_assets(assets_collection)
|
|
||||||
|
|
||||||
self._maybe_add_main_op(main_op)
|
self._maybe_add_main_op(main_op)
|
||||||
|
|
||||||
self._add_train_op(train_op)
|
self._add_train_op(train_op)
|
||||||
@ -280,7 +284,7 @@ class SavedModelBuilder(object):
|
|||||||
def add_meta_graph(self,
|
def add_meta_graph(self,
|
||||||
tags,
|
tags,
|
||||||
signature_def_map=None,
|
signature_def_map=None,
|
||||||
assets_collection=None,
|
assets_list=None,
|
||||||
legacy_init_op=None,
|
legacy_init_op=None,
|
||||||
clear_devices=False,
|
clear_devices=False,
|
||||||
main_op=None,
|
main_op=None,
|
||||||
@ -297,8 +301,8 @@ class SavedModelBuilder(object):
|
|||||||
tags: The set of tags to annotate the meta graph def with.
|
tags: The set of tags to annotate the meta graph def with.
|
||||||
signature_def_map: The map of signature defs to be added to the meta graph
|
signature_def_map: The map of signature defs to be added to the meta graph
|
||||||
def.
|
def.
|
||||||
assets_collection: Assets collection to be saved with SavedModel. Note
|
assets_list: Assets to be saved with SavedModel. Note
|
||||||
that this collection should be a subset of the assets saved as part of
|
that this list should be a subset of the assets saved as part of
|
||||||
the first meta graph in the SavedModel.
|
the first meta graph in the SavedModel.
|
||||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||||
restore op upon a load. Deprecated; please use main_op instead.
|
restore op upon a load. Deprecated; please use main_op instead.
|
||||||
@ -332,6 +336,212 @@ class SavedModelBuilder(object):
|
|||||||
# Re-mapping to main_op, as treatment is identical regardless.
|
# Re-mapping to main_op, as treatment is identical regardless.
|
||||||
main_op = main_op or legacy_init_op
|
main_op = main_op or legacy_init_op
|
||||||
|
|
||||||
|
# Add ops to collection.
|
||||||
|
self._add_collections(main_op=main_op, train_op=None)
|
||||||
|
|
||||||
|
saver = self._maybe_create_saver(saver)
|
||||||
|
|
||||||
|
# The graph almost certainly previously contained at least one Saver, and
|
||||||
|
# possibly several (e.g. one for loading a pretrained embedding, and another
|
||||||
|
# for the model weights). Removing the preexisting ones was the
|
||||||
|
# motivation for the clear_extraneous_savers option, but it turns out that
|
||||||
|
# there are edge cases where that option breaks the graph. Until that is
|
||||||
|
# resolved, we just leave the option set to False for now.
|
||||||
|
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
||||||
|
meta_graph_def = saver.export_meta_graph(
|
||||||
|
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
|
||||||
|
|
||||||
|
# Save asset files and write them to disk, if any.
|
||||||
|
self._save_and_write_assets(meta_graph_def, assets_list)
|
||||||
|
|
||||||
|
# Tag the meta graph def and add it to the SavedModel.
|
||||||
|
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
||||||
|
|
||||||
|
@deprecated_args(None,
|
||||||
|
"Pass your op to the equivalent parameter main_op instead.",
|
||||||
|
"legacy_init_op")
|
||||||
|
def add_meta_graph_and_variables(self,
|
||||||
|
sess,
|
||||||
|
tags,
|
||||||
|
signature_def_map=None,
|
||||||
|
assets_list=None,
|
||||||
|
legacy_init_op=None,
|
||||||
|
clear_devices=False,
|
||||||
|
main_op=None,
|
||||||
|
strip_default_attrs=False,
|
||||||
|
saver=None):
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
"""Adds the current meta graph to the SavedModel and saves variables.
|
||||||
|
|
||||||
|
Creates a Saver to save the variables from the provided session. Exports the
|
||||||
|
corresponding meta graph def. This function assumes that the variables to be
|
||||||
|
saved have been initialized. For a given `SavedModelBuilder`, this API must
|
||||||
|
be called exactly once and for the first meta graph to save. For subsequent
|
||||||
|
meta graph defs to be added, the `add_meta_graph()` API must be used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sess: The TensorFlow session from which to save the meta graph and
|
||||||
|
variables.
|
||||||
|
tags: The set of tags with which to save the meta graph.
|
||||||
|
signature_def_map: The map of signature def map to add to the meta graph
|
||||||
|
def.
|
||||||
|
assets_list: Assets to be saved with SavedModel.
|
||||||
|
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||||
|
restore op upon a load. Deprecated; please use main_op instead.
|
||||||
|
clear_devices: Set to true if the device info on the default graph should
|
||||||
|
be cleared.
|
||||||
|
main_op: Op or group of ops to execute when the graph is loaded. Note
|
||||||
|
that when the main_op is specified it is run after the restore op at
|
||||||
|
load-time.
|
||||||
|
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||||
|
removed from the NodeDefs. For a detailed guide, see
|
||||||
|
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||||
|
saver: An instance of tf.train.Saver that will be used to export the
|
||||||
|
metagraph and save variables. If None, a sharded Saver that restores
|
||||||
|
all variables will be used.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# pylint: enable=line-too-long
|
||||||
|
if self._has_saved_variables:
|
||||||
|
raise AssertionError("Graph state including variables and assets has "
|
||||||
|
"already been saved. Please invoke "
|
||||||
|
"`add_meta_graph()` instead.")
|
||||||
|
|
||||||
|
# Validate the signature def map to ensure all included TensorInfos are
|
||||||
|
# properly populated.
|
||||||
|
self._validate_signature_def_map(signature_def_map)
|
||||||
|
|
||||||
|
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||||
|
# Re-mapping to main_op, as treatment is identical regardless.
|
||||||
|
main_op = main_op or legacy_init_op
|
||||||
|
|
||||||
|
# Add ops to collection.
|
||||||
|
self._add_collections(main_op=main_op, train_op=None)
|
||||||
|
|
||||||
|
saved_model_utils.get_or_create_variables_dir(self._export_dir)
|
||||||
|
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
||||||
|
|
||||||
|
saver = self._maybe_create_saver(saver)
|
||||||
|
|
||||||
|
# Save the variables. Also, disable writing the checkpoint state proto. The
|
||||||
|
# file is not used during SavedModel loading. In addition, since a
|
||||||
|
# SavedModel can be copied or moved, this avoids the checkpoint state to
|
||||||
|
# become outdated.
|
||||||
|
saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
|
||||||
|
|
||||||
|
# Export the meta graph def.
|
||||||
|
|
||||||
|
# The graph almost certainly previously contained at least one Saver, and
|
||||||
|
# possibly several (e.g. one for loading a pretrained embedding, and another
|
||||||
|
# for the model weights). Removing the preexisting ones was the
|
||||||
|
# motivation for the clear_extraneous_savers option, but it turns out that
|
||||||
|
# there are edge cases where that option breaks the graph. Until that is
|
||||||
|
# resolved, we just leave the option set to False for now.
|
||||||
|
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
||||||
|
meta_graph_def = saver.export_meta_graph(
|
||||||
|
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
|
||||||
|
|
||||||
|
# Save asset files and write them to disk, if any.
|
||||||
|
self._save_and_write_assets(meta_graph_def, assets_list)
|
||||||
|
|
||||||
|
# Tag the meta graph def and add it to the SavedModel.
|
||||||
|
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
||||||
|
|
||||||
|
# Mark this instance of SavedModel as having saved variables, such that
|
||||||
|
# subsequent attempts to save variables will fail.
|
||||||
|
self._has_saved_variables = True
|
||||||
|
|
||||||
|
def save(self, as_text=False):
|
||||||
|
"""Writes a `SavedModel` protocol buffer to disk.
|
||||||
|
|
||||||
|
The function writes the SavedModel protocol buffer to the export directory
|
||||||
|
in serialized format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
as_text: Writes the SavedModel protocol buffer in text format to disk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to which the SavedModel protocol buffer was written.
|
||||||
|
"""
|
||||||
|
if not file_io.file_exists(self._export_dir):
|
||||||
|
file_io.recursive_create_dir(self._export_dir)
|
||||||
|
|
||||||
|
if as_text:
|
||||||
|
path = os.path.join(
|
||||||
|
compat.as_bytes(self._export_dir),
|
||||||
|
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||||
|
file_io.write_string_to_file(path, str(self._saved_model))
|
||||||
|
else:
|
||||||
|
path = os.path.join(
|
||||||
|
compat.as_bytes(self._export_dir),
|
||||||
|
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||||
|
file_io.write_string_to_file(path, self._saved_model.SerializeToString())
|
||||||
|
tf_logging.info("SavedModel written to: %s", compat.as_text(path))
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring
|
||||||
|
class SavedModelBuilder(_SavedModelBuilder):
|
||||||
|
__doc__ = _SavedModelBuilder.__doc__.replace("assets_list",
|
||||||
|
"assets_collection")
|
||||||
|
|
||||||
|
def __init__(self, export_dir):
|
||||||
|
super(SavedModelBuilder, self).__init__(export_dir=export_dir)
|
||||||
|
|
||||||
|
def _add_collections(self, assets_collection, main_op, train_op):
|
||||||
|
"""Add asset and op collections to be saved."""
|
||||||
|
# Save asset files and write them to disk, if any.
|
||||||
|
self._save_and_write_assets(assets_collection)
|
||||||
|
|
||||||
|
self._maybe_add_main_op(main_op)
|
||||||
|
|
||||||
|
self._add_train_op(train_op)
|
||||||
|
|
||||||
|
def _save_and_write_assets(self, assets_collection_to_add=None):
|
||||||
|
"""Saves asset to the meta graph and writes asset files to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assets_collection_to_add: The collection where the asset paths are setup.
|
||||||
|
"""
|
||||||
|
# Add assets to the collection with key `constants.ASSETS_KEY`, in the
|
||||||
|
# graph.
|
||||||
|
asset_filename_map = _maybe_save_assets(_add_asset_to_collection,
|
||||||
|
assets_collection_to_add)
|
||||||
|
|
||||||
|
# Return if there are no assets to write.
|
||||||
|
if not asset_filename_map:
|
||||||
|
tf_logging.info("No assets to write.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Copy assets from source path to destination path.
|
||||||
|
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||||
|
|
||||||
|
@deprecated_args(None,
|
||||||
|
"Pass your op to the equivalent parameter main_op instead.",
|
||||||
|
"legacy_init_op")
|
||||||
|
def add_meta_graph(self,
|
||||||
|
tags,
|
||||||
|
signature_def_map=None,
|
||||||
|
assets_collection=None,
|
||||||
|
legacy_init_op=None,
|
||||||
|
clear_devices=False,
|
||||||
|
main_op=None,
|
||||||
|
strip_default_attrs=False,
|
||||||
|
saver=None):
|
||||||
|
if not self._has_saved_variables:
|
||||||
|
raise AssertionError(
|
||||||
|
"Graph state including variables and assets has not been saved yet. "
|
||||||
|
"Please invoke `add_meta_graph_and_variables()` first.")
|
||||||
|
|
||||||
|
# Validate the signature def map to ensure all included TensorInfos are
|
||||||
|
# properly populated.
|
||||||
|
self._validate_signature_def_map(signature_def_map)
|
||||||
|
|
||||||
|
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||||
|
# Re-mapping to main_op, as treatment is identical regardless.
|
||||||
|
main_op = main_op or legacy_init_op
|
||||||
|
|
||||||
# Add assets and ops
|
# Add assets and ops
|
||||||
self._add_collections(assets_collection, main_op, None)
|
self._add_collections(assets_collection, main_op, None)
|
||||||
|
|
||||||
@ -363,38 +573,6 @@ class SavedModelBuilder(object):
|
|||||||
main_op=None,
|
main_op=None,
|
||||||
strip_default_attrs=False,
|
strip_default_attrs=False,
|
||||||
saver=None):
|
saver=None):
|
||||||
# pylint: disable=line-too-long
|
|
||||||
"""Adds the current meta graph to the SavedModel and saves variables.
|
|
||||||
|
|
||||||
Creates a Saver to save the variables from the provided session. Exports the
|
|
||||||
corresponding meta graph def. This function assumes that the variables to be
|
|
||||||
saved have been initialized. For a given `SavedModelBuilder`, this API must
|
|
||||||
be called exactly once and for the first meta graph to save. For subsequent
|
|
||||||
meta graph defs to be added, the `add_meta_graph()` API must be used.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sess: The TensorFlow session from which to save the meta graph and
|
|
||||||
variables.
|
|
||||||
tags: The set of tags with which to save the meta graph.
|
|
||||||
signature_def_map: The map of signature def map to add to the meta graph
|
|
||||||
def.
|
|
||||||
assets_collection: Assets collection to be saved with SavedModel.
|
|
||||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
|
||||||
restore op upon a load. Deprecated; please use main_op instead.
|
|
||||||
clear_devices: Set to true if the device info on the default graph should
|
|
||||||
be cleared.
|
|
||||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
|
||||||
that when the main_op is specified it is run after the restore op at
|
|
||||||
load-time.
|
|
||||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
|
||||||
removed from the NodeDefs. For a detailed guide, see
|
|
||||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
|
||||||
saver: An instance of tf.train.Saver that will be used to export the
|
|
||||||
metagraph and save variables. If None, a sharded Saver that restores
|
|
||||||
all variables will be used.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# pylint: enable=line-too-long
|
|
||||||
if self._has_saved_variables:
|
if self._has_saved_variables:
|
||||||
raise AssertionError("Graph state including variables and assets has "
|
raise AssertionError("Graph state including variables and assets has "
|
||||||
"already been saved. Please invoke "
|
"already been saved. Please invoke "
|
||||||
@ -441,41 +619,19 @@ class SavedModelBuilder(object):
|
|||||||
# subsequent attempts to save variables will fail.
|
# subsequent attempts to save variables will fail.
|
||||||
self._has_saved_variables = True
|
self._has_saved_variables = True
|
||||||
|
|
||||||
def save(self, as_text=False):
|
add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace(
|
||||||
"""Writes a `SavedModel` protocol buffer to disk.
|
"assets_list", "assets_collection")
|
||||||
|
add_meta_graph_and_variables.__doc__ = \
|
||||||
The function writes the SavedModel protocol buffer to the export directory
|
_SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace(
|
||||||
in serialized format.
|
"assets_list", "assets_collection")
|
||||||
|
|
||||||
Args:
|
|
||||||
as_text: Writes the SavedModel protocol buffer in text format to disk.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The path to which the SavedModel protocol buffer was written.
|
|
||||||
"""
|
|
||||||
if not file_io.file_exists(self._export_dir):
|
|
||||||
file_io.recursive_create_dir(self._export_dir)
|
|
||||||
|
|
||||||
if as_text:
|
|
||||||
path = os.path.join(
|
|
||||||
compat.as_bytes(self._export_dir),
|
|
||||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
|
||||||
file_io.write_string_to_file(path, str(self._saved_model))
|
|
||||||
else:
|
|
||||||
path = os.path.join(
|
|
||||||
compat.as_bytes(self._export_dir),
|
|
||||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
|
||||||
file_io.write_string_to_file(path, self._saved_model.SerializeToString())
|
|
||||||
tf_logging.info("SavedModel written to: %s", compat.as_text(path))
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_save_assets(assets_collection_to_add=None):
|
def _maybe_save_assets(write_fn, assets_to_add=None):
|
||||||
"""Saves assets to the meta graph.
|
"""Saves assets to the meta graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
assets_collection_to_add: The collection where the asset paths are setup.
|
write_fn: A function callback that writes asset into meta graph.
|
||||||
|
assets_to_add: The list where the asset paths are setup.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict of asset basenames for saving to the original full path to the asset.
|
A dict of asset basenames for saving to the original full path to the asset.
|
||||||
@ -486,14 +642,13 @@ def _maybe_save_assets(assets_collection_to_add=None):
|
|||||||
# Map of target file names to original filenames
|
# Map of target file names to original filenames
|
||||||
asset_filename_map = {}
|
asset_filename_map = {}
|
||||||
|
|
||||||
if assets_collection_to_add is None:
|
if assets_to_add is None:
|
||||||
tf_logging.info("No assets to save.")
|
tf_logging.info("No assets to save.")
|
||||||
return asset_filename_map
|
return asset_filename_map
|
||||||
|
|
||||||
# Iterate over the supplied asset collection, build the `AssetFile` proto
|
# Iterate over the supplied assets, build the `AssetFile` proto and add them
|
||||||
# and add them to the collection with key `constants.ASSETS_KEY`, in the
|
# to the meta graph.
|
||||||
# graph.
|
for asset_tensor in assets_to_add:
|
||||||
for asset_tensor in assets_collection_to_add:
|
|
||||||
asset_source_filepath = _asset_path_from_tensor(asset_tensor)
|
asset_source_filepath = _asset_path_from_tensor(asset_tensor)
|
||||||
if not asset_source_filepath:
|
if not asset_source_filepath:
|
||||||
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
|
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
|
||||||
@ -501,10 +656,11 @@ def _maybe_save_assets(assets_collection_to_add=None):
|
|||||||
asset_filename = _get_asset_filename_to_add(
|
asset_filename = _get_asset_filename_to_add(
|
||||||
asset_source_filepath, asset_filename_map)
|
asset_source_filepath, asset_filename_map)
|
||||||
|
|
||||||
# Build `AssetFile` proto and add it to the asset collection in the graph.
|
# Call the passed-in function that builds AssetFileDef proto and adds it
|
||||||
|
# to either the collection or asset_file_def field of the meta graph.
|
||||||
# Note that this should be done even when the file is a duplicate of an
|
# Note that this should be done even when the file is a duplicate of an
|
||||||
# already-added file, as the tensor reference should still exist.
|
# already-added file, as the tensor reference should still exist.
|
||||||
_add_asset_to_collection(asset_filename, asset_tensor)
|
write_fn(asset_filename, asset_tensor)
|
||||||
|
|
||||||
# In the cases where we are adding a duplicate, this will result in the
|
# In the cases where we are adding a duplicate, this will result in the
|
||||||
# last of the filepaths being the one used for copying the file to the
|
# last of the filepaths being the one used for copying the file to the
|
||||||
@ -542,7 +698,7 @@ def _get_asset_filename_to_add(asset_filepath, asset_filename_map):
|
|||||||
|
|
||||||
other_asset_filepath = asset_filename_map[asset_filename]
|
other_asset_filepath = asset_filename_map[asset_filename]
|
||||||
if other_asset_filepath == asset_filepath:
|
if other_asset_filepath == asset_filepath:
|
||||||
# This is the same file, stored twice in the collection list. No need
|
# This is the same file, stored twice in the list. No need
|
||||||
# to make unique.
|
# to make unique.
|
||||||
return asset_filename
|
return asset_filename
|
||||||
|
|
||||||
@ -589,6 +745,20 @@ def _asset_path_from_tensor(path_tensor):
|
|||||||
return str_values[0]
|
return str_values[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor):
|
||||||
|
"""Builds an asset proto and adds it to the meta graph def.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
meta_graph_def: The meta graph def to which the asset will be added.
|
||||||
|
asset_filename: The filename of the asset to be added.
|
||||||
|
asset_tensor: The asset tensor used to populate the tensor info of the asset
|
||||||
|
proto.
|
||||||
|
"""
|
||||||
|
asset_proto = meta_graph_def.asset_file_def.add()
|
||||||
|
asset_proto.filename = asset_filename
|
||||||
|
asset_proto.tensor_info.name = asset_tensor.name
|
||||||
|
|
||||||
|
|
||||||
def _add_asset_to_collection(asset_filename, asset_tensor):
|
def _add_asset_to_collection(asset_filename, asset_tensor):
|
||||||
"""Builds an asset proto and adds it to the asset collection of the graph.
|
"""Builds an asset proto and adds it to the asset collection of the graph.
|
||||||
|
|
||||||
|
@ -99,22 +99,29 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
|
|||||||
collection_def = meta_graph_def_to_load.collection_def
|
collection_def = meta_graph_def_to_load.collection_def
|
||||||
|
|
||||||
asset_tensor_dict = {}
|
asset_tensor_dict = {}
|
||||||
if constants.ASSETS_KEY in collection_def:
|
asset_protos = []
|
||||||
# Location of the assets for SavedModel.
|
|
||||||
assets_directory = os.path.join(
|
if meta_graph_def_to_load.asset_file_def:
|
||||||
compat.as_bytes(export_dir),
|
asset_protos = meta_graph_def_to_load.asset_file_def
|
||||||
compat.as_bytes(constants.ASSETS_DIRECTORY))
|
elif constants.ASSETS_KEY in collection_def:
|
||||||
assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
|
assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
|
||||||
# Process each asset and add it to the asset tensor dictionary.
|
|
||||||
for asset_any_proto in assets_any_proto:
|
for asset_any_proto in assets_any_proto:
|
||||||
asset_proto = meta_graph_pb2.AssetFileDef()
|
asset_proto = meta_graph_pb2.AssetFileDef()
|
||||||
asset_any_proto.Unpack(asset_proto)
|
asset_any_proto.Unpack(asset_proto)
|
||||||
tensor_name = asset_proto.tensor_info.name
|
asset_protos.append(asset_proto)
|
||||||
if import_scope:
|
|
||||||
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
# Location of the assets for SavedModel.
|
||||||
asset_tensor_dict[tensor_name] = os.path.join(
|
assets_directory = os.path.join(
|
||||||
compat.as_bytes(assets_directory),
|
compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY))
|
||||||
compat.as_bytes(asset_proto.filename))
|
# Process each asset and add it to the asset tensor dictionary.
|
||||||
|
for asset_proto in asset_protos:
|
||||||
|
tensor_name = asset_proto.tensor_info.name
|
||||||
|
if import_scope:
|
||||||
|
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
||||||
|
asset_tensor_dict[tensor_name] = os.path.join(
|
||||||
|
compat.as_bytes(assets_directory),
|
||||||
|
compat.as_bytes(asset_proto.filename))
|
||||||
|
|
||||||
return asset_tensor_dict
|
return asset_tensor_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ def tearDownModule():
|
|||||||
file_io.delete_recursively(test.get_temp_dir())
|
file_io.delete_recursively(test.get_temp_dir())
|
||||||
|
|
||||||
|
|
||||||
class SavedModelTest(test.TestCase):
|
class SavedModelTestBase(test.TestCase):
|
||||||
|
|
||||||
def _get_export_dir(self, label):
|
def _get_export_dir(self, label):
|
||||||
return os.path.join(test.get_temp_dir(), label)
|
return os.path.join(test.get_temp_dir(), label)
|
||||||
@ -78,14 +78,16 @@ class SavedModelTest(test.TestCase):
|
|||||||
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||||
return asset_collection
|
return asset_collection
|
||||||
|
|
||||||
def _validate_asset_collection(self, export_dir, graph_collection_def,
|
|
||||||
expected_asset_file_name,
|
class SavedModelTest(SavedModelTestBase):
|
||||||
expected_asset_file_contents,
|
|
||||||
expected_asset_tensor_name,
|
def _validate_assets(self,
|
||||||
asset_id=0):
|
export_dir,
|
||||||
assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
|
asset_file_def,
|
||||||
asset = meta_graph_pb2.AssetFileDef()
|
expected_asset_file_name,
|
||||||
assets_any[asset_id].Unpack(asset)
|
expected_asset_file_contents,
|
||||||
|
expected_asset_tensor_name,
|
||||||
|
asset_id=0):
|
||||||
assets_path = os.path.join(
|
assets_path = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
@ -93,8 +95,10 @@ class SavedModelTest(test.TestCase):
|
|||||||
actual_asset_contents = file_io.read_file_to_string(assets_path)
|
actual_asset_contents = file_io.read_file_to_string(assets_path)
|
||||||
self.assertEqual(expected_asset_file_contents,
|
self.assertEqual(expected_asset_file_contents,
|
||||||
compat.as_text(actual_asset_contents))
|
compat.as_text(actual_asset_contents))
|
||||||
self.assertEqual(expected_asset_file_name, asset.filename)
|
self.assertEqual(expected_asset_file_name,
|
||||||
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
|
asset_file_def[asset_id].filename)
|
||||||
|
self.assertEqual(expected_asset_tensor_name,
|
||||||
|
asset_file_def[asset_id].tensor_info.name)
|
||||||
|
|
||||||
def _validate_inputs_tensor_info_fail(self, builder, tensor_info):
|
def _validate_inputs_tensor_info_fail(self, builder, tensor_info):
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
@ -185,7 +189,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testVerifySessionGraphUsage(self):
|
def testVerifySessionGraphUsage(self):
|
||||||
export_dir = self._get_export_dir("test_verify_session_graph_usage")
|
export_dir = self._get_export_dir("test_verify_session_graph_usage")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
@ -205,7 +209,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testSequence(self):
|
def testSequence(self):
|
||||||
export_dir = self._get_export_dir("test_sequence")
|
export_dir = self._get_export_dir("test_sequence")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Expect an assertion error since add_meta_graph_and_variables() should be
|
# Expect an assertion error since add_meta_graph_and_variables() should be
|
||||||
# invoked before any add_meta_graph() calls.
|
# invoked before any add_meta_graph() calls.
|
||||||
@ -222,7 +226,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testTags(self):
|
def testTags(self):
|
||||||
export_dir = self._get_export_dir("test_tags")
|
export_dir = self._get_export_dir("test_tags")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with a single variable. SavedModel invoked to:
|
# Graph with a single variable. SavedModel invoked to:
|
||||||
# - add with weights.
|
# - add with weights.
|
||||||
@ -311,7 +315,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testVariables(self):
|
def testVariables(self):
|
||||||
export_dir = self._get_export_dir("test_variables")
|
export_dir = self._get_export_dir("test_variables")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with two variables. SavedModel invoked to:
|
# Graph with two variables. SavedModel invoked to:
|
||||||
# - add with weights.
|
# - add with weights.
|
||||||
@ -363,7 +367,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testGraphWithoutVariables(self):
|
def testGraphWithoutVariables(self):
|
||||||
export_dir = self._get_export_dir("test_graph_has_variables")
|
export_dir = self._get_export_dir("test_graph_has_variables")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with no variables.
|
# Graph with no variables.
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
@ -398,7 +402,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testNoOverwrite(self):
|
def testNoOverwrite(self):
|
||||||
export_dir = self._get_export_dir("test_no_overwrite")
|
export_dir = self._get_export_dir("test_no_overwrite")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with a single variable. SavedModel invoked to:
|
# Graph with a single variable. SavedModel invoked to:
|
||||||
# - add with weights.
|
# - add with weights.
|
||||||
@ -417,12 +421,12 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
# An attempt to create another builder with the same export directory should
|
# An attempt to create another builder with the same export directory should
|
||||||
# result in an assertion error.
|
# result in an assertion error.
|
||||||
self.assertRaises(AssertionError, saved_model_builder.SavedModelBuilder,
|
self.assertRaises(AssertionError, saved_model_builder._SavedModelBuilder,
|
||||||
export_dir)
|
export_dir)
|
||||||
|
|
||||||
def testSaveAsText(self):
|
def testSaveAsText(self):
|
||||||
export_dir = self._get_export_dir("test_astext")
|
export_dir = self._get_export_dir("test_astext")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with a single variable. SavedModel invoked to:
|
# Graph with a single variable. SavedModel invoked to:
|
||||||
# - add with weights.
|
# - add with weights.
|
||||||
@ -453,7 +457,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testCollections(self):
|
def testCollections(self):
|
||||||
export_dir = self._get_export_dir("test_collections")
|
export_dir = self._get_export_dir("test_collections")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with a single variable added to a collection. SavedModel invoked to:
|
# Graph with a single variable added to a collection. SavedModel invoked to:
|
||||||
# - add with weights.
|
# - add with weights.
|
||||||
@ -503,7 +507,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testSignatureDefs(self):
|
def testSignatureDefs(self):
|
||||||
export_dir = self._get_export_dir("test_signature_defs")
|
export_dir = self._get_export_dir("test_signature_defs")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Graph with a single variable and a single entry in the signature def map.
|
# Graph with a single variable and a single entry in the signature def map.
|
||||||
# SavedModel is invoked to add with weights.
|
# SavedModel is invoked to add with weights.
|
||||||
@ -563,7 +567,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testSignatureDefValidationFails(self):
|
def testSignatureDefValidationFails(self):
|
||||||
export_dir = self._get_export_dir("test_signature_def_validation_fail")
|
export_dir = self._get_export_dir("test_signature_def_validation_fail")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
tensor_without_encoding = meta_graph_pb2.TensorInfo()
|
tensor_without_encoding = meta_graph_pb2.TensorInfo()
|
||||||
tensor_without_encoding.dtype = types_pb2.DT_FLOAT
|
tensor_without_encoding.dtype = types_pb2.DT_FLOAT
|
||||||
@ -585,11 +589,11 @@ class SavedModelTest(test.TestCase):
|
|||||||
tensor_with_name.dtype = types_pb2.DT_FLOAT
|
tensor_with_name.dtype = types_pb2.DT_FLOAT
|
||||||
|
|
||||||
export_dir = self._get_export_dir("test_signature_def_validation_name_1")
|
export_dir = self._get_export_dir("test_signature_def_validation_name_1")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
self._validate_inputs_tensor_info_accept(builder, tensor_with_name)
|
self._validate_inputs_tensor_info_accept(builder, tensor_with_name)
|
||||||
|
|
||||||
export_dir = self._get_export_dir("test_signature_def_validation_name_2")
|
export_dir = self._get_export_dir("test_signature_def_validation_name_2")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
self._validate_outputs_tensor_info_accept(builder, tensor_with_name)
|
self._validate_outputs_tensor_info_accept(builder, tensor_with_name)
|
||||||
|
|
||||||
def testSignatureDefValidationSucceedsWithCoo(self):
|
def testSignatureDefValidationSucceedsWithCoo(self):
|
||||||
@ -599,16 +603,16 @@ class SavedModelTest(test.TestCase):
|
|||||||
tensor_with_coo.dtype = types_pb2.DT_FLOAT
|
tensor_with_coo.dtype = types_pb2.DT_FLOAT
|
||||||
|
|
||||||
export_dir = self._get_export_dir("test_signature_def_validation_coo_1")
|
export_dir = self._get_export_dir("test_signature_def_validation_coo_1")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
self._validate_inputs_tensor_info_accept(builder, tensor_with_coo)
|
self._validate_inputs_tensor_info_accept(builder, tensor_with_coo)
|
||||||
|
|
||||||
export_dir = self._get_export_dir("test_signature_def_validation_coo_2")
|
export_dir = self._get_export_dir("test_signature_def_validation_coo_2")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
|
self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
|
||||||
|
|
||||||
def testAssets(self):
|
def testAssets(self):
|
||||||
export_dir = self._get_export_dir("test_assets")
|
export_dir = self._get_export_dir("test_assets")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
@ -618,21 +622,19 @@ class SavedModelTest(test.TestCase):
|
|||||||
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
|
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
|
||||||
file_io.write_string_to_file(ignored_filepath, "will be ignored")
|
file_io.write_string_to_file(ignored_filepath, "will be ignored")
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection("hello42.txt",
|
asset_list = self._build_asset_collection("hello42.txt", "foo bar baz",
|
||||||
"foo bar baz",
|
"asset_file_tensor")
|
||||||
"asset_file_tensor")
|
|
||||||
|
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||||
"hello42.txt", "foo bar baz",
|
"foo bar baz", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
ignored_asset_path = os.path.join(
|
ignored_asset_path = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
@ -641,64 +643,66 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testAssetsNameCollisionDiffFile(self):
|
def testAssetsNameCollisionDiffFile(self):
|
||||||
export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
|
export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection(
|
||||||
"hello42.txt", "foo bar bak", "asset_file_tensor",
|
"hello42.txt", "foo bar bak", "asset_file_tensor", asset_subdir="1")
|
||||||
asset_subdir="1")
|
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection(
|
||||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1",
|
"hello42.txt", "foo bar baz", "asset_file_tensor_1", asset_subdir="2")
|
||||||
asset_subdir="2")
|
|
||||||
|
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||||
"hello42.txt", "foo bar bak",
|
"foo bar bak", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
self._validate_assets(
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
export_dir,
|
||||||
"hello42.txt_1", "foo bar baz",
|
foo_graph.asset_file_def,
|
||||||
"asset_file_tensor_1:0",
|
"hello42.txt_1",
|
||||||
asset_id=1)
|
"foo bar baz",
|
||||||
|
"asset_file_tensor_1:0",
|
||||||
|
asset_id=1)
|
||||||
|
|
||||||
def testAssetsNameCollisionSameFilepath(self):
|
def testAssetsNameCollisionSameFilepath(self):
|
||||||
export_dir = self._get_export_dir("test_assets_name_collision_same_path")
|
export_dir = self._get_export_dir("test_assets_name_collision_same_path")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection("hello42.txt", "foo bar baz",
|
||||||
"hello42.txt", "foo bar baz", "asset_file_tensor")
|
"asset_file_tensor")
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection("hello42.txt", "foo bar baz",
|
||||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1")
|
"asset_file_tensor_1")
|
||||||
|
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||||
"hello42.txt", "foo bar baz",
|
"foo bar baz", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
# The second tensor should be recorded, but the same.
|
# The second tensor should be recorded, but the same.
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(
|
||||||
"hello42.txt", "foo bar baz",
|
export_dir,
|
||||||
"asset_file_tensor_1:0",
|
foo_graph.asset_file_def,
|
||||||
asset_id=1)
|
"hello42.txt",
|
||||||
|
"foo bar baz",
|
||||||
|
"asset_file_tensor_1:0",
|
||||||
|
asset_id=1)
|
||||||
ignored_asset_path = os.path.join(
|
ignored_asset_path = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
@ -707,35 +711,35 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testAssetsNameCollisionSameFile(self):
|
def testAssetsNameCollisionSameFile(self):
|
||||||
export_dir = self._get_export_dir("test_assets_name_collision_same_file")
|
export_dir = self._get_export_dir("test_assets_name_collision_same_file")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection(
|
||||||
"hello42.txt", "foo bar baz", "asset_file_tensor",
|
"hello42.txt", "foo bar baz", "asset_file_tensor", asset_subdir="1")
|
||||||
asset_subdir="1")
|
|
||||||
|
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection(
|
||||||
"hello42.txt", "foo bar baz", "asset_file_tensor_1",
|
"hello42.txt", "foo bar baz", "asset_file_tensor_1", asset_subdir="2")
|
||||||
asset_subdir="2")
|
|
||||||
|
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||||
"hello42.txt", "foo bar baz",
|
"foo bar baz", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
# The second tensor should be recorded, but the same.
|
# The second tensor should be recorded, but the same.
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(
|
||||||
"hello42.txt", "foo bar baz",
|
export_dir,
|
||||||
"asset_file_tensor_1:0",
|
foo_graph.asset_file_def,
|
||||||
asset_id=1)
|
"hello42.txt",
|
||||||
|
"foo bar baz",
|
||||||
|
"asset_file_tensor_1:0",
|
||||||
|
asset_id=1)
|
||||||
ignored_asset_path = os.path.join(
|
ignored_asset_path = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
@ -744,19 +748,21 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testAssetsNameCollisionManyFiles(self):
|
def testAssetsNameCollisionManyFiles(self):
|
||||||
export_dir = self._get_export_dir("test_assets_name_collision_many_files")
|
export_dir = self._get_export_dir("test_assets_name_collision_many_files")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
idx = str(i)
|
idx = str(i)
|
||||||
asset_collection = self._build_asset_collection(
|
asset_list = self._build_asset_collection(
|
||||||
"hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx,
|
"hello42.txt",
|
||||||
|
"foo bar baz " + idx,
|
||||||
|
"asset_file_tensor_" + idx,
|
||||||
asset_subdir=idx)
|
asset_subdir=idx)
|
||||||
|
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
@ -765,18 +771,20 @@ class SavedModelTest(test.TestCase):
|
|||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
for i in range(1, 5):
|
for i in range(1, 5):
|
||||||
idx = str(i)
|
idx = str(i)
|
||||||
self._validate_asset_collection(
|
self._validate_assets(
|
||||||
export_dir, foo_graph.collection_def, "hello42.txt_" + idx,
|
export_dir,
|
||||||
"foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx),
|
foo_graph.asset_file_def,
|
||||||
|
"hello42.txt_" + idx,
|
||||||
|
"foo bar baz " + idx,
|
||||||
|
"asset_file_tensor_{}:0".format(idx),
|
||||||
asset_id=i)
|
asset_id=i)
|
||||||
|
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||||
"hello42.txt", "foo bar baz 0",
|
"foo bar baz 0", "asset_file_tensor_0:0")
|
||||||
"asset_file_tensor_0:0")
|
|
||||||
|
|
||||||
def testCustomMainOp(self):
|
def testCustomMainOp(self):
|
||||||
export_dir = self._get_export_dir("test_main_op")
|
export_dir = self._get_export_dir("test_main_op")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
# Add `v1` and `v2` variables to the graph.
|
# Add `v1` and `v2` variables to the graph.
|
||||||
@ -811,7 +819,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testLegacyInitOp(self):
|
def testLegacyInitOp(self):
|
||||||
export_dir = self._get_export_dir("test_legacy_init_op")
|
export_dir = self._get_export_dir("test_legacy_init_op")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
# Add `v1` and `v2` variables to the graph.
|
# Add `v1` and `v2` variables to the graph.
|
||||||
@ -855,7 +863,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
||||||
|
|
||||||
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
g = ops.Graph()
|
g = ops.Graph()
|
||||||
with self.session(graph=g) as sess:
|
with self.session(graph=g) as sess:
|
||||||
@ -885,7 +893,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testTrainOp(self):
|
def testTrainOp(self):
|
||||||
export_dir = self._get_export_dir("test_train_op")
|
export_dir = self._get_export_dir("test_train_op")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
# Add `v1` and `v2` variables to the graph.
|
# Add `v1` and `v2` variables to the graph.
|
||||||
@ -914,7 +922,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testTrainOpGroup(self):
|
def testTrainOpGroup(self):
|
||||||
export_dir = self._get_export_dir("test_train_op_group")
|
export_dir = self._get_export_dir("test_train_op_group")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
# Add `v1` and `v2` variables to the graph.
|
# Add `v1` and `v2` variables to the graph.
|
||||||
@ -943,7 +951,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testTrainOpAfterVariables(self):
|
def testTrainOpAfterVariables(self):
|
||||||
export_dir = self._get_export_dir("test_train_op_after_variables")
|
export_dir = self._get_export_dir("test_train_op_after_variables")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
# Add `v1` and `v2` variables to the graph.
|
# Add `v1` and `v2` variables to the graph.
|
||||||
@ -975,28 +983,28 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultipleAssets(self):
|
def testMultipleAssets(self):
|
||||||
export_dir = self._get_export_dir("test_multiple_assets")
|
export_dir = self._get_export_dir("test_multiple_assets")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
# Build an asset collection specific to `foo` graph.
|
# Build an asset collection specific to `foo` graph.
|
||||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
asset_list = self._build_asset_collection("foo.txt", "content_foo",
|
||||||
"asset_file_tensor")
|
"asset_file_tensor")
|
||||||
|
|
||||||
# Add the asset collection as part of the graph with tag "foo".
|
# Add the asset collection as part of the graph with tag "foo".
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
# Build an asset collection specific to `bar` graph.
|
# Build an asset collection specific to `bar` graph.
|
||||||
asset_collection = self._build_asset_collection("bar.txt", "content_bar",
|
asset_list = self._build_asset_collection("bar.txt", "content_bar",
|
||||||
"asset_file_tensor")
|
"asset_file_tensor")
|
||||||
|
|
||||||
# Add the asset collection as part of the graph with tag "bar".
|
# Add the asset collection as part of the graph with tag "bar".
|
||||||
builder.add_meta_graph(["bar"], assets_collection=asset_collection)
|
builder.add_meta_graph(["bar"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
@ -1004,43 +1012,41 @@ class SavedModelTest(test.TestCase):
|
|||||||
# Check assets restored for graph with tag "foo".
|
# Check assets restored for graph with tag "foo".
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "foo.txt",
|
||||||
"foo.txt", "content_foo",
|
"content_foo", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
|
|
||||||
# Check assets restored for graph with tag "bar".
|
# Check assets restored for graph with tag "bar".
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
bar_graph = loader.load(sess, ["bar"], export_dir)
|
bar_graph = loader.load(sess, ["bar"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, bar_graph.collection_def,
|
self._validate_assets(export_dir, bar_graph.asset_file_def, "bar.txt",
|
||||||
"bar.txt", "content_bar",
|
"content_bar", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
|
|
||||||
def testDuplicateAssets(self):
|
def testDuplicateAssets(self):
|
||||||
export_dir = self._get_export_dir("test_duplicate_assets")
|
export_dir = self._get_export_dir("test_duplicate_assets")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
# Build an asset collection with `foo.txt` that has `foo` specific
|
# Build an asset collection with `foo.txt` that has `foo` specific
|
||||||
# content.
|
# content.
|
||||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
asset_list = self._build_asset_collection("foo.txt", "content_foo",
|
||||||
"asset_file_tensor")
|
"asset_file_tensor")
|
||||||
|
|
||||||
# Add the asset collection as part of the graph with tag "foo".
|
# Add the asset collection as part of the graph with tag "foo".
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], assets_collection=asset_collection)
|
sess, ["foo"], assets_list=asset_list)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
# Build an asset collection with `foo.txt` that has `bar` specific
|
# Build an asset collection with `foo.txt` that has `bar` specific
|
||||||
# content.
|
# content.
|
||||||
asset_collection = self._build_asset_collection("foo.txt", "content_bar",
|
asset_list = self._build_asset_collection("foo.txt", "content_bar",
|
||||||
"asset_file_tensor")
|
"asset_file_tensor")
|
||||||
|
|
||||||
# Add the asset collection as part of the graph with tag "bar".
|
# Add the asset collection as part of the graph with tag "bar".
|
||||||
builder.add_meta_graph(["bar"], assets_collection=asset_collection)
|
builder.add_meta_graph(["bar"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
@ -1048,9 +1054,8 @@ class SavedModelTest(test.TestCase):
|
|||||||
# Check assets restored for graph with tag "foo".
|
# Check assets restored for graph with tag "foo".
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
foo_graph = loader.load(sess, ["foo"], export_dir)
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "foo.txt",
|
||||||
"foo.txt", "content_foo",
|
"content_foo", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
|
|
||||||
# Check assets restored for graph with tag "bar".
|
# Check assets restored for graph with tag "bar".
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
@ -1059,13 +1064,12 @@ class SavedModelTest(test.TestCase):
|
|||||||
# Validate the assets for `bar` graph. `foo.txt` should contain the
|
# Validate the assets for `bar` graph. `foo.txt` should contain the
|
||||||
# original contents corresponding to `foo` graph since an asset with the
|
# original contents corresponding to `foo` graph since an asset with the
|
||||||
# same name across multiple graphs is only stored the first time
|
# same name across multiple graphs is only stored the first time
|
||||||
self._validate_asset_collection(export_dir, bar_graph.collection_def,
|
self._validate_assets(export_dir, bar_graph.asset_file_def, "foo.txt",
|
||||||
"foo.txt", "content_foo",
|
"content_foo", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
|
|
||||||
def testOp(self):
|
def testOp(self):
|
||||||
export_dir = self._get_export_dir("test_op")
|
export_dir = self._get_export_dir("test_op")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with session.Session(
|
with session.Session(
|
||||||
graph=ops.Graph(),
|
graph=ops.Graph(),
|
||||||
@ -1108,7 +1112,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testCustomSaveable(self):
|
def testCustomSaveable(self):
|
||||||
export_dir = self._get_export_dir("custom_saveable")
|
export_dir = self._get_export_dir("custom_saveable")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with session.Session(
|
with session.Session(
|
||||||
graph=ops.Graph(),
|
graph=ops.Graph(),
|
||||||
@ -1137,7 +1141,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testCustomSaver(self):
|
def testCustomSaver(self):
|
||||||
export_dir = self._get_export_dir("test_custom_saver")
|
export_dir = self._get_export_dir("test_custom_saver")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
variables.VariableV1(1, name="v1")
|
variables.VariableV1(1, name="v1")
|
||||||
@ -1159,7 +1163,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testNoCustomSaver(self):
|
def testNoCustomSaver(self):
|
||||||
export_dir = self._get_export_dir("test_no_custom_saver")
|
export_dir = self._get_export_dir("test_no_custom_saver")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
variables.VariableV1(1, name="v1")
|
variables.VariableV1(1, name="v1")
|
||||||
@ -1181,7 +1185,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testMultipleCustomSavers(self):
|
def testMultipleCustomSavers(self):
|
||||||
export_dir = self._get_export_dir("test_multiple_custom_savers")
|
export_dir = self._get_export_dir("test_multiple_custom_savers")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
variables.VariableV1(1, name="v1")
|
variables.VariableV1(1, name="v1")
|
||||||
@ -1211,19 +1215,19 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testImportScope(self):
|
def testImportScope(self):
|
||||||
export_dir = self._get_export_dir("test_scoped_assets")
|
export_dir = self._get_export_dir("test_scoped_assets")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Build a SavedModel with a variable, an asset, and a constant tensor.
|
# Build a SavedModel with a variable, an asset, and a constant tensor.
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
self._init_and_validate_variable(sess, "v", 42)
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
asset_list = self._build_asset_collection("foo.txt", "content_foo",
|
||||||
"asset_file_tensor")
|
"asset_file_tensor")
|
||||||
constant_op.constant("constant value", name="constant_tensor_name")
|
constant_op.constant("constant value", name="constant_tensor_name")
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["tag_name"], assets_collection=asset_collection)
|
sess, ["tag_name"], assets_list=asset_list)
|
||||||
|
|
||||||
# Save the asset file path for later comparison.
|
# Save the asset file path for later comparison.
|
||||||
asset_file_path = asset_collection[0].eval()
|
asset_file_path = asset_list[0].eval()
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
@ -1244,16 +1248,14 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
# The loaded asset tensor should be scoped, but the asset file path and
|
# The loaded asset tensor should be scoped, but the asset file path and
|
||||||
# contents should be unchanged.
|
# contents should be unchanged.
|
||||||
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
asset_list = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||||
self.assertEqual(1, len(asset_collection))
|
self.assertEqual(1, len(asset_list))
|
||||||
self.assertEqual(asset_file_path, asset_collection[0].eval())
|
self.assertEqual(asset_file_path, asset_list[0].eval())
|
||||||
self.assertEqual("scope_name/asset_file_tensor:0",
|
self.assertEqual("scope_name/asset_file_tensor:0", asset_list[0].name)
|
||||||
asset_collection[0].name)
|
|
||||||
# The static asset data inside graph_proto.collection_def should not be
|
# The static asset data inside graph_proto.collection_def should not be
|
||||||
# scoped.
|
# scoped.
|
||||||
self._validate_asset_collection(export_dir, graph_proto.collection_def,
|
self._validate_assets(export_dir, graph_proto.asset_file_def, "foo.txt",
|
||||||
"foo.txt", "content_foo",
|
"content_foo", "asset_file_tensor:0")
|
||||||
"asset_file_tensor:0")
|
|
||||||
|
|
||||||
# The constant tensor should be scoped, but its contents should be
|
# The constant tensor should be scoped, but its contents should be
|
||||||
# unchanged.
|
# unchanged.
|
||||||
@ -1264,7 +1266,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testClearDevices(self):
|
def testClearDevices(self):
|
||||||
export_dir = self._get_export_dir("test_clear_devices")
|
export_dir = self._get_export_dir("test_clear_devices")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Specify a device and save a variable.
|
# Specify a device and save a variable.
|
||||||
ops.reset_default_graph()
|
ops.reset_default_graph()
|
||||||
@ -1288,7 +1290,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
|
|
||||||
def testStripDefaultAttrs(self):
|
def testStripDefaultAttrs(self):
|
||||||
export_dir = self._get_export_dir("test_strip_default_attrs")
|
export_dir = self._get_export_dir("test_strip_default_attrs")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Add a graph with two float32 variables and a Complex Op composing them
|
# Add a graph with two float32 variables and a Complex Op composing them
|
||||||
# with strip_default_attrs enabled.
|
# with strip_default_attrs enabled.
|
||||||
@ -1361,7 +1363,7 @@ class SavedModelTest(test.TestCase):
|
|||||||
def testInconsistentConsumerDefaultAttrs(self):
|
def testInconsistentConsumerDefaultAttrs(self):
|
||||||
export_dir = self._get_export_dir(
|
export_dir = self._get_export_dir(
|
||||||
"test_strip_default_attrs_no_consumer_defaults")
|
"test_strip_default_attrs_no_consumer_defaults")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
# Add a graph with a single variable and a test op with a defaultless
|
# Add a graph with a single variable and a test op with a defaultless
|
||||||
# float32 attr, "test_attr".
|
# float32 attr, "test_attr".
|
||||||
@ -1428,5 +1430,60 @@ class SavedModelTest(test.TestCase):
|
|||||||
loader.load(sess, ["foo"], export_dir)
|
loader.load(sess, ["foo"], export_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class SavedModelV1Test(SavedModelTestBase):
|
||||||
|
|
||||||
|
def _validate_asset_collection(self,
|
||||||
|
export_dir,
|
||||||
|
graph_collection_def,
|
||||||
|
expected_asset_file_name,
|
||||||
|
expected_asset_file_contents,
|
||||||
|
expected_asset_tensor_name,
|
||||||
|
asset_id=0):
|
||||||
|
assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
|
||||||
|
asset = meta_graph_pb2.AssetFileDef()
|
||||||
|
assets_any[asset_id].Unpack(asset)
|
||||||
|
assets_path = os.path.join(
|
||||||
|
compat.as_bytes(export_dir),
|
||||||
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
|
compat.as_bytes(expected_asset_file_name))
|
||||||
|
actual_asset_contents = file_io.read_file_to_string(assets_path)
|
||||||
|
self.assertEqual(expected_asset_file_contents,
|
||||||
|
compat.as_text(actual_asset_contents))
|
||||||
|
self.assertEqual(expected_asset_file_name, asset.filename)
|
||||||
|
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
|
||||||
|
|
||||||
|
def testWritingAssetsToCollection(self):
|
||||||
|
export_dir = self._get_export_dir("test_writing_assets_to_collection")
|
||||||
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
|
# Build an asset list.
|
||||||
|
ignored_filepath = os.path.join(
|
||||||
|
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
|
||||||
|
file_io.write_string_to_file(ignored_filepath, "will be ignored")
|
||||||
|
|
||||||
|
asset_collection = self._build_asset_collection(
|
||||||
|
"hello42.txt", "foo bar baz", "asset_file_tensor")
|
||||||
|
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, ["foo"], assets_collection=asset_collection)
|
||||||
|
|
||||||
|
# Save the SavedModel to disk.
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
|
foo_graph = loader.load(sess, ["foo"], export_dir)
|
||||||
|
self._validate_asset_collection(export_dir, foo_graph.collection_def,
|
||||||
|
"hello42.txt", "foo bar baz",
|
||||||
|
"asset_file_tensor:0")
|
||||||
|
ignored_asset_path = os.path.join(
|
||||||
|
compat.as_bytes(export_dir),
|
||||||
|
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||||
|
compat.as_bytes("ignored.txt"))
|
||||||
|
self.assertFalse(file_io.file_exists(ignored_asset_path))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
path: "tensorflow.saved_model.Builder"
|
path: "tensorflow.saved_model.Builder"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
|
is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.saved_model.builder_impl._SavedModelBuilder\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
path: "tensorflow.saved_model.builder.SavedModelBuilder"
|
path: "tensorflow.saved_model.builder.SavedModelBuilder"
|
||||||
tf_class {
|
tf_class {
|
||||||
is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
|
is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
|
||||||
|
is_instance: "<class \'tensorflow.python.saved_model.builder_impl._SavedModelBuilder\'>"
|
||||||
is_instance: "<type \'object\'>"
|
is_instance: "<type \'object\'>"
|
||||||
member_method {
|
member_method {
|
||||||
name: "__init__"
|
name: "__init__"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user