Add SaveOptions object with option to whitelist op namespaces. Added options argument to all functions that save out a SavedModel.

PiperOrigin-RevId: 266021878
This commit is contained in:
Katherine Wu 2019-08-28 16:38:17 -07:00 committed by TensorFlower Gardener
parent 25f08ffe5d
commit cc739565b9
27 changed files with 201 additions and 27 deletions

View File

@ -1119,7 +1119,8 @@ class Network(base_layer.Layer):
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None):
signatures=None,
options=None):
"""Saves the model to Tensorflow SavedModel or a single HDF5 file.
The savefile includes:
@ -1148,6 +1149,8 @@ class Network(base_layer.Layer):
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
format only. Please see the `signatures` argument in
`tf.saved_model.save` for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies
options for saving to SavedModel.
Example:
@ -1163,7 +1166,7 @@ class Network(base_layer.Layer):
```
"""
saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
signatures)
signatures, options)
def save_weights(self, filepath, overwrite=True, save_format=None):
"""Saves all layer weights.

View File

@ -49,7 +49,8 @@ def save_model(model,
overwrite=True,
include_optimizer=True,
save_format=None,
signatures=None):
signatures=None,
options=None):
"""Saves a model as a TensorFlow SavedModel or HDF5 file.
The saved model contains:
@ -83,6 +84,8 @@ def save_model(model,
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
format only. Please see the `signatures` argument in
`tf.saved_model.save` for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies
options for saving to SavedModel.
Raises:
ImportError: If save format is hdf5, and h5py is not available.
@ -109,7 +112,7 @@ def save_model(model,
model, filepath, overwrite, include_optimizer)
else:
saved_model_save.save(model, filepath, overwrite, include_optimizer,
signatures)
signatures, options)
@keras_export('keras.models.load_model')

View File

