Add SavedModel support to speech commands training
PiperOrigin-RevId: 307691321 Change-Id: Ib52d1bf5537835ce56db11b074a0f3df9c3f9206
This commit is contained in:
parent
567f7e24ef
commit
314b503bd2
@ -80,6 +80,9 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
|||||||
preprocess: How the spectrogram is processed to produce features, for
|
preprocess: How the spectrogram is processed to produce features, for
|
||||||
example 'mfcc', 'average', or 'micro'.
|
example 'mfcc', 'average', or 'micro'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Input and output tensor objects.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Exception: If the preprocessing mode isn't recognized.
|
Exception: If the preprocessing mode isn't recognized.
|
||||||
"""
|
"""
|
||||||
@ -150,7 +153,59 @@ def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
|||||||
runtime_settings=runtime_settings)
|
runtime_settings=runtime_settings)
|
||||||
|
|
||||||
# Create an output to use for inference.
|
# Create an output to use for inference.
|
||||||
tf.nn.softmax(logits, name='labels_softmax')
|
softmax = tf.nn.softmax(logits, name='labels_softmax')
|
||||||
|
|
||||||
|
return reshaped_input, softmax
|
||||||
|
|
||||||
|
|
||||||
|
def save_graph_def(file_name, frozen_graph_def):
|
||||||
|
"""Writes a graph def file out to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: Where to save the file.
|
||||||
|
frozen_graph_def: GraphDef proto object to save.
|
||||||
|
"""
|
||||||
|
tf.io.write_graph(
|
||||||
|
frozen_graph_def,
|
||||||
|
os.path.dirname(file_name),
|
||||||
|
os.path.basename(file_name),
|
||||||
|
as_text=False)
|
||||||
|
tf.compat.v1.logging.info('Saved frozen graph to %s', file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def save_saved_model(file_name, sess, input_tensor, output_tensor):
|
||||||
|
"""Writes a SavedModel out to disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name: Where to save the file.
|
||||||
|
sess: TensorFlow session containing the graph.
|
||||||
|
input_tensor: Tensor object defining the input's properties.
|
||||||
|
output_tensor: Tensor object defining the output's properties.
|
||||||
|
"""
|
||||||
|
# Store the frozen graph as a SavedModel for v2 compatibility.
|
||||||
|
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name)
|
||||||
|
tensor_info_inputs = {
|
||||||
|
'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)
|
||||||
|
}
|
||||||
|
tensor_info_outputs = {
|
||||||
|
'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)
|
||||||
|
}
|
||||||
|
signature = (
|
||||||
|
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
||||||
|
inputs=tensor_info_inputs,
|
||||||
|
outputs=tensor_info_outputs,
|
||||||
|
method_name=tf.compat.v1.saved_model.signature_constants
|
||||||
|
.PREDICT_METHOD_NAME))
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess,
|
||||||
|
[tf.compat.v1.saved_model.tag_constants.SERVING],
|
||||||
|
signature_def_map={
|
||||||
|
tf.compat.v1.saved_model.signature_constants
|
||||||
|
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||||
|
signature,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
@ -167,7 +222,7 @@ def main(_):
|
|||||||
|
|
||||||
# Create the model and load its weights.
|
# Create the model and load its weights.
|
||||||
sess = tf.compat.v1.InteractiveSession()
|
sess = tf.compat.v1.InteractiveSession()
|
||||||
create_inference_graph(
|
input_tensor, output_tensor = create_inference_graph(
|
||||||
FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
|
FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
|
||||||
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
|
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
|
||||||
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
|
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
|
||||||
@ -178,12 +233,14 @@ def main(_):
|
|||||||
# Turn all the variables into inline constants inside the graph and save it.
|
# Turn all the variables into inline constants inside the graph and save it.
|
||||||
frozen_graph_def = graph_util.convert_variables_to_constants(
|
frozen_graph_def = graph_util.convert_variables_to_constants(
|
||||||
sess, sess.graph_def, ['labels_softmax'])
|
sess, sess.graph_def, ['labels_softmax'])
|
||||||
tf.io.write_graph(
|
|
||||||
frozen_graph_def,
|
if FLAGS.save_format == 'graph_def':
|
||||||
os.path.dirname(FLAGS.output_file),
|
save_graph_def(FLAGS.output_file, frozen_graph_def)
|
||||||
os.path.basename(FLAGS.output_file),
|
elif FLAGS.save_format == 'saved_model':
|
||||||
as_text=False)
|
save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor)
|
||||||
tf.compat.v1.logging.info('Saved frozen graph to %s', FLAGS.output_file)
|
else:
|
||||||
|
raise Exception('Unknown save format "%s" (should be "graph_def" or'
|
||||||
|
' "saved_model")' % (FLAGS.save_format))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -246,5 +303,10 @@ if __name__ == '__main__':
|
|||||||
type=str,
|
type=str,
|
||||||
default='mfcc',
|
default='mfcc',
|
||||||
help='Spectrogram processing mode. Can be "mfcc" or "average"')
|
help='Spectrogram processing mode. Can be "mfcc" or "average"')
|
||||||
|
parser.add_argument(
|
||||||
|
'--save_format',
|
||||||
|
type=str,
|
||||||
|
default='graph_def',
|
||||||
|
help='How to save the result. Can be "graph_def" or "saved_model"')
|
||||||
FLAGS, unparsed = parser.parse_known_args()
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||||
|
@ -18,8 +18,12 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
|
||||||
from tensorflow.examples.speech_commands import freeze
|
from tensorflow.examples.speech_commands import freeze
|
||||||
|
from tensorflow.python.framework import graph_util
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.ops.variables import global_variables_initializer
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -103,6 +107,27 @@ class FreezeTest(test.TestCase):
|
|||||||
ops = [node.op for node in sess.graph_def.node]
|
ops = [node.op for node in sess.graph_def.node]
|
||||||
self.assertEqual(0, ops.count('Mfcc'))
|
self.assertEqual(0, ops.count('Mfcc'))
|
||||||
|
|
||||||
|
@test_util.run_deprecated_v1
|
||||||
|
def testCreateSavedModel(self):
|
||||||
|
tmp_dir = self.get_temp_dir()
|
||||||
|
saved_model_path = os.path.join(tmp_dir, 'saved_model')
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
input_tensor, output_tensor = freeze.create_inference_graph(
|
||||||
|
wanted_words='a,b,c,d',
|
||||||
|
sample_rate=16000,
|
||||||
|
clip_duration_ms=1000.0,
|
||||||
|
clip_stride_ms=30.0,
|
||||||
|
window_size_ms=30.0,
|
||||||
|
window_stride_ms=10.0,
|
||||||
|
feature_bin_count=40,
|
||||||
|
model_architecture='conv',
|
||||||
|
preprocess='micro')
|
||||||
|
global_variables_initializer().run()
|
||||||
|
graph_util.convert_variables_to_constants(
|
||||||
|
sess, sess.graph_def, ['labels_softmax'])
|
||||||
|
freeze.save_saved_model(saved_model_path, sess, input_tensor,
|
||||||
|
output_tensor)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user