Add SavedModel support to speech commands training
PiperOrigin-RevId: 307691321 Change-Id: Ib52d1bf5537835ce56db11b074a0f3df9c3f9206
This commit is contained in:
parent
567f7e24ef
commit
314b503bd2
@ -80,6 +80,9 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||
preprocess: How the spectrogram is processed to produce features, for
|
||||
example 'mfcc', 'average', or 'micro'.
|
||||
|
||||
Returns:
|
||||
Input and output tensor objects.
|
||||
|
||||
Raises:
|
||||
Exception: If the preprocessing mode isn't recognized.
|
||||
"""
|
||||
@ -150,7 +153,59 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
||||
runtime_settings=runtime_settings)
|
||||
|
||||
# Create an output to use for inference.
|
||||
tf.nn.softmax(logits, name='labels_softmax')
|
||||
softmax = tf.nn.softmax(logits, name='labels_softmax')
|
||||
|
||||
return reshaped_input, softmax
|
||||
|
||||
|
||||
def save_graph_def(file_name, frozen_graph_def):
|
||||
"""Writes a graph def file out to disk.
|
||||
|
||||
Args:
|
||||
file_name: Where to save the file.
|
||||
frozen_graph_def: GraphDef proto object to save.
|
||||
"""
|
||||
tf.io.write_graph(
|
||||
frozen_graph_def,
|
||||
os.path.dirname(file_name),
|
||||
os.path.basename(file_name),
|
||||
as_text=False)
|
||||
tf.compat.v1.logging.info('Saved frozen graph to %s', file_name)
|
||||
|
||||
|
||||
def save_saved_model(file_name, sess, input_tensor, output_tensor):
|
||||
"""Writes a SavedModel out to disk.
|
||||
|
||||
Args:
|
||||
file_name: Where to save the file.
|
||||
sess: TensorFlow session containing the graph.
|
||||
input_tensor: Tensor object defining the input's properties.
|
||||
output_tensor: Tensor object defining the output's properties.
|
||||
"""
|
||||
# Store the frozen graph as a SavedModel for v2 compatibility.
|
||||
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name)
|
||||
tensor_info_inputs = {
|
||||
'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)
|
||||
}
|
||||
tensor_info_outputs = {
|
||||
'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)
|
||||
}
|
||||
signature = (
|
||||
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||
inputs=tensor_info_inputs,
|
||||
outputs=tensor_info_outputs,
|
||||
method_name=tf.compat.v1.saved_model.signature_constants
|
||||
.PREDICT_METHOD_NAME))
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess,
|
||||
[tf.compat.v1.saved_model.tag_constants.SERVING],
|
||||
signature_def_map={
|
||||
tf.compat.v1.saved_model.signature_constants
|
||||
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||
signature,
|
||||
},
|
||||
)
|
||||
builder.save()
|
||||
|
||||
|
||||
def main(_):
|
||||
@ -167,7 +222,7 @@ def main(_):
|
||||
|
||||
# Create the model and load its weights.
|
||||
sess = tf.compat.v1.InteractiveSession()
|
||||
create_inference_graph(
|
||||
input_tensor, output_tensor = create_inference_graph(
|
||||
FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
|
||||
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
|
||||
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
|
||||
@ -178,12 +233,14 @@ def main(_):
|
||||
# Turn all the variables into inline constants inside the graph and save it.
|
||||
frozen_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph_def, ['labels_softmax'])
|
||||
tf.io.write_graph(
|
||||
frozen_graph_def,
|
||||
os.path.dirname(FLAGS.output_file),
|
||||
os.path.basename(FLAGS.output_file),
|
||||
as_text=False)
|
||||
tf.compat.v1.logging.info('Saved frozen graph to %s', FLAGS.output_file)
|
||||
|
||||
if FLAGS.save_format == 'graph_def':
|
||||
save_graph_def(FLAGS.output_file, frozen_graph_def)
|
||||
elif FLAGS.save_format == 'saved_model':
|
||||
save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor)
|
||||
else:
|
||||
raise Exception('Unknown save format "%s" (should be "graph_def" or'
|
||||
' "saved_model")' % (FLAGS.save_format))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -246,5 +303,10 @@ if __name__ == '__main__':
|
||||
type=str,
|
||||
default='mfcc',
|
||||
help='Spectrogram processing mode. Can be "mfcc" or "average"')
|
||||
parser.add_argument(
|
||||
'--save_format',
|
||||
type=str,
|
||||
default='graph_def',
|
||||
help='How to save the result. Can be "graph_def" or "saved_model"')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -18,8 +18,12 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
|
||||
from tensorflow.examples.speech_commands import freeze
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.variables import global_variables_initializer
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -103,6 +107,27 @@ class FreezeTest(test.TestCase):
|
||||
ops = [node.op for node in sess.graph_def.node]
|
||||
self.assertEqual(0, ops.count('Mfcc'))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCreateSavedModel(self):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
saved_model_path = os.path.join(tmp_dir, 'saved_model')
|
||||
with self.cached_session() as sess:
|
||||
input_tensor, output_tensor = freeze.create_inference_graph(
|
||||
wanted_words='a,b,c,d',
|
||||
sample_rate=16000,
|
||||
clip_duration_ms=1000.0,
|
||||
clip_stride_ms=30.0,
|
||||
window_size_ms=30.0,
|
||||
window_stride_ms=10.0,
|
||||
feature_bin_count=40,
|
||||
model_architecture='conv',
|
||||
preprocess='micro')
|
||||
global_variables_initializer().run()
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess, sess.graph_def, ['labels_softmax'])
|
||||
freeze.save_saved_model(saved_model_path, sess, input_tensor,
|
||||
output_tensor)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user