Utility function for supporting classify and regress APIs for Keras models.
PiperOrigin-RevId: 292012653 Change-Id: I7a031bbbd75d4d9ae98cf5fe23eb482083db34a5
This commit is contained in:
parent
4152ed672e
commit
a670c87ea6
@ -24,6 +24,7 @@ py_library(
|
||||
":load",
|
||||
":loader",
|
||||
":main_op",
|
||||
":method_name_updater",
|
||||
":save",
|
||||
":signature_constants",
|
||||
":signature_def_utils",
|
||||
@ -516,3 +517,27 @@ py_library(
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "method_name_updater",
|
||||
srcs = ["method_name_updater.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":constants",
|
||||
":loader",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "method_name_updater_test",
|
||||
srcs = ["method_name_updater_test.py"],
|
||||
deps = [
|
||||
":method_name_updater",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
|
148
tensorflow/python/saved_model/method_name_updater.py
Normal file
148
tensorflow/python/saved_model/method_name_updater.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""SignatureDef method name utility functions.
|
||||
|
||||
Utility functions for manipulating signature_def.method_names.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import loader_impl as loader
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# TODO(jdchung): Consider integrated this into the saved_model_cli so that users
|
||||
# could do this from the command line directly.
|
||||
@tf_export(v1=["saved_model.signature_def_utils.MethodNameUpdater"])
|
||||
class MethodNameUpdater(object):
|
||||
"""Updates the method name(s) of the SavedModel stored in the given path.
|
||||
|
||||
The `MethodNameUpdater` class provides the functionality to update the method
|
||||
name field in the signature_defs of the given SavedModel. For example, it
|
||||
can be used to replace the `predict` `method_name` to `regress`.
|
||||
|
||||
Typical usages of the `MethodNameUpdater`
|
||||
```python
|
||||
...
|
||||
updater = tf.compat.v1.saved_model.MethodNameUpdater(export_dir)
|
||||
# Update all signature_defs with key "foo" in all meta graph defs.
|
||||
updater.replace_method_name(signature_key="foo", method_name="regress")
|
||||
# Update a single signature_def with key "bar" in the meta graph def with
|
||||
# tags ["serve"]
|
||||
updater.replace_method_name(signature_key="bar", method_name="classify",
|
||||
tags="serve")
|
||||
updater.save(new_export_dir)
|
||||
```
|
||||
|
||||
Note: This function will only be available through the v1 compatibility
|
||||
library as tf.compat.v1.saved_model.builder.MethodNameUpdater.
|
||||
"""
|
||||
|
||||
def __init__(self, export_dir):
|
||||
"""Creates an MethodNameUpdater object.
|
||||
|
||||
Args:
|
||||
export_dir: Directory containing the SavedModel files.
|
||||
|
||||
Raises:
|
||||
IOError: If the saved model file does not exist, or cannot be successfully
|
||||
parsed.
|
||||
"""
|
||||
self._export_dir = export_dir
|
||||
self._saved_model = loader.parse_saved_model(export_dir)
|
||||
|
||||
def replace_method_name(self, signature_key, method_name, tags=None):
|
||||
"""Replaces the method_name in the specified signature_def.
|
||||
|
||||
This will match and replace multiple sig defs iff tags is None (i.e when
|
||||
multiple `MetaGraph`s have a signature_def with the same key).
|
||||
If tags is not None, this will only replace a single signature_def in the
|
||||
`MetaGraph` with matching tags.
|
||||
|
||||
Args:
|
||||
signature_key: Key of the signature_def to be updated.
|
||||
method_name: new method_name to replace the existing one.
|
||||
tags: A tag or sequence of tags identifying the `MetaGraph` to update. If
|
||||
None, all meta graphs will be updated.
|
||||
Raises:
|
||||
ValueError: if signature_key or method_name are not defined or
|
||||
if no metagraphs were found with the associated tags or
|
||||
if no meta graph has a signature_def that matches signature_key.
|
||||
"""
|
||||
if not signature_key:
|
||||
raise ValueError("signature_key must be defined.")
|
||||
if not method_name:
|
||||
raise ValueError("method_name must be defined.")
|
||||
|
||||
if (tags is not None and not isinstance(tags, list)):
|
||||
tags = [tags]
|
||||
found_match = False
|
||||
for meta_graph_def in self._saved_model.meta_graphs:
|
||||
if tags is None or set(tags) == set(meta_graph_def.meta_info_def.tags):
|
||||
if signature_key not in meta_graph_def.signature_def:
|
||||
raise ValueError(
|
||||
"MetaGraphDef associated with tags " + str(tags) +
|
||||
" does not have a signature_def with key: " + signature_key +
|
||||
". This means either you specified the wrong signature key or "
|
||||
"forgot to put the signature_def with the corresponding key in "
|
||||
"your SavedModel.")
|
||||
meta_graph_def.signature_def[signature_key].method_name = method_name
|
||||
found_match = True
|
||||
|
||||
if not found_match:
|
||||
raise ValueError(
|
||||
"MetaGraphDef associated with tags " + str(tags) +
|
||||
" could not be found in SavedModel. This means either you specified "
|
||||
"the invalid tags your SavedModel does not have a MetaGraph with "
|
||||
"the specified tags")
|
||||
|
||||
def save(self, new_export_dir=None):
|
||||
"""Saves the updated `SavedModel`.
|
||||
|
||||
Args:
|
||||
new_export_dir: Path where the updated `SavedModel` will be saved. If
|
||||
None, the input `SavedModel` will be overriden with the updates.
|
||||
|
||||
Raises:
|
||||
errors.OpError: If there are errors during the file save operation.
|
||||
"""
|
||||
|
||||
is_input_text_proto = file_io.file_exists(os.path.join(
|
||||
compat.as_bytes(self._export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)))
|
||||
if not new_export_dir:
|
||||
new_export_dir = self._export_dir
|
||||
|
||||
if is_input_text_proto:
|
||||
# TODO(jdchung): Add a util for the path creation below.
|
||||
path = os.path.join(
|
||||
compat.as_bytes(new_export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||
file_io.write_string_to_file(path, str(self._saved_model))
|
||||
else:
|
||||
path = os.path.join(
|
||||
compat.as_bytes(new_export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
file_io.write_string_to_file(
|
||||
path, self._saved_model.SerializeToString(deterministic=True))
|
||||
tf_logging.info("SavedModel written to: %s", compat.as_text(path))
|
377
tensorflow/python/saved_model/method_name_updater_test.py
Normal file
377
tensorflow/python/saved_model/method_name_updater_test.py
Normal file
@ -0,0 +1,377 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for method name utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.core.protobuf import saved_model_pb2
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import loader_impl as loader
|
||||
from tensorflow.python.saved_model import method_name_updater
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_SAVED_MODEL_PROTO = text_format.Parse("""
|
||||
saved_model_schema_version: 1
|
||||
meta_graphs {
|
||||
meta_info_def {
|
||||
tags: "serve"
|
||||
}
|
||||
signature_def: {
|
||||
key: "serving_default"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim { size: -1 }
|
||||
dim { size: 100 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "foo"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
meta_graphs {
|
||||
meta_info_def {
|
||||
tags: "serve"
|
||||
tags: "gpu"
|
||||
}
|
||||
signature_def: {
|
||||
key: "serving_default"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim { size: -1 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "bar"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""", saved_model_pb2.SavedModel())
|
||||
|
||||
|
||||
class MethodNameUpdaterTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(MethodNameUpdaterTest, self).setUp()
|
||||
self._saved_model_path = tempfile.mkdtemp(prefix=test.get_temp_dir())
|
||||
|
||||
def testBasic(self):
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._saved_model_path),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
file_io.write_string_to_file(
|
||||
path, _SAVED_MODEL_PROTO.SerializeToString(deterministic=True))
|
||||
|
||||
updater = method_name_updater.MethodNameUpdater(self._saved_model_path)
|
||||
updater.replace_method_name(
|
||||
signature_key="serving_default", method_name="classify")
|
||||
updater.save()
|
||||
|
||||
actual = loader.parse_saved_model(self._saved_model_path)
|
||||
self.assertProtoEquals(
|
||||
actual,
|
||||
text_format.Parse(
|
||||
"""
|
||||
saved_model_schema_version: 1
|
||||
meta_graphs {
|
||||
meta_info_def {
|
||||
tags: "serve"
|
||||
}
|
||||
signature_def: {
|
||||
key: "serving_default"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "classify"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim { size: -1 }
|
||||
dim { size: 100 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "foo"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
meta_graphs {
|
||||
meta_info_def {
|
||||
tags: "serve"
|
||||
tags: "gpu"
|
||||
}
|
||||
signature_def: {
|
||||
key: "serving_default"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "classify"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim { size: -1 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "bar"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""", saved_model_pb2.SavedModel()))
|
||||
|
||||
def testTextFormatAndNewExportDir(self):
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._saved_model_path),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
|
||||
file_io.write_string_to_file(path, str(_SAVED_MODEL_PROTO))
|
||||
|
||||
updater = method_name_updater.MethodNameUpdater(self._saved_model_path)
|
||||
updater.replace_method_name(
|
||||
signature_key="foo", method_name="regress", tags="serve")
|
||||
updater.replace_method_name(
|
||||
signature_key="bar", method_name="classify", tags=["gpu", "serve"])
|
||||
|
||||
new_export_dir = tempfile.mkdtemp(prefix=test.get_temp_dir())
|
||||
updater.save(new_export_dir)
|
||||
|
||||
self.assertTrue(
|
||||
file_io.file_exists(
|
||||
os.path.join(
|
||||
compat.as_bytes(new_export_dir),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))))
|
||||
actual = loader.parse_saved_model(new_export_dir)
|
||||
self.assertProtoEquals(
|
||||
actual,
|
||||
text_format.Parse(
|
||||
"""
|
||||
saved_model_schema_version: 1
|
||||
meta_graphs {
|
||||
meta_info_def {
|
||||
tags: "serve"
|
||||
}
|
||||
signature_def: {
|
||||
key: "serving_default"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim { size: -1 }
|
||||
dim { size: 100 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "foo"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "regress"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
meta_graphs {
|
||||
meta_info_def {
|
||||
tags: "serve"
|
||||
tags: "gpu"
|
||||
}
|
||||
signature_def: {
|
||||
key: "serving_default"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "predict"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim { size: -1 }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
signature_def: {
|
||||
key: "bar"
|
||||
value: {
|
||||
inputs: {
|
||||
key: "inputs"
|
||||
value { name: "input_node:0" }
|
||||
}
|
||||
method_name: "classify"
|
||||
outputs: {
|
||||
key: "outputs"
|
||||
value {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
""", saved_model_pb2.SavedModel()))
|
||||
|
||||
def testExceptions(self):
|
||||
with self.assertRaises(IOError):
|
||||
updater = method_name_updater.MethodNameUpdater(
|
||||
tempfile.mkdtemp(prefix=test.get_temp_dir()))
|
||||
|
||||
path = os.path.join(
|
||||
compat.as_bytes(self._saved_model_path),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
|
||||
file_io.write_string_to_file(
|
||||
path, _SAVED_MODEL_PROTO.SerializeToString(deterministic=True))
|
||||
updater = method_name_updater.MethodNameUpdater(self._saved_model_path)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "signature_key must be defined"):
|
||||
updater.replace_method_name(
|
||||
signature_key=None, method_name="classify")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "method_name must be defined"):
|
||||
updater.replace_method_name(
|
||||
signature_key="foobar", method_name="")
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"MetaGraphDef associated with tags \['gpu'\] could not be found"):
|
||||
updater.replace_method_name(
|
||||
signature_key="bar", method_name="classify", tags=["gpu"])
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"MetaGraphDef associated with tags \['serve'\] does not "
|
||||
r"have a signature_def with key: baz"):
|
||||
updater.replace_method_name(
|
||||
signature_key="baz", method_name="classify", tags=["serve"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -25,6 +25,7 @@ from tensorflow.python.saved_model import builder
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import main_op
|
||||
from tensorflow.python.saved_model import method_name_updater
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
|
@ -0,0 +1,17 @@
|
||||
path: "tensorflow.saved_model.signature_def_utils.MethodNameUpdater"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.saved_model.method_name_updater.MethodNameUpdater\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "replace_method_name"
|
||||
argspec: "args=[\'self\', \'signature_key\', \'method_name\', \'tags\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'new_export_dir\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.saved_model.signature_def_utils"
|
||||
tf_module {
|
||||
member {
|
||||
name: "MethodNameUpdater"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "build_signature_def"
|
||||
argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
|
@ -1082,6 +1082,8 @@ renames = {
|
||||
'tf.saved_model.REGRESS_METHOD_NAME',
|
||||
'tf.saved_model.signature_constants.REGRESS_OUTPUTS':
|
||||
'tf.saved_model.REGRESS_OUTPUTS',
|
||||
'tf.saved_model.signature_def_utils.MethodNameUpdater':
|
||||
'tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater',
|
||||
'tf.saved_model.signature_def_utils.build_signature_def':
|
||||
'tf.compat.v1.saved_model.signature_def_utils.build_signature_def',
|
||||
'tf.saved_model.signature_def_utils.classification_signature_def':
|
||||
|
Loading…
Reference in New Issue
Block a user