313 lines
12 KiB
Python
313 lines
12 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
r"""Converts a trained checkpoint into a frozen model for mobile inference.
|
|
|
|
Once you've trained a model using the `train.py` script, you can use this tool
|
|
to convert it into a binary GraphDef file that can be loaded into the Android,
|
|
iOS, or Raspberry Pi example code. Here's an example of how to run it:
|
|
|
|
bazel run tensorflow/examples/speech_commands/freeze -- \
|
|
--sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \
|
|
--window_stride_ms=10 --clip_duration_ms=1000 \
|
|
--model_architecture=conv \
|
|
--start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \
|
|
--output_file=/tmp/my_frozen_graph.pb
|
|
|
|
One thing to watch out for is that you need to pass in the same arguments for
|
|
`sample_rate` and other command line variables here as you did for the training
|
|
script.
|
|
|
|
The resulting graph has an input for WAV-encoded data named 'wav_data', one for
|
|
raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data',
|
|
and the output is called 'labels_softmax'.
|
|
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import os.path
|
|
import sys
|
|
|
|
import tensorflow as tf
|
|
|
|
import input_data
|
|
import models
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.ops import gen_audio_ops as audio_ops
|
|
|
|
# If it's available, load the specialized feature generator. If this doesn't
|
|
# work, try building with bazel instead of running the Python script directly.
|
|
# bazel run tensorflow/examples/speech_commands:freeze_graph
|
|
try:
|
|
from tensorflow.lite.experimental.microfrontend.python.ops import audio_microfrontend_op as frontend_op # pylint:disable=g-import-not-at-top
|
|
except ImportError:
|
|
frontend_op = None
|
|
|
|
FLAGS = None
|
|
|
|
|
|
def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
|
|
clip_stride_ms, window_size_ms, window_stride_ms,
|
|
feature_bin_count, model_architecture, preprocess):
|
|
"""Creates an audio model with the nodes needed for inference.
|
|
|
|
Uses the supplied arguments to create a model, and inserts the input and
|
|
output nodes that are needed to use the graph for inference.
|
|
|
|
Args:
|
|
wanted_words: Comma-separated list of the words we're trying to recognize.
|
|
sample_rate: How many samples per second are in the input audio files.
|
|
clip_duration_ms: How many samples to analyze for the audio pattern.
|
|
clip_stride_ms: How often to run recognition. Useful for models with cache.
|
|
window_size_ms: Time slice duration to estimate frequencies from.
|
|
window_stride_ms: How far apart time slices should be.
|
|
feature_bin_count: Number of frequency bands to analyze.
|
|
model_architecture: Name of the kind of model to generate.
|
|
preprocess: How the spectrogram is processed to produce features, for
|
|
example 'mfcc', 'average', or 'micro'.
|
|
|
|
Returns:
|
|
Input and output tensor objects.
|
|
|
|
Raises:
|
|
Exception: If the preprocessing mode isn't recognized.
|
|
"""
|
|
|
|
words_list = input_data.prepare_words_list(wanted_words.split(','))
|
|
model_settings = models.prepare_model_settings(
|
|
len(words_list), sample_rate, clip_duration_ms, window_size_ms,
|
|
window_stride_ms, feature_bin_count, preprocess)
|
|
runtime_settings = {'clip_stride_ms': clip_stride_ms}
|
|
|
|
wav_data_placeholder = tf.compat.v1.placeholder(tf.string, [],
|
|
name='wav_data')
|
|
decoded_sample_data = tf.audio.decode_wav(
|
|
wav_data_placeholder,
|
|
desired_channels=1,
|
|
desired_samples=model_settings['desired_samples'],
|
|
name='decoded_sample_data')
|
|
spectrogram = audio_ops.audio_spectrogram(
|
|
decoded_sample_data.audio,
|
|
window_size=model_settings['window_size_samples'],
|
|
stride=model_settings['window_stride_samples'],
|
|
magnitude_squared=True)
|
|
|
|
if preprocess == 'average':
|
|
fingerprint_input = tf.nn.pool(
|
|
input=tf.expand_dims(spectrogram, -1),
|
|
window_shape=[1, model_settings['average_window_width']],
|
|
strides=[1, model_settings['average_window_width']],
|
|
pooling_type='AVG',
|
|
padding='SAME')
|
|
elif preprocess == 'mfcc':
|
|
fingerprint_input = audio_ops.mfcc(
|
|
spectrogram,
|
|
sample_rate,
|
|
dct_coefficient_count=model_settings['fingerprint_width'])
|
|
elif preprocess == 'micro':
|
|
if not frontend_op:
|
|
raise Exception(
|
|
'Micro frontend op is currently not available when running TensorFlow'
|
|
' directly from Python, you need to build and run through Bazel, for'
|
|
' example'
|
|
' `bazel run tensorflow/examples/speech_commands:freeze_graph`')
|
|
sample_rate = model_settings['sample_rate']
|
|
window_size_ms = (model_settings['window_size_samples'] *
|
|
1000) / sample_rate
|
|
window_step_ms = (model_settings['window_stride_samples'] *
|
|
1000) / sample_rate
|
|
int16_input = tf.cast(
|
|
tf.multiply(decoded_sample_data.audio, 32767), tf.int16)
|
|
micro_frontend = frontend_op.audio_microfrontend(
|
|
int16_input,
|
|
sample_rate=sample_rate,
|
|
window_size=window_size_ms,
|
|
window_step=window_step_ms,
|
|
num_channels=model_settings['fingerprint_width'],
|
|
out_scale=1,
|
|
out_type=tf.float32)
|
|
fingerprint_input = tf.multiply(micro_frontend, (10.0 / 256.0))
|
|
else:
|
|
raise Exception('Unknown preprocess mode "%s" (should be "mfcc",'
|
|
' "average", or "micro")' % (preprocess))
|
|
|
|
fingerprint_size = model_settings['fingerprint_size']
|
|
reshaped_input = tf.reshape(fingerprint_input, [-1, fingerprint_size])
|
|
|
|
logits = models.create_model(
|
|
reshaped_input, model_settings, model_architecture, is_training=False,
|
|
runtime_settings=runtime_settings)
|
|
|
|
# Create an output to use for inference.
|
|
softmax = tf.nn.softmax(logits, name='labels_softmax')
|
|
|
|
return reshaped_input, softmax
|
|
|
|
|
|
def save_graph_def(file_name, frozen_graph_def):
|
|
"""Writes a graph def file out to disk.
|
|
|
|
Args:
|
|
file_name: Where to save the file.
|
|
frozen_graph_def: GraphDef proto object to save.
|
|
"""
|
|
tf.io.write_graph(
|
|
frozen_graph_def,
|
|
os.path.dirname(file_name),
|
|
os.path.basename(file_name),
|
|
as_text=False)
|
|
tf.compat.v1.logging.info('Saved frozen graph to %s', file_name)
|
|
|
|
|
|
def save_saved_model(file_name, sess, input_tensor, output_tensor):
|
|
"""Writes a SavedModel out to disk.
|
|
|
|
Args:
|
|
file_name: Where to save the file.
|
|
sess: TensorFlow session containing the graph.
|
|
input_tensor: Tensor object defining the input's properties.
|
|
output_tensor: Tensor object defining the output's properties.
|
|
"""
|
|
# Store the frozen graph as a SavedModel for v2 compatibility.
|
|
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(file_name)
|
|
tensor_info_inputs = {
|
|
'input': tf.compat.v1.saved_model.utils.build_tensor_info(input_tensor)
|
|
}
|
|
tensor_info_outputs = {
|
|
'output': tf.compat.v1.saved_model.utils.build_tensor_info(output_tensor)
|
|
}
|
|
signature = (
|
|
tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
|
|
inputs=tensor_info_inputs,
|
|
outputs=tensor_info_outputs,
|
|
method_name=tf.compat.v1.saved_model.signature_constants
|
|
.PREDICT_METHOD_NAME))
|
|
builder.add_meta_graph_and_variables(
|
|
sess,
|
|
[tf.compat.v1.saved_model.tag_constants.SERVING],
|
|
signature_def_map={
|
|
tf.compat.v1.saved_model.signature_constants
|
|
.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
|
signature,
|
|
},
|
|
)
|
|
builder.save()
|
|
|
|
|
|
def main(_):
|
|
if FLAGS.quantize:
|
|
try:
|
|
_ = tf.contrib
|
|
except AttributeError as e:
|
|
msg = e.args[0]
|
|
msg += ('\n\n The --quantize option still requires contrib, which is not '
|
|
'part of TensorFlow 2.0. Please install a previous version:'
|
|
'\n `pip install tensorflow<=1.15`')
|
|
e.args = (msg,)
|
|
raise e
|
|
|
|
# Create the model and load its weights.
|
|
sess = tf.compat.v1.InteractiveSession()
|
|
input_tensor, output_tensor = create_inference_graph(
|
|
FLAGS.wanted_words, FLAGS.sample_rate, FLAGS.clip_duration_ms,
|
|
FLAGS.clip_stride_ms, FLAGS.window_size_ms, FLAGS.window_stride_ms,
|
|
FLAGS.feature_bin_count, FLAGS.model_architecture, FLAGS.preprocess)
|
|
if FLAGS.quantize:
|
|
tf.contrib.quantize.create_eval_graph()
|
|
models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
|
|
|
|
# Turn all the variables into inline constants inside the graph and save it.
|
|
frozen_graph_def = graph_util.convert_variables_to_constants(
|
|
sess, sess.graph_def, ['labels_softmax'])
|
|
|
|
if FLAGS.save_format == 'graph_def':
|
|
save_graph_def(FLAGS.output_file, frozen_graph_def)
|
|
elif FLAGS.save_format == 'saved_model':
|
|
save_saved_model(FLAGS.output_file, sess, input_tensor, output_tensor)
|
|
else:
|
|
raise Exception('Unknown save format "%s" (should be "graph_def" or'
|
|
' "saved_model")' % (FLAGS.save_format))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--sample_rate',
|
|
type=int,
|
|
default=16000,
|
|
help='Expected sample rate of the wavs',)
|
|
parser.add_argument(
|
|
'--clip_duration_ms',
|
|
type=int,
|
|
default=1000,
|
|
help='Expected duration in milliseconds of the wavs',)
|
|
parser.add_argument(
|
|
'--clip_stride_ms',
|
|
type=int,
|
|
default=30,
|
|
help='How often to run recognition. Useful for models with cache.',)
|
|
parser.add_argument(
|
|
'--window_size_ms',
|
|
type=float,
|
|
default=30.0,
|
|
help='How long each spectrogram timeslice is',)
|
|
parser.add_argument(
|
|
'--window_stride_ms',
|
|
type=float,
|
|
default=10.0,
|
|
help='How long the stride is between spectrogram timeslices',)
|
|
parser.add_argument(
|
|
'--feature_bin_count',
|
|
type=int,
|
|
default=40,
|
|
help='How many bins to use for the MFCC fingerprint',
|
|
)
|
|
parser.add_argument(
|
|
'--start_checkpoint',
|
|
type=str,
|
|
default='',
|
|
help='If specified, restore this pretrained model before any training.')
|
|
parser.add_argument(
|
|
'--model_architecture',
|
|
type=str,
|
|
default='conv',
|
|
help='What model architecture to use')
|
|
parser.add_argument(
|
|
'--wanted_words',
|
|
type=str,
|
|
default='yes,no,up,down,left,right,on,off,stop,go',
|
|
help='Words to use (others will be added to an unknown label)',)
|
|
parser.add_argument(
|
|
'--output_file', type=str, help='Where to save the frozen graph.')
|
|
parser.add_argument(
|
|
'--quantize',
|
|
type=bool,
|
|
default=False,
|
|
help='Whether to train the model for eight-bit deployment')
|
|
parser.add_argument(
|
|
'--preprocess',
|
|
type=str,
|
|
default='mfcc',
|
|
help='Spectrogram processing mode. Can be "mfcc" or "average"')
|
|
parser.add_argument(
|
|
'--save_format',
|
|
type=str,
|
|
default='graph_def',
|
|
help='How to save the result. Can be "graph_def" or "saved_model"')
|
|
FLAGS, unparsed = parser.parse_known_args()
|
|
tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|