From a670c87ea6e919ecc2bf8354a9fb361ff8fc98be Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Jan 2020 14:21:38 -0800 Subject: [PATCH] Utility function for supporting classify and regress APIs for Keras models. PiperOrigin-RevId: 292012653 Change-Id: I7a031bbbd75d4d9ae98cf5fe23eb482083db34a5 --- tensorflow/python/saved_model/BUILD | 25 ++ .../python/saved_model/method_name_updater.py | 148 +++++++ .../saved_model/method_name_updater_test.py | 377 ++++++++++++++++++ tensorflow/python/saved_model/saved_model.py | 1 + ...ature_def_utils.-method-name-updater.pbtxt | 17 + ...flow.saved_model.signature_def_utils.pbtxt | 4 + tensorflow/tools/compatibility/renames_v2.py | 2 + 7 files changed, 574 insertions(+) create mode 100644 tensorflow/python/saved_model/method_name_updater.py create mode 100644 tensorflow/python/saved_model/method_name_updater_test.py create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.-method-name-updater.pbtxt diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index e8132813d4f..f99340e6bad 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -24,6 +24,7 @@ py_library( ":load", ":loader", ":main_op", + ":method_name_updater", ":save", ":signature_constants", ":signature_def_utils", @@ -516,3 +517,27 @@ py_library( "@six_archive//:six", ], ) + +py_library( + name = "method_name_updater", + srcs = ["method_name_updater.py"], + srcs_version = "PY2AND3", + deps = [ + ":constants", + ":loader", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + +tf_py_test( + name = "method_name_updater_test", + srcs = ["method_name_updater_test.py"], + deps = [ + ":method_name_updater", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", + "//tensorflow/python/eager:test", + ], +) diff --git a/tensorflow/python/saved_model/method_name_updater.py b/tensorflow/python/saved_model/method_name_updater.py new file mode 100644 index 00000000000..12f0bdd3552 --- /dev/null +++ b/tensorflow/python/saved_model/method_name_updater.py @@ -0,0 +1,148 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SignatureDef method name utility functions. + +Utility functions for manipulating signature_def.method_names. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.python.lib.io import file_io +from tensorflow.python.platform import tf_logging +from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import loader_impl as loader +from tensorflow.python.util import compat +from tensorflow.python.util.tf_export import tf_export + + +# TODO(jdchung): Consider integrated this into the saved_model_cli so that users +# could do this from the command line directly. +@tf_export(v1=["saved_model.signature_def_utils.MethodNameUpdater"]) +class MethodNameUpdater(object): + """Updates the method name(s) of the SavedModel stored in the given path. + + The `MethodNameUpdater` class provides the functionality to update the method + name field in the signature_defs of the given SavedModel. For example, it + can be used to replace the `predict` `method_name` to `regress`. + + Typical usages of the `MethodNameUpdater` + ```python + ... + updater = tf.compat.v1.saved_model.MethodNameUpdater(export_dir) + # Update all signature_defs with key "foo" in all meta graph defs. + updater.replace_method_name(signature_key="foo", method_name="regress") + # Update a single signature_def with key "bar" in the meta graph def with + # tags ["serve"] + updater.replace_method_name(signature_key="bar", method_name="classify", + tags="serve") + updater.save(new_export_dir) + ``` + + Note: This function will only be available through the v1 compatibility + library as tf.compat.v1.saved_model.builder.MethodNameUpdater. + """ + + def __init__(self, export_dir): + """Creates an MethodNameUpdater object. + + Args: + export_dir: Directory containing the SavedModel files. + + Raises: + IOError: If the saved model file does not exist, or cannot be successfully + parsed. + """ + self._export_dir = export_dir + self._saved_model = loader.parse_saved_model(export_dir) + + def replace_method_name(self, signature_key, method_name, tags=None): + """Replaces the method_name in the specified signature_def. + + This will match and replace multiple sig defs iff tags is None (i.e when + multiple `MetaGraph`s have a signature_def with the same key). + If tags is not None, this will only replace a single signature_def in the + `MetaGraph` with matching tags. + + Args: + signature_key: Key of the signature_def to be updated. + method_name: new method_name to replace the existing one. + tags: A tag or sequence of tags identifying the `MetaGraph` to update. If + None, all meta graphs will be updated. + Raises: + ValueError: if signature_key or method_name are not defined or + if no metagraphs were found with the associated tags or + if no meta graph has a signature_def that matches signature_key. + """ + if not signature_key: + raise ValueError("signature_key must be defined.") + if not method_name: + raise ValueError("method_name must be defined.") + + if (tags is not None and not isinstance(tags, list)): + tags = [tags] + found_match = False + for meta_graph_def in self._saved_model.meta_graphs: + if tags is None or set(tags) == set(meta_graph_def.meta_info_def.tags): + if signature_key not in meta_graph_def.signature_def: + raise ValueError( + "MetaGraphDef associated with tags " + str(tags) + + " does not have a signature_def with key: " + signature_key + + ". This means either you specified the wrong signature key or " + "forgot to put the signature_def with the corresponding key in " + "your SavedModel.") + meta_graph_def.signature_def[signature_key].method_name = method_name + found_match = True + + if not found_match: + raise ValueError( + "MetaGraphDef associated with tags " + str(tags) + + " could not be found in SavedModel. This means either you specified " + "the invalid tags your SavedModel does not have a MetaGraph with " + "the specified tags") + + def save(self, new_export_dir=None): + """Saves the updated `SavedModel`. + + Args: + new_export_dir: Path where the updated `SavedModel` will be saved. If + None, the input `SavedModel` will be overriden with the updates. + + Raises: + errors.OpError: If there are errors during the file save operation. + """ + + is_input_text_proto = file_io.file_exists(os.path.join( + compat.as_bytes(self._export_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))) + if not new_export_dir: + new_export_dir = self._export_dir + + if is_input_text_proto: + # TODO(jdchung): Add a util for the path creation below. + path = os.path.join( + compat.as_bytes(new_export_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) + file_io.write_string_to_file(path, str(self._saved_model)) + else: + path = os.path.join( + compat.as_bytes(new_export_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) + file_io.write_string_to_file( + path, self._saved_model.SerializeToString(deterministic=True)) + tf_logging.info("SavedModel written to: %s", compat.as_text(path)) diff --git a/tensorflow/python/saved_model/method_name_updater_test.py b/tensorflow/python/saved_model/method_name_updater_test.py new file mode 100644 index 00000000000..9009784990b --- /dev/null +++ b/tensorflow/python/saved_model/method_name_updater_test.py @@ -0,0 +1,377 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for method name utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +from google.protobuf import text_format +from tensorflow.core.protobuf import saved_model_pb2 +from tensorflow.python.lib.io import file_io +from tensorflow.python.platform import test +from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import loader_impl as loader +from tensorflow.python.saved_model import method_name_updater +from tensorflow.python.util import compat + +_SAVED_MODEL_PROTO = text_format.Parse(""" +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + tags: "serve" + } + signature_def: { + key: "serving_default" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { + dim { size: -1 } + dim { size: 100 } + } + } + } + } + } + signature_def: { + key: "foo" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } } + } + } + } + } +} +meta_graphs { + meta_info_def { + tags: "serve" + tags: "gpu" + } + signature_def: { + key: "serving_default" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { + dim { size: -1 } + } + } + } + } + } + signature_def: { + key: "bar" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } } + } + } + } + } +} +""", saved_model_pb2.SavedModel()) + + +class MethodNameUpdaterTest(test.TestCase): + + def setUp(self): + super(MethodNameUpdaterTest, self).setUp() + self._saved_model_path = tempfile.mkdtemp(prefix=test.get_temp_dir()) + + def testBasic(self): + path = os.path.join( + compat.as_bytes(self._saved_model_path), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) + file_io.write_string_to_file( + path, _SAVED_MODEL_PROTO.SerializeToString(deterministic=True)) + + updater = method_name_updater.MethodNameUpdater(self._saved_model_path) + updater.replace_method_name( + signature_key="serving_default", method_name="classify") + updater.save() + + actual = loader.parse_saved_model(self._saved_model_path) + self.assertProtoEquals( + actual, + text_format.Parse( + """ + saved_model_schema_version: 1 + meta_graphs { + meta_info_def { + tags: "serve" + } + signature_def: { + key: "serving_default" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "classify" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { + dim { size: -1 } + dim { size: 100 } + } + } + } + } + } + signature_def: { + key: "foo" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } } + } + } + } + } + } + meta_graphs { + meta_info_def { + tags: "serve" + tags: "gpu" + } + signature_def: { + key: "serving_default" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "classify" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { + dim { size: -1 } + } + } + } + } + } + signature_def: { + key: "bar" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } } + } + } + } + } + } + """, saved_model_pb2.SavedModel())) + + def testTextFormatAndNewExportDir(self): + path = os.path.join( + compat.as_bytes(self._saved_model_path), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) + file_io.write_string_to_file(path, str(_SAVED_MODEL_PROTO)) + + updater = method_name_updater.MethodNameUpdater(self._saved_model_path) + updater.replace_method_name( + signature_key="foo", method_name="regress", tags="serve") + updater.replace_method_name( + signature_key="bar", method_name="classify", tags=["gpu", "serve"]) + + new_export_dir = tempfile.mkdtemp(prefix=test.get_temp_dir()) + updater.save(new_export_dir) + + self.assertTrue( + file_io.file_exists( + os.path.join( + compat.as_bytes(new_export_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)))) + actual = loader.parse_saved_model(new_export_dir) + self.assertProtoEquals( + actual, + text_format.Parse( + """ + saved_model_schema_version: 1 + meta_graphs { + meta_info_def { + tags: "serve" + } + signature_def: { + key: "serving_default" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { + dim { size: -1 } + dim { size: 100 } + } + } + } + } + } + signature_def: { + key: "foo" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "regress" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } } + } + } + } + } + } + meta_graphs { + meta_info_def { + tags: "serve" + tags: "gpu" + } + signature_def: { + key: "serving_default" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "predict" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { + dim { size: -1 } + } + } + } + } + } + signature_def: { + key: "bar" + value: { + inputs: { + key: "inputs" + value { name: "input_node:0" } + } + method_name: "classify" + outputs: { + key: "outputs" + value { + dtype: DT_FLOAT + tensor_shape { dim { size: 1 } } + } + } + } + } + } + """, saved_model_pb2.SavedModel())) + + def testExceptions(self): + with self.assertRaises(IOError): + updater = method_name_updater.MethodNameUpdater( + tempfile.mkdtemp(prefix=test.get_temp_dir())) + + path = os.path.join( + compat.as_bytes(self._saved_model_path), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) + file_io.write_string_to_file( + path, _SAVED_MODEL_PROTO.SerializeToString(deterministic=True)) + updater = method_name_updater.MethodNameUpdater(self._saved_model_path) + + with self.assertRaisesRegex(ValueError, "signature_key must be defined"): + updater.replace_method_name( + signature_key=None, method_name="classify") + + with self.assertRaisesRegex(ValueError, "method_name must be defined"): + updater.replace_method_name( + signature_key="foobar", method_name="") + + with self.assertRaisesRegex( + ValueError, + r"MetaGraphDef associated with tags \['gpu'\] could not be found"): + updater.replace_method_name( + signature_key="bar", method_name="classify", tags=["gpu"]) + + with self.assertRaisesRegex( + ValueError, r"MetaGraphDef associated with tags \['serve'\] does not " + r"have a signature_def with key: baz"): + updater.replace_method_name( + signature_key="baz", method_name="classify", tags=["serve"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py index 9c926d789f4..68862a81229 100644 --- a/tensorflow/python/saved_model/saved_model.py +++ b/tensorflow/python/saved_model/saved_model.py @@ -25,6 +25,7 @@ from tensorflow.python.saved_model import builder from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import main_op +from tensorflow.python.saved_model import method_name_updater from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.-method-name-updater.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.-method-name-updater.pbtxt new file mode 100644 index 00000000000..bbe74622ba1 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.-method-name-updater.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.saved_model.signature_def_utils.MethodNameUpdater" +tf_class { + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "replace_method_name" + argspec: "args=[\'self\', \'signature_key\', \'method_name\', \'tags\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "save" + argspec: "args=[\'self\', \'new_export_dir\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.pbtxt index a5602464eeb..dde72390316 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.signature_def_utils.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.saved_model.signature_def_utils" tf_module { + member { + name: "MethodNameUpdater" + mtype: "" + } member_method { name: "build_signature_def" argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 6fa7bc8aaa7..c3edae0f616 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -1082,6 +1082,8 @@ renames = { 'tf.saved_model.REGRESS_METHOD_NAME', 'tf.saved_model.signature_constants.REGRESS_OUTPUTS': 'tf.saved_model.REGRESS_OUTPUTS', + 'tf.saved_model.signature_def_utils.MethodNameUpdater': + 'tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater', 'tf.saved_model.signature_def_utils.build_signature_def': 'tf.compat.v1.saved_model.signature_def_utils.build_signature_def', 'tf.saved_model.signature_def_utils.classification_signature_def':