Add SavedModel support to speech commands training

PiperOrigin-RevId: 307691321
Change-Id: Ib52d1bf5537835ce56db11b074a0f3df9c3f9206
This commit is contained in:
Pete Warden 2020-04-21 15:20:49 -07:00 committed by TensorFlower Gardener
parent 567f7e24ef
commit 314b503bd2
2 changed files with 95 additions and 8 deletions

View File

@ -80,6 +80,9 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
preprocess: How the spectrogram is processed to produce features, for
example 'mfcc', 'average', or 'micro'.
Returns:
Input and output tensor objects.
Raises:
Exception: If the preprocessing mode isn't recognized.
"""
@ -150,7 +153,59 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
runtime_settings=runtime_settings)
# Create an output to use for inference.
tf.nn.softmax(logits, name='labels_softmax')
softmax = tf.nn.softmax(logits, name='labels_softmax')
return reshaped_input, softmax
def save_graph_def(file_name, frozen_graph_def):
"""Writes a graph def file out to disk.
Args:
file_name: Where to save the file.
frozen_graph_def: GraphDef proto object to save.
"""
tf.io.write_graph(
frozen_graph_def,
os.path.dirname(file_name),
os.path.basename(file_name),
as_text=False)
tf.compat.v1.logging.info('Saved frozen graph to %s', file_name)
def save_saved_model(file_name, sess, input_tensor, output_tensor):
"""Writes a SavedModel out to disk.
Args:
file_name: Where to save the file.
sess: TensorFlow session containing the graph.
input_tensor: Tensor object defining the input's properties.
output_tensor: Tensor object defining the output's properties.
"""
# Store the frozen graph as a SavedModel for v2 compatibility.
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name)
tensor_info_inputs = {
'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)
}
tensor_info_outputs = {
'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)
}
signature = (
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=tf.compat.v1.saved_model.signature_constants
.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess,
[tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature,
},
)
builder.save()
def main(_):
@ -167,7 +222,7 @@ def main(_):
# Create the model and load its weights.
sess = tf.compat.v1.InteractiveSession()
create_inference_graph(
input_tensor, output_tensor = create_inference_graph(
FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
@ -178,12 +233,14 @@ def main(_):
# Turn all the variables into inline constants inside the graph and save it.
frozen_graph_def = graph_util.convert_variables_to_constants(
sess, sess.graph_def, ['labels_softmax'])
tf.io.write_graph(
frozen_graph_def,
os.path.dirname(FLAGS.output_file),
os.path.basename(FLAGS.output_file),
as_text=False)
tf.compat.v1.logging.info('Saved frozen graph to %s', FLAGS.output_file)
if FLAGS.save_format == 'graph_def':
save_graph_def(FLAGS.output_file, frozen_graph_def)
elif FLAGS.save_format == 'saved_model':
save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor)
else:
raise Exception('Unknown save format "%s" (should be "graph_def" or'
' "saved_model")' % (FLAGS.save_format))
if __name__ == '__main__':
@ -246,5 +303,10 @@ if __name__ == '__main__':
type=str,
default='mfcc',
help='Spectrogram processing mode. Can be "mfcc" or "average"')
parser.add_argument(
'--save_format',
type=str,
default='graph_def',
help='How to save the result. Can be "graph_def" or "saved_model"')
FLAGS, unparsed = parser.parse_known_args()
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -18,8 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from tensorflow.examples.speech_commands import freeze
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import test_util
from tensorflow.python.ops.variables import global_variables_initializer
from tensorflow.python.platform import test
@ -103,6 +107,27 @@ class FreezeTest(test.TestCase):
ops = [node.op for node in sess.graph_def.node]
self.assertEqual(0, ops.count('Mfcc'))
@test_util.run_deprecated_v1
def testCreateSavedModel(self):
tmp_dir = self.get_temp_dir()
saved_model_path = os.path.join(tmp_dir, 'saved_model')
with self.cached_session() as sess:
input_tensor, output_tensor = freeze.create_inference_graph(
wanted_words='a,b,c,d',
sample_rate=16000,
clip_duration_ms=1000.0,
clip_stride_ms=30.0,
window_size_ms=30.0,
window_stride_ms=10.0,
feature_bin_count=40,
model_architecture='conv',
preprocess='micro')
global_variables_initializer().run()
graph_util.convert_variables_to_constants(
sess, sess.graph_def, ['labels_softmax'])
freeze.save_saved_model(saved_model_path, sess, input_tensor,
output_tensor)
if __name__ == '__main__':
test.main()