@ -58,7 +58,8 @@ training_lib = LazyLoader(
# pylint:enable=g-inconsistent-quotes
def save(model, filepath, overwrite, include_optimizer, signatures=None):
def save(model, filepath, overwrite, include_optimizer, signatures=None,
options=None):
"""Saves a model as a SavedModel to the filepath.
Args:
@ -69,6 +70,8 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None):
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
format only. Please see the `signatures` argument in `tf.saved_model.save`
for details.
options: Optional`tf.saved_model.SaveOptions` object that specifies
options for saving to SavedModel.
Raises:
ValueError: if the model's inputs have not been defined.
@ -89,7 +92,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None):
# Trace all functions and signatures with `training=0` instead of using the
# default learning phase placeholder.
with K.learning_phase_scope(0):
save_lib.save(model, filepath, signatures)
save_lib.save(model, filepath, signatures, options)
if not include_optimizer:
model.optimizer = orig_optimizer

View File

@ -290,6 +290,7 @@ py_library(
":function_serialization",
":nested_structure_coder",
":revived_types",
":save_options",
":signature_constants",
":signature_def_utils",
":signature_serialization",
@ -502,3 +503,11 @@ tf_py_test(
"//tensorflow/python/eager:test",
],
)
py_library(
name = "save_options",
srcs = ["save_options.py"],
deps = [
"@six_archive//:six",
],
)

View File

@ -44,6 +44,7 @@ from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import save_options
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import signature_serialization
@ -533,7 +534,8 @@ def _process_asset(trackable_asset, asset_info, resource_map):
resource_map[original_path_tensor] = asset_variable
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
namespace_whitelist):
"""Generates a MetaGraph which calls `signature_functions`.
Args:
@ -541,6 +543,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
saveable_view: The _SaveableView being exported.
signature_functions: A dictionary mapping signature keys to concrete
functions containing signatures to add to the MetaGraph.
namespace_whitelist: List of strings containing whitelisted op namespaces.
Returns:
An _AssetInfo, which contains information to help creating the SavedModel.
@ -593,6 +596,7 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
saver_def = saver.to_proto()
meta_graph_def.saver_def.CopyFrom(saver_def)
graph_def = exported_graph.as_graph_def(add_shapes=True)
_verify_ops(graph_def, namespace_whitelist)
meta_graph_def.graph_def.CopyFrom(graph_def)
meta_graph_def.meta_info_def.tags.append(tag_constants.SERVING)
@ -610,6 +614,32 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
return asset_info, exported_graph
def _verify_ops(graph_def, namespace_whitelist):
"""Verifies that all namespaced ops in the graph are whitelisted."""
invalid_ops = []
invalid_namespaces = set()
all_operations = []
all_operations.extend(meta_graph.ops_used_by_graph_def(graph_def))
for op in all_operations:
if ">" in op:
namespace = op.split(">")[0]
if namespace not in namespace_whitelist:
invalid_ops.append(op)
invalid_namespaces.add(namespace)
if invalid_ops:
raise ValueError(
"Attempted to save ops from non-whitelisted namespaces to SavedModel: "
"{}.\nPlease verify that these ops should be saved, since they must be "
"available when loading the SavedModel. If loading from Python, you "
"must import the library defining these ops. From C++, link the custom "
"ops to the serving binary. Once you've confirmed this, please add the "
"following namespaces to the `namespace_whitelist` argument in "
"tf.saved_model.SaveOptions: {}.".format(
invalid_ops, invalid_namespaces))
def _serialize_object_graph(saveable_view, asset_file_def_index):
"""Save a SavedObjectGraph proto for `root`."""
# SavedObjectGraph is similar to the TrackableObjectGraph proto in the
@ -672,7 +702,7 @@ def _write_object_proto(obj, proto, asset_file_def_index):
@tf_export("saved_model.save",
v1=["saved_model.save", "saved_model.experimental.save"])
def save(obj, export_dir, signatures=None):
def save(obj, export_dir, signatures=None, options=None):
# pylint: disable=line-too-long
"""Exports the Trackable object `obj` to [SavedModel format](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md).
@ -808,6 +838,8 @@ def save(obj, export_dir, signatures=None):
signatures or concrete functions. The keys of such a dictionary may be
arbitrary strings, but will typically be from the
`tf.saved_model.signature_constants` module.
options: Optional, `tf.saved_model.SaveOptions` object that specifies
options for saving.
Raises:
ValueError: If `obj` is not trackable.
@ -830,6 +862,7 @@ def save(obj, export_dir, signatures=None):
if not isinstance(obj, base.Trackable):
raise ValueError(
"Expected a Trackable object for export, got {}.".format(obj))
options = options or save_options.SaveOptions()
checkpoint_graph_view = _AugmentedGraphView(obj)
if signatures is None:
@ -857,7 +890,7 @@ def save(obj, export_dir, signatures=None):
meta_graph_def = saved_model.meta_graphs.add()
object_saver = util.TrackableSaver(checkpoint_graph_view)
asset_info, exported_graph = _fill_meta_graph_def(
meta_graph_def, saveable_view, signatures)
meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
saved_model.saved_model_schema_version = (
constants.SAVED_MODEL_SCHEMA_VERSION)
# So far we've just been generating protocol buffers with no I/O. Now we write

View File

@ -0,0 +1,64 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Options for saving SavedModels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@tf_export("saved_model.SaveOptions")
class SaveOptions(object):
"""Options for saving to SavedModel.
This function may be used in the `options` argument in functions that
save a SavedModel (`tf.saved_model.save`, `tf.keras.models.save_model`).
"""
# Define object attributes in __slots__ for improved memory and performance.
__slots__ = ("namespace_whitelist",)
def __init__(self, namespace_whitelist=None):
"""Creates an object that stores options for SavedModel saving.
Args:
namespace_whitelist: List of strings containing op namespaces to whitelist
when saving a model. Saving an object that uses namespaced ops must
explicitly add all namespaces to the whitelist. The namespaced ops must
be registered into the framework when loading the SavedModel.
"""
self.namespace_whitelist = _validate_namespace_whitelist(
namespace_whitelist)
def _validate_namespace_whitelist(namespace_whitelist):
"""Validates namespace whitelist argument."""
if namespace_whitelist is None:
return []
if not isinstance(namespace_whitelist, list):
raise TypeError("Namespace whitelist must be a list of strings.")
processed = []
for namespace in namespace_whitelist:
if not isinstance(namespace, six.string_types):
raise ValueError("Whitelisted namespace must be a string. Got: {} of type"
" {}.".format(namespace, type(namespace)))
processed.append(compat.as_str(namespace))
return processed

View File

@ -21,6 +21,9 @@ from __future__ import print_function
import os
import sys
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import backprop
@ -391,6 +394,28 @@ class SaveTest(test.TestCase):
_import_and_infer(save_dir, {"x": 3}))
class SavingOptionsTest(test.TestCase):
def testOpNameSpace(self):
# TODO(kathywu): Add test that saves out SavedModel with a custom op when
# the ">" character is allowed in op names.
graph_def = graph_pb2.GraphDef()
text_format.Merge("node { name: 'A' op: 'Test>CustomOp' }",
graph_def)
with self.assertRaisesRegexp(
ValueError, "Attempted to save ops from non-whitelisted namespaces"):
save._verify_ops(graph_def, [])
save._verify_ops(graph_def, ["Test"])
# Test with multiple carrots in op name.
text_format.Merge("node { name: 'A' op: 'Test>>A>CustomOp' }",
graph_def)
with self.assertRaisesRegexp(
ValueError, "Attempted to save ops from non-whitelisted namespaces"):
save._verify_ops(graph_def, [])
save._verify_ops(graph_def, ["Test"])
class AssetTests(test.TestCase):
def setUp(self):

View File

@ -277,7 +277,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -294,7 +294,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -278,7 +278,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -278,7 +278,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -277,7 +277,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -294,7 +294,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -30,6 +30,6 @@ tf_module {
}
member_method {
name: "save_model"
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -0,0 +1,13 @@
path: "tensorflow.saved_model.SaveOptions"
tf_class {
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "namespace_whitelist"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -2,6 +2,6 @@ path: "tensorflow.saved_model.experimental"
tf_module {
member_method {
name: "save"
argspec: "args=[\'obj\', \'export_dir\', \'signatures\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'obj\', \'export_dir\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}

View File

@ -84,6 +84,10 @@ tf_module {
name: "SERVING"
mtype: "<type \'str\'>"
}
member {
name: "SaveOptions"
mtype: "<type \'type\'>"
}
member {
name: "TPU"
mtype: "<type \'str\'>"
@ -186,7 +190,7 @@ tf_module {
}
member_method {
name: "save"
argspec: "args=[\'obj\', \'export_dir\', \'signatures\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'obj\', \'export_dir\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "simple_save"

View File

@ -277,7 +277,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -294,7 +294,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -278,7 +278,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -278,7 +278,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -277,7 +277,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -294,7 +294,7 @@ tf_class {
}
member_method {
name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "save_weights"

View File

@ -30,6 +30,6 @@ tf_module {
}
member_method {
name: "save_model"
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\'], "
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
}
}

View File

@ -0,0 +1,13 @@
path: "tensorflow.saved_model.SaveOptions"
tf_class {
is_instance: "<class \'tensorflow.python.saved_model.save_options.SaveOptions\'>"
is_instance: "<type \'object\'>"
member {
name: "namespace_whitelist"
mtype: "<type \'member_descriptor\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'namespace_whitelist\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}

View File

@ -72,6 +72,10 @@ tf_module {
name: "SERVING"
mtype: "<type \'str\'>"
}
member {
name: "SaveOptions"
mtype: "<type \'type\'>"
}
member {
name: "TPU"
mtype: "<type \'str\'>"
@ -98,6 +102,6 @@ tf_module {
}
member_method {
name: "save"
argspec: "args=[\'obj\', \'export_dir\', \'signatures\'], varargs=None, keywords=None, defaults=[\'None\'], "
argspec: "args=[\'obj\', \'export_dir\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
}

View File

@ -51,7 +51,7 @@ if sys.version_info.major == 3:
_NORMALIZE_TYPE = {}
for t in ('property', 'object', 'getset_descriptor', 'int', 'str', 'type',
'tuple', 'module', 'collections.defaultdict', 'set', 'dict',
'NoneType', 'frozenset'):
'NoneType', 'frozenset', 'member_descriptor'):
_NORMALIZE_TYPE["<class '%s'>" % t] = "<type '%s'>" % t
for e in 'Exception', 'RuntimeError':
_NORMALIZE_TYPE["<class '%s'>" % e] = "<type 'exceptions.%s'>" % e