From 314b503bd2a5837de23fa35dadaac0f6b557e246 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Tue, 21 Apr 2020 15:20:49 -0700 Subject: [PATCH] Add SavedModel support to speech commands training PiperOrigin-RevId: 307691321 Change-Id: Ib52d1bf5537835ce56db11b074a0f3df9c3f9206 --- tensorflow/examples/speech_commands/freeze.py | 78 +++++++++++++++++-- .../examples/speech_commands/freeze_test.py | 25 ++++++ 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/tensorflow/examples/speech_commands/freeze.py b/tensorflow/examples/speech_commands/freeze.py index 4a48a440b6e..44940b0647f 100644 --- a/tensorflow/examples/speech_commands/freeze.py +++ b/tensorflow/examples/speech_commands/freeze.py @@ -80,6 +80,9 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, preprocess: How the spectrogram is processed to produce features, for example 'mfcc', 'average', or 'micro'. + Returns: + Input and output tensor objects. + Raises: Exception: If the preprocessing mode isn't recognized. """ @@ -150,7 +153,59 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms, runtime_settings=runtime_settings) # Create an output to use for inference. - tf.nn.softmax(logits, name='labels_softmax') + softmax = tf.nn.softmax(logits, name='labels_softmax') + + return reshaped_input, softmax + + +def save_graph_def(file_name, frozen_graph_def): + """Writes a graph def file out to disk. + + Args: + file_name: Where to save the file. + frozen_graph_def: GraphDef proto object to save. + """ + tf.io.write_graph( + frozen_graph_def, + os.path.dirname(file_name), + os.path.basename(file_name), + as_text=False) + tf.compat.v1.logging.info('Saved frozen graph to %s', file_name) + + +def save_saved_model(file_name, sess, input_tensor, output_tensor): + """Writes a SavedModel out to disk. + + Args: + file_name: Where to save the file. + sess: TensorFlow session containing the graph. + input_tensor: Tensor object defining the input's properties. + output_tensor: Tensor object defining the output's properties. + """ + # Store the frozen graph as a SavedModel for v2 compatibility. + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name) + tensor_info_inputs = { + 'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor) + } + tensor_info_outputs = { + 'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor) + } + signature = ( + tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs=tensor_info_inputs, + outputs=tensor_info_outputs, + method_name=tf.compat.v1.saved_model.signature_constants + .PREDICT_METHOD_NAME)) + builder.add_meta_graph_and_variables( + sess, + [tf.compat.v1.saved_model.tag_constants.SERVING], + signature_def_map={ + tf.compat.v1.saved_model.signature_constants + .DEFAULT_SERVING_SIGNATURE_DEF_KEY: + signature, + }, + ) + builder.save() def main(_): @@ -167,7 +222,7 @@ def main(_): # Create the model and load its weights. sess = tf.compat.v1.InteractiveSession() - create_inference_graph( + input_tensor, output_tensor = create_inference_graph( FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms, FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms, FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess) @@ -178,12 +233,14 @@ def main(_): # Turn all the variables into inline constants inside the graph and save it. frozen_graph_def = graph_util.convert_variables_to_constants( sess, sess.graph_def, ['labels_softmax']) - tf.io.write_graph( - frozen_graph_def, - os.path.dirname(FLAGS.output_file), - os.path.basename(FLAGS.output_file), - as_text=False) - tf.compat.v1.logging.info('Saved frozen graph to %s', FLAGS.output_file) + + if FLAGS.save_format == 'graph_def': + save_graph_def(FLAGS.output_file, frozen_graph_def) + elif FLAGS.save_format == 'saved_model': + save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor) + else: + raise Exception('Unknown save format "%s" (should be "graph_def" or' + ' "saved_model")' % (FLAGS.save_format)) if __name__ == '__main__': @@ -246,5 +303,10 @@ if __name__ == '__main__': type=str, default='mfcc', help='Spectrogram processing mode. Can be "mfcc" or "average"') + parser.add_argument( + '--save_format', + type=str, + default='graph_def', + help='How to save the result. Can be "graph_def" or "saved_model"') FLAGS, unparsed = parser.parse_known_args() tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/speech_commands/freeze_test.py b/tensorflow/examples/speech_commands/freeze_test.py index a242453d0e5..93a79b0b4f7 100644 --- a/tensorflow/examples/speech_commands/freeze_test.py +++ b/tensorflow/examples/speech_commands/freeze_test.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os.path + from tensorflow.examples.speech_commands import freeze +from tensorflow.python.framework import graph_util from tensorflow.python.framework import test_util +from tensorflow.python.ops.variables import global_variables_initializer from tensorflow.python.platform import test @@ -103,6 +107,27 @@ class FreezeTest(test.TestCase): ops = [node.op for node in sess.graph_def.node] self.assertEqual(0, ops.count('Mfcc')) + @test_util.run_deprecated_v1 + def testCreateSavedModel(self): + tmp_dir = self.get_temp_dir() + saved_model_path = os.path.join(tmp_dir, 'saved_model') + with self.cached_session() as sess: + input_tensor, output_tensor = freeze.create_inference_graph( + wanted_words='a,b,c,d', + sample_rate=16000, + clip_duration_ms=1000.0, + clip_stride_ms=30.0, + window_size_ms=30.0, + window_stride_ms=10.0, + feature_bin_count=40, + model_architecture='conv', + preprocess='micro') + global_variables_initializer().run() + graph_util.convert_variables_to_constants( + sess, sess.graph_def, ['labels_softmax']) + freeze.save_saved_model(saved_model_path, sess, input_tensor, + output_tensor) + if __name__ == '__main__': test.main